1use crate::error::{Error, Result};
4use crate::protocol::{Compression, EventType, Message, MessageFormat, PROTOCOL_VERSION};
5use crate::stream::{
6 EventData, EventStream, FeatureData, FeatureStream, MessageStream, TileData, TileStream,
7};
8use futures::{SinkExt, StreamExt};
9use std::ops::Range;
10use std::time::Duration;
11use tokio::net::TcpStream;
12use tokio::sync::mpsc;
13use tokio::time::timeout;
14use tokio_tungstenite::{
15 MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message as WsMessage,
16};
17use tracing::{debug, info};
18
19#[derive(Debug, Clone)]
21pub struct ClientConfig {
22 pub url: String,
24 pub connect_timeout: Duration,
26 pub message_timeout: Duration,
28 pub format: MessageFormat,
30 pub compression: Compression,
32 pub auto_reconnect: bool,
34 pub max_reconnect_attempts: usize,
36}
37
38impl Default for ClientConfig {
39 fn default() -> Self {
40 Self {
41 url: "ws://localhost:9001/ws".to_string(),
42 connect_timeout: Duration::from_secs(10),
43 message_timeout: Duration::from_secs(30),
44 format: MessageFormat::MessagePack,
45 compression: Compression::Zstd,
46 auto_reconnect: true,
47 max_reconnect_attempts: 5,
48 }
49 }
50}
51
52pub struct WebSocketClient {
54 config: ClientConfig,
55 socket: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
56 #[allow(dead_code)]
58 message_tx: mpsc::UnboundedSender<Message>,
59 message_rx: Option<mpsc::UnboundedReceiver<Message>>,
60 format: MessageFormat,
61 compression: Compression,
62}
63
64impl WebSocketClient {
65 pub fn new() -> Self {
67 Self::with_config(ClientConfig::default())
68 }
69
70 pub fn with_config(config: ClientConfig) -> Self {
72 let (message_tx, message_rx) = mpsc::unbounded_channel();
73 let format = config.format;
74 let compression = config.compression;
75
76 Self {
77 config,
78 socket: None,
79 message_tx,
80 message_rx: Some(message_rx),
81 format,
82 compression,
83 }
84 }
85
86 pub async fn connect(url: &str) -> Result<Self> {
88 let config = ClientConfig {
89 url: url.to_string(),
90 ..Default::default()
91 };
92
93 let mut client = Self::with_config(config);
94 client.do_connect().await?;
95 client.handshake().await?;
96
97 Ok(client)
98 }
99
100 async fn do_connect(&mut self) -> Result<()> {
102 info!("Connecting to {}", self.config.url);
103
104 let connect_future = connect_async(&self.config.url);
105 let (ws_stream, _) = timeout(self.config.connect_timeout, connect_future)
106 .await
107 .map_err(|_| Error::Timeout("Connection timeout".to_string()))?
108 .map_err(|e| Error::Connection(e.to_string()))?;
109
110 self.socket = Some(ws_stream);
111 info!("Connected to {}", self.config.url);
112
113 Ok(())
114 }
115
116 async fn handshake(&mut self) -> Result<()> {
118 debug!("Performing handshake");
119
120 let handshake_msg = Message::Handshake {
121 version: PROTOCOL_VERSION,
122 format: self.config.format,
123 compression: self.config.compression,
124 };
125
126 self.send_message(handshake_msg).await?;
127
128 let ack = timeout(self.config.message_timeout, self.receive_message())
130 .await
131 .map_err(|_| Error::Timeout("Handshake timeout".to_string()))??;
132
133 match ack {
134 Message::HandshakeAck {
135 version,
136 format,
137 compression,
138 } => {
139 if version != PROTOCOL_VERSION {
140 return Err(Error::Protocol(format!(
141 "Protocol version mismatch: expected {}, got {}",
142 PROTOCOL_VERSION, version
143 )));
144 }
145 self.format = format;
146 self.compression = compression;
147 info!(
148 "Handshake complete: format={:?}, compression={:?}",
149 format, compression
150 );
151 Ok(())
152 }
153 _ => Err(Error::Protocol(
154 "Expected handshake acknowledgement".to_string(),
155 )),
156 }
157 }
158
159 async fn send_message(&mut self, message: Message) -> Result<()> {
161 let socket = self
162 .socket
163 .as_mut()
164 .ok_or_else(|| Error::Connection("Not connected".to_string()))?;
165
166 let data = message.encode(self.format, self.compression)?;
167 socket
168 .send(WsMessage::Binary(data.into()))
169 .await
170 .map_err(|e| Error::Send(e.to_string()))?;
171
172 Ok(())
173 }
174
175 async fn receive_message(&mut self) -> Result<Message> {
177 let socket = self
178 .socket
179 .as_mut()
180 .ok_or_else(|| Error::Connection("Not connected".to_string()))?;
181
182 let msg = socket
183 .next()
184 .await
185 .ok_or_else(|| Error::Receive("Connection closed".to_string()))?
186 .map_err(|e| Error::Receive(e.to_string()))?;
187
188 let data = match msg {
189 WsMessage::Binary(payload) => payload.to_vec(),
190 WsMessage::Text(text) => text.as_bytes().to_vec(),
191 WsMessage::Close(_) => {
192 return Err(Error::Connection("Server closed connection".to_string()));
193 }
194 _ => {
195 return Err(Error::InvalidMessage("Unexpected message type".to_string()));
196 }
197 };
198
199 Message::decode(&data, self.format, self.compression)
200 }
201
202 pub async fn subscribe_tiles(
204 &mut self,
205 bbox: [f64; 4],
206 zoom_range: Range<u8>,
207 ) -> Result<String> {
208 let subscription_id = uuid::Uuid::new_v4().to_string();
209
210 let msg = Message::SubscribeTiles {
211 subscription_id: subscription_id.clone(),
212 bbox,
213 zoom_range,
214 tile_size: Some(256),
215 };
216
217 self.send_message(msg).await?;
218
219 let ack = timeout(self.config.message_timeout, self.receive_message())
221 .await
222 .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
223
224 match ack {
225 Message::Ack { success: true, .. } => Ok(subscription_id),
226 Message::Ack { message, .. } => Err(Error::Subscription(
227 message.unwrap_or_else(|| "Failed to subscribe".to_string()),
228 )),
229 Message::Error { message, .. } => Err(Error::Subscription(message)),
230 _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
231 }
232 }
233
234 pub async fn subscribe_features(&mut self, layer: Option<String>) -> Result<String> {
236 let subscription_id = uuid::Uuid::new_v4().to_string();
237
238 let msg = Message::SubscribeFeatures {
239 subscription_id: subscription_id.clone(),
240 bbox: None,
241 filters: None,
242 layer,
243 };
244
245 self.send_message(msg).await?;
246
247 let ack = timeout(self.config.message_timeout, self.receive_message())
249 .await
250 .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
251
252 match ack {
253 Message::Ack { success: true, .. } => Ok(subscription_id),
254 Message::Ack { message, .. } => Err(Error::Subscription(
255 message.unwrap_or_else(|| "Failed to subscribe".to_string()),
256 )),
257 Message::Error { message, .. } => Err(Error::Subscription(message)),
258 _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
259 }
260 }
261
262 pub async fn subscribe_events(&mut self, event_types: Vec<EventType>) -> Result<String> {
264 let subscription_id = uuid::Uuid::new_v4().to_string();
265
266 let msg = Message::SubscribeEvents {
267 subscription_id: subscription_id.clone(),
268 event_types,
269 };
270
271 self.send_message(msg).await?;
272
273 let ack = timeout(self.config.message_timeout, self.receive_message())
275 .await
276 .map_err(|_| Error::Timeout("Subscribe timeout".to_string()))??;
277
278 match ack {
279 Message::Ack { success: true, .. } => Ok(subscription_id),
280 Message::Ack { message, .. } => Err(Error::Subscription(
281 message.unwrap_or_else(|| "Failed to subscribe".to_string()),
282 )),
283 Message::Error { message, .. } => Err(Error::Subscription(message)),
284 _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
285 }
286 }
287
288 pub async fn unsubscribe(&mut self, subscription_id: &str) -> Result<()> {
290 let msg = Message::Unsubscribe {
291 subscription_id: subscription_id.to_string(),
292 };
293
294 self.send_message(msg).await?;
295
296 let ack = timeout(self.config.message_timeout, self.receive_message())
298 .await
299 .map_err(|_| Error::Timeout("Unsubscribe timeout".to_string()))??;
300
301 match ack {
302 Message::Ack { success: true, .. } => Ok(()),
303 Message::Ack { message, .. } => Err(Error::Subscription(
304 message.unwrap_or_else(|| "Failed to unsubscribe".to_string()),
305 )),
306 Message::Error { message, .. } => Err(Error::Subscription(message)),
307 _ => Err(Error::Protocol("Expected acknowledgement".to_string())),
308 }
309 }
310
311 pub fn message_stream(&mut self) -> Option<MessageStream> {
313 self.message_rx.take().map(MessageStream::new)
314 }
315
316 pub fn tile_stream(&mut self) -> TileStream {
318 let (tx, rx) = mpsc::unbounded_channel();
319
320 let message_rx = self.message_rx.take();
322 if let Some(mut message_rx) = message_rx {
323 tokio::spawn(async move {
324 while let Some(message) = message_rx.recv().await {
325 if let Message::TileData {
326 tile,
327 data,
328 mime_type,
329 ..
330 } = message
331 {
332 let tile_data = TileData::new(tile.0, tile.1, tile.2, data, mime_type);
333 if tx.send(tile_data).is_err() {
334 break;
335 }
336 }
337 }
338 });
339 }
340
341 TileStream::new(rx)
342 }
343
344 pub fn feature_stream(&mut self) -> FeatureStream {
346 let (tx, rx) = mpsc::unbounded_channel();
347
348 let message_rx = self.message_rx.take();
350 if let Some(mut message_rx) = message_rx {
351 tokio::spawn(async move {
352 while let Some(message) = message_rx.recv().await {
353 if let Message::FeatureData {
354 geojson,
355 change_type,
356 ..
357 } = message
358 {
359 let feature_data = FeatureData::new(geojson, change_type, None);
360 if tx.send(feature_data).is_err() {
361 break;
362 }
363 }
364 }
365 });
366 }
367
368 FeatureStream::new(rx)
369 }
370
371 pub fn event_stream(&mut self) -> EventStream {
373 let (tx, rx) = mpsc::unbounded_channel();
374
375 let message_rx = self.message_rx.take();
377 if let Some(mut message_rx) = message_rx {
378 tokio::spawn(async move {
379 while let Some(message) = message_rx.recv().await {
380 if let Message::Event {
381 event_type,
382 payload,
383 timestamp,
384 ..
385 } = message
386 {
387 if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(×tamp) {
388 let event_data = EventData::with_timestamp(
389 event_type,
390 payload,
391 ts.with_timezone(&chrono::Utc),
392 );
393 if tx.send(event_data).is_err() {
394 break;
395 }
396 }
397 }
398 }
399 });
400 }
401
402 EventStream::new(rx)
403 }
404
405 pub async fn ping(&mut self, id: u64) -> Result<()> {
407 self.send_message(Message::Ping { id }).await
408 }
409
410 pub async fn close(mut self) -> Result<()> {
412 if let Some(mut socket) = self.socket.take() {
413 socket
414 .close(None)
415 .await
416 .map_err(|e| Error::Connection(e.to_string()))?;
417 }
418 Ok(())
419 }
420}
421
422impl Default for WebSocketClient {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_client_config_default() {
434 let config = ClientConfig::default();
435 assert_eq!(config.url, "ws://localhost:9001/ws");
436 assert_eq!(config.format, MessageFormat::MessagePack);
437 assert_eq!(config.compression, Compression::Zstd);
438 assert!(config.auto_reconnect);
439 }
440
441 #[test]
442 fn test_client_creation() {
443 let client = WebSocketClient::new();
444 assert!(client.socket.is_none());
445 }
446}