Skip to main content

oxigdal_ws/
client.rs

1//! WebSocket client implementation.
2
3use crate::error::{Error, Result};
4use crate::protocol::{Compression, EventType, Message, MessageFormat, PROTOCOL_VERSION};
5use crate::stream::{
6    EventData, EventStream, FeatureData, FeatureStream, MessageStream, TileData, TileStream,
7};
8use futures::{SinkExt, StreamExt};
9use std::ops::Range;
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tokio::sync::mpsc;
13use tokio::time::timeout;
14use tokio_tungstenite::{
15    MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message as WsMessage,
16};
17use tracing::{debug, info};
18
19/// WebSocket client configuration.
20#[derive(Debug, Clone)]
21pub struct ClientConfig {
22    /// Server URL
23    pub url: String,
24    /// Connection timeout
25    pub connect_timeout: Duration,
26    /// Message timeout
27    pub message_timeout: Duration,
28    /// Preferred message format
29    pub format: MessageFormat,
30    /// Preferred compression
31    pub compression: Compression,
32    /// Reconnect on disconnect
33    pub auto_reconnect: bool,
34    /// Maximum reconnect attempts
35    pub max_reconnect_attempts: usize,
36}
37
38impl Default for ClientConfig {
39    fn default() -> Self {
40        Self {
41            url: "ws://localhost:9001/ws".to_string(),
42            connect_timeout: Duration::from_secs(10),
43            message_timeout: Duration::from_secs(30),
44            format: MessageFormat::MessagePack,
45            compression: Compression::Zstd,
46            auto_reconnect: true,
47            max_reconnect_attempts: 5,
48        }
49    }
50}
51
52/// WebSocket client.
53pub struct WebSocketClient {
54    config: ClientConfig,
55    socket: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
56    /// Message sender for internal message queue (reserved for async processing)
57    #[allow(dead_code)]
58    message_tx: mpsc::UnboundedSender<Message>,
59    message_rx: Option<mpsc::UnboundedReceiver<Message>>,
60    format: MessageFormat,
61    compression: Compression,
62}
63
64impl WebSocketClient {
65    /// Create a new client with default configuration.
66    pub fn new() -> Self {
67        Self::with_config(ClientConfig::default())
68    }
69
70    /// Create a new client with custom configuration.
71    pub fn with_config(config: ClientConfig) -> Self {
72        let (message_tx, message_rx) = mpsc::unbounded_channel();
73        let format = config.format;
74        let compression = config.compression;
75
76        Self {
77            config,
78            socket: None,
79            message_tx,
80            message_rx: Some(message_rx),
81            format,
82            compression,
83        }
84    }
85
86    /// Connect to the WebSocket server.
87    pub async fn connect(url: &str) -> Result<Self> {
88        let config = ClientConfig {
89            url: url.to_string(),
90            ..Default::default()
91        };
92
93        let mut client = Self::with_config(config);
94        client.do_connect().await?;
95        client.handshake().await?;
96
97        Ok(client)
98    }
99
100    /// Perform the actual connection.
101    async fn do_connect(&mut self) -> Result<()> {
102        info!("Connecting to {}", self.config.url);
103
104        let connect_future = connect_async(&self.config.url);
105        let (ws_stream, _) = timeout(self.config.connect_timeout, connect_future)
106            .await
107            .map_err(|_| Error::Timeout("Connection timeout".to_string()))?
108            .map_err(|e| Error::Connection(e.to_string()))?;
109
110        self.socket = Some(ws_stream);
111        info!("Connected to {}", self.config.url);
112
113        Ok(())
114    }
115
116    /// Perform protocol handshake.
117    async fn handshake(&mut self) -> Result<()> {
118        debug!("Performing handshake");
119
120        let handshake_msg = Message::Handshake {
121            version: PROTOCOL_VERSION,
122            format: self.config.format,
123            compression: self.config.compression,
124        };
125
126        self.send_message(handshake_msg).await?;
127
128        // Wait for handshake acknowledgement
129        let ack = timeout(self.config.message_timeout, self.receive_message())
130            .await
131            .map_err(|_| Error::Timeout("Handshake timeout".to_string()))??;
132
133        match ack {
134            Message::HandshakeAck {
135                version,
136                format,
137                compression,
138            } => {
139                if version != PROTOCOL_VERSION {
140                    return Err(Error::Protocol(format!(
141                        "Protocol version mismatch: expected {}, got {}",
142                        PROTOCOL_VERSION, version
143                    )));
144                }
145                self.format = format;
146                self.compression = compression;
147                info!(
148                    "Handshake complete: format={:?}, compression={:?}",
149                    format, compression
150                );
151                Ok(())
152            }
153            _ => Err(Error::Protocol(
154                "Expected handshake acknowledgement".to_string(),
155            )),
156        }
157    }
158
159    /// Send a message to the server.
160    async fn send_message(&mut self, message: Message) -> Result<()> {
161        let socket = self
162            .socket
163            .as_mut()
164            .ok_or_else(|| Error::Connection("Not connected".to_string()))?;
165
166        let data = message.encode(self.format, self.compression)?;
167        socket
168            .send(WsMessage::Binary(data.into()))
169            .await
170            .map_err(|e| Error::Send(e.to_string()))?;
171
172        Ok(())
173    }
174
175    /// Receive a message from the server.
176    async fn receive_message(&mut self) -> Result<Message> {
177        let socket = self
178            .socket
179            .as_mut()
180            .ok_or_else(|| Error::Connection("Not connected".to_string()))?;
181
182        let msg = socket
183            .next()
184            .await
185            .ok_or_else(|| Error::Receive("Connection closed".to_string()))?
186            .map_err(|e| Error::Receive(e.to_string()))?;
187
188        let data = match msg {
189            WsMessage::Binary(payload) => payload.to_vec(),
190            WsMessage::Text(text) => text.as_bytes().to_vec(),
191            WsMessage::Close(_) => {
192                return Err(Error::Connection("Server closed connection".to_string()));
193            }
194            _ => {
195                return Err(Error::InvalidMessage("Unexpected message type".to_string()));
196            }
197        };
198
199        Message::decode(&data, self.format, self.compression)
200    }
201
202    /// Subscribe to tile updates.
203    pub async fn subscribe_tiles(
204        &mut self,
205        bbox: [f64; 4],
206        zoom_range: Range<u8>,
207    ) -> Result<String> {
208        let subscription_id = uuid::Uuid::new_v4().to_string();
209
210        let msg = Message::SubscribeTiles {
211            subscription_id: subscription_id.clone(),
212            bbox,
213            zoom_range,
214            tile_size: Some(256),
215        };
216
217        self.send_message(msg).await?;
218
219        // Wait for acknowledgement
220        let ack = timeout(self.config.message_timeout, self.receive_message())
221            .await
222            .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
223
224        match ack {
225            Message::Ack { success: true, .. } => Ok(subscription_id),
226            Message::Ack { message, .. } => Err(Error::Subscription(
227                message.unwrap_or_else(|| "Failed to subscribe".to_string()),
228            )),
229            Message::Error { message, .. } => Err(Error::Subscription(message)),
230            _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
231        }
232    }
233
234    /// Subscribe to feature updates.
235    pub async fn subscribe_features(&mut self, layer: Option<String>) -> Result<String> {
236        let subscription_id = uuid::Uuid::new_v4().to_string();
237
238        let msg = Message::SubscribeFeatures {
239            subscription_id: subscription_id.clone(),
240            bbox: None,
241            filters: None,
242            layer,
243        };
244
245        self.send_message(msg).await?;
246
247        // Wait for acknowledgement
248        let ack = timeout(self.config.message_timeout, self.receive_message())
249            .await
250            .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
251
252        match ack {
253            Message::Ack { success: true, .. } => Ok(subscription_id),
254            Message::Ack { message, .. } => Err(Error::Subscription(
255                message.unwrap_or_else(|| "Failed to subscribe".to_string()),
256            )),
257            Message::Error { message, .. } => Err(Error::Subscription(message)),
258            _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
259        }
260    }
261
262    /// Subscribe to events.
263    pub async fn subscribe_events(&mut self, event_types: Vec<EventType>) -> Result<String> {
264        let subscription_id = uuid::Uuid::new_v4().to_string();
265
266        let msg = Message::SubscribeEvents {
267            subscription_id: subscription_id.clone(),
268            event_types,
269        };
270
271        self.send_message(msg).await?;
272
273        // Wait for acknowledgement
274        let ack = timeout(self.config.message_timeout, self.receive_message())
275            .await
276            .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
277
278        match ack {
279            Message::Ack { success: true, .. } => Ok(subscription_id),
280            Message::Ack { message, .. } => Err(Error::Subscription(
281                message.unwrap_or_else(|| "Failed to subscribe".to_string()),
282            )),
283            Message::Error { message, .. } => Err(Error::Subscription(message)),
284            _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
285        }
286    }
287
288    /// Unsubscribe from updates.
289    pub async fn unsubscribe(&mut self, subscription_id: &str) -> Result<()> {
290        let msg = Message::Unsubscribe {
291            subscription_id: subscription_id.to_string(),
292        };
293
294        self.send_message(msg).await?;
295
296        // Wait for acknowledgement
297        let ack = timeout(self.config.message_timeout, self.receive_message())
298            .await
299            .map_err(|_| Error::Timeout("Unsubscribe timeout".to_string()))??;
300
301        match ack {
302            Message::Ack { success: true, .. } => Ok(()),
303            Message::Ack { message, .. } => Err(Error::Subscription(
304                message.unwrap_or_else(|| "Failed to unsubscribe".to_string()),
305            )),
306            Message::Error { message, .. } => Err(Error::Subscription(message)),
307            _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
308        }
309    }
310
311    /// Get a stream of all messages.
312    pub fn message_stream(&mut self) -> Option<MessageStream> {
313        self.message_rx.take().map(MessageStream::new)
314    }
315
316    /// Get a stream of tile data.
317    pub fn tile_stream(&mut self) -> TileStream {
318        let (tx, rx) = mpsc::unbounded_channel();
319
320        // Spawn task to filter tile messages
321        let message_rx = self.message_rx.take();
322        if let Some(mut message_rx) = message_rx {
323            tokio::spawn(async move {
324                while let Some(message) = message_rx.recv().await {
325                    if let Message::TileData {
326                        tile,
327                        data,
328                        mime_type,
329                        ..
330                    } = message
331                    {
332                        let tile_data = TileData::new(tile.0, tile.1, tile.2, data, mime_type);
333                        if tx.send(tile_data).is_err() {
334                            break;
335                        }
336                    }
337                }
338            });
339        }
340
341        TileStream::new(rx)
342    }
343
344    /// Get a stream of feature data.
345    pub fn feature_stream(&mut self) -> FeatureStream {
346        let (tx, rx) = mpsc::unbounded_channel();
347
348        // Spawn task to filter feature messages
349        let message_rx = self.message_rx.take();
350        if let Some(mut message_rx) = message_rx {
351            tokio::spawn(async move {
352                while let Some(message) = message_rx.recv().await {
353                    if let Message::FeatureData {
354                        geojson,
355                        change_type,
356                        ..
357                    } = message
358                    {
359                        let feature_data = FeatureData::new(geojson, change_type, None);
360                        if tx.send(feature_data).is_err() {
361                            break;
362                        }
363                    }
364                }
365            });
366        }
367
368        FeatureStream::new(rx)
369    }
370
371    /// Get a stream of events.
372    pub fn event_stream(&mut self) -> EventStream {
373        let (tx, rx) = mpsc::unbounded_channel();
374
375        // Spawn task to filter event messages
376        let message_rx = self.message_rx.take();
377        if let Some(mut message_rx) = message_rx {
378            tokio::spawn(async move {
379                while let Some(message) = message_rx.recv().await {
380                    if let Message::Event {
381                        event_type,
382                        payload,
383                        timestamp,
384                        ..
385                    } = message
386                    {
387                        if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&timestamp) {
388                            let event_data = EventData::with_timestamp(
389                                event_type,
390                                payload,
391                                ts.with_timezone(&chrono::Utc),
392                            );
393                            if tx.send(event_data).is_err() {
394                                break;
395                            }
396                        }
397                    }
398                }
399            });
400        }
401
402        EventStream::new(rx)
403    }
404
405    /// Send a ping.
406    pub async fn ping(&mut self, id: u64) -> Result<()> {
407        self.send_message(Message::Ping { id }).await
408    }
409
410    /// Close the connection.
411    pub async fn close(mut self) -> Result<()> {
412        if let Some(mut socket) = self.socket.take() {
413            socket
414                .close(None)
415                .await
416                .map_err(|e| Error::Connection(e.to_string()))?;
417        }
418        Ok(())
419    }
420}
421
422impl Default for WebSocketClient {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_client_config_default() {
434        let config = ClientConfig::default();
435        assert_eq!(config.url, "ws://localhost:9001/ws");
436        assert_eq!(config.format, MessageFormat::MessagePack);
437        assert_eq!(config.compression, Compression::Zstd);
438        assert!(config.auto_reconnect);
439    }
440
441    #[test]
442    fn test_client_creation() {
443        let client = WebSocketClient::new();
444        assert!(client.socket.is_none());
445    }
446}