Skip to main content

oxigdal_ws/
server.rs

1//! WebSocket server implementation.
2
3use crate::error::{Error, Result};
4use crate::protocol::{Compression, Message, MessageFormat};
5use crate::subscription::{Subscription, SubscriptionManager};
6use axum::{
7    Router,
8    extract::{
9        State,
10        ws::{WebSocket, WebSocketUpgrade},
11    },
12    response::IntoResponse,
13    routing::get,
14};
15use dashmap::DashMap;
16use futures::{SinkExt, StreamExt};
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::sync::mpsc;
20use tower_http::cors::CorsLayer;
21use tracing::{debug, error, info, warn};
22use uuid::Uuid;
23
24/// WebSocket server configuration.
25#[derive(Debug, Clone)]
26pub struct ServerConfig {
27    /// Bind address
28    pub bind_addr: SocketAddr,
29    /// Maximum connections
30    pub max_connections: usize,
31    /// Message buffer size per client
32    pub message_buffer_size: usize,
33    /// Default message format
34    pub default_format: MessageFormat,
35    /// Default compression
36    pub default_compression: Compression,
37    /// Enable CORS
38    pub enable_cors: bool,
39}
40
41impl Default for ServerConfig {
42    fn default() -> Self {
43        Self {
44            bind_addr: SocketAddr::from(([0, 0, 0, 0], 9001)),
45            max_connections: 10000,
46            message_buffer_size: 1000,
47            default_format: MessageFormat::MessagePack,
48            default_compression: Compression::Zstd,
49            enable_cors: true,
50        }
51    }
52}
53
54/// Client connection state.
55struct ClientState {
56    /// Client ID
57    id: String,
58    /// Message sender
59    tx: mpsc::UnboundedSender<Message>,
60    /// Message format preference
61    format: MessageFormat,
62    /// Compression preference
63    compression: Compression,
64}
65
66impl ClientState {
67    /// Send a message to the client.
68    fn send(&self, message: Message) -> Result<()> {
69        self.tx
70            .send(message)
71            .map_err(|_| Error::Send("Client disconnected".to_string()))
72    }
73}
74
75/// Shared server state.
76#[derive(Clone)]
77struct AppState {
78    /// Active clients
79    clients: Arc<DashMap<String, ClientState>>,
80    /// Subscription manager
81    subscriptions: Arc<SubscriptionManager>,
82    /// Server configuration
83    config: Arc<ServerConfig>,
84}
85
86impl AppState {
87    fn new(config: ServerConfig) -> Self {
88        Self {
89            clients: Arc::new(DashMap::new()),
90            subscriptions: Arc::new(SubscriptionManager::new()),
91            config: Arc::new(config),
92        }
93    }
94
95    /// Broadcast message to all clients.
96    fn broadcast(&self, message: Message) {
97        for client in self.clients.iter() {
98            if let Err(e) = client.send(message.clone()) {
99                warn!("Failed to send to client {}: {}", client.id, e);
100            }
101        }
102    }
103
104    /// Send message to specific client.
105    fn send_to_client(&self, client_id: &str, message: Message) -> Result<()> {
106        if let Some(client) = self.clients.get(client_id) {
107            client.send(message)
108        } else {
109            Err(Error::NotFound(format!("Client not found: {}", client_id)))
110        }
111    }
112
113    /// Send message to all subscribers of a subscription type.
114    #[allow(dead_code)]
115    fn send_to_subscribers(&self, subscription_id: &str, message: Message) {
116        if let Some(sub) = self.subscriptions.get(subscription_id) {
117            if let Err(e) = self.send_to_client(&sub.client_id, message) {
118                warn!("Failed to send to subscriber {}: {}", sub.client_id, e);
119            }
120        }
121    }
122}
123
124/// WebSocket server.
125pub struct WebSocketServer {
126    state: AppState,
127}
128
129impl WebSocketServer {
130    /// Create a new WebSocket server with default configuration.
131    pub fn new() -> Self {
132        Self::with_config(ServerConfig::default())
133    }
134
135    /// Create a new WebSocket server with custom configuration.
136    pub fn with_config(config: ServerConfig) -> Self {
137        Self {
138            state: AppState::new(config),
139        }
140    }
141
142    /// Create a builder for the server.
143    pub fn builder() -> ServerBuilder {
144        ServerBuilder::new()
145    }
146
147    /// Run the WebSocket server.
148    pub async fn run(self) -> Result<()> {
149        let bind_addr = self.state.config.bind_addr;
150
151        let mut app = Router::new()
152            .route("/ws", get(ws_handler))
153            .route("/health", get(health_handler))
154            .with_state(self.state.clone());
155
156        if self.state.config.enable_cors {
157            app = app.layer(CorsLayer::permissive());
158        }
159
160        info!("WebSocket server listening on {}", bind_addr);
161
162        let listener = tokio::net::TcpListener::bind(bind_addr)
163            .await
164            .map_err(|e| Error::Server(format!("Failed to bind: {}", e)))?;
165
166        axum::serve(listener, app)
167            .await
168            .map_err(|e| Error::Server(format!("Server error: {}", e)))?;
169
170        Ok(())
171    }
172
173    /// Get server statistics.
174    pub fn stats(&self) -> ServerStats {
175        ServerStats {
176            active_connections: self.state.clients.len(),
177            total_subscriptions: self.state.subscriptions.count(),
178            unique_clients: self.state.subscriptions.client_count(),
179        }
180    }
181
182    /// Broadcast a message to all connected clients.
183    pub fn broadcast(&self, message: Message) {
184        self.state.broadcast(message);
185    }
186
187    /// Send a message to a specific client.
188    pub fn send_to_client(&self, client_id: &str, message: Message) -> Result<()> {
189        self.state.send_to_client(client_id, message)
190    }
191
192    /// Get subscription manager.
193    pub fn subscriptions(&self) -> &SubscriptionManager {
194        &self.state.subscriptions
195    }
196}
197
198impl Default for WebSocketServer {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Server statistics.
205#[derive(Debug, Clone)]
206pub struct ServerStats {
207    /// Number of active WebSocket connections
208    pub active_connections: usize,
209    /// Total number of subscriptions
210    pub total_subscriptions: usize,
211    /// Number of unique clients with subscriptions
212    pub unique_clients: usize,
213}
214
215/// Builder for WebSocket server.
216pub struct ServerBuilder {
217    config: ServerConfig,
218}
219
220impl ServerBuilder {
221    /// Create a new server builder.
222    pub fn new() -> Self {
223        Self {
224            config: ServerConfig::default(),
225        }
226    }
227
228    /// Set bind address.
229    pub fn bind(mut self, addr: &str) -> Result<Self> {
230        self.config.bind_addr = addr
231            .parse()
232            .map_err(|e| Error::InvalidParameter(format!("Invalid address: {}", e)))?;
233        Ok(self)
234    }
235
236    /// Set maximum connections.
237    pub fn max_connections(mut self, max: usize) -> Self {
238        self.config.max_connections = max;
239        self
240    }
241
242    /// Set message buffer size.
243    pub fn message_buffer_size(mut self, size: usize) -> Self {
244        self.config.message_buffer_size = size;
245        self
246    }
247
248    /// Set default message format.
249    pub fn default_format(mut self, format: MessageFormat) -> Self {
250        self.config.default_format = format;
251        self
252    }
253
254    /// Set default compression.
255    pub fn default_compression(mut self, compression: Compression) -> Self {
256        self.config.default_compression = compression;
257        self
258    }
259
260    /// Enable CORS.
261    pub fn enable_cors(mut self, enable: bool) -> Self {
262        self.config.enable_cors = enable;
263        self
264    }
265
266    /// Build the server.
267    pub fn build(self) -> WebSocketServer {
268        WebSocketServer::with_config(self.config)
269    }
270}
271
272impl Default for ServerBuilder {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Health check handler.
279async fn health_handler() -> &'static str {
280    "OK"
281}
282
283/// WebSocket upgrade handler.
284async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
285    ws.on_upgrade(|socket| handle_socket(socket, state))
286}
287
288/// Handle WebSocket connection.
289async fn handle_socket(socket: WebSocket, state: AppState) {
290    let client_id = Uuid::new_v4().to_string();
291    info!("New WebSocket connection: {}", client_id);
292
293    let (mut sender, mut receiver) = socket.split();
294    let (tx, mut rx) = mpsc::unbounded_channel();
295
296    // Default protocol settings
297    let mut format = state.config.default_format;
298    let mut compression = state.config.default_compression;
299
300    // Add client to state
301    let client_state = ClientState {
302        id: client_id.clone(),
303        tx: tx.clone(),
304        format,
305        compression,
306    };
307    state.clients.insert(client_id.clone(), client_state);
308
309    // Spawn task to send messages to client
310    let client_id_clone = client_id.clone();
311    tokio::spawn(async move {
312        while let Some(message) = rx.recv().await {
313            // Encode message
314            let data = match message.encode(format, compression) {
315                Ok(data) => data,
316                Err(e) => {
317                    error!("Failed to encode message: {}", e);
318                    continue;
319                }
320            };
321
322            // Send as binary message
323            if let Err(e) = sender
324                .send(axum::extract::ws::Message::Binary(data.into()))
325                .await
326            {
327                error!("Failed to send message to {}: {}", client_id_clone, e);
328                break;
329            }
330        }
331    });
332
333    // Handle incoming messages
334    while let Some(msg) = receiver.next().await {
335        let msg = match msg {
336            Ok(msg) => msg,
337            Err(e) => {
338                error!("WebSocket error for {}: {}", client_id, e);
339                break;
340            }
341        };
342
343        let data = match msg {
344            axum::extract::ws::Message::Binary(data) => data.to_vec(),
345            axum::extract::ws::Message::Text(text) => text.as_bytes().to_vec(),
346            axum::extract::ws::Message::Close(_) => {
347                info!("Client {} disconnected", client_id);
348                break;
349            }
350            axum::extract::ws::Message::Ping(_) | axum::extract::ws::Message::Pong(_) => {
351                continue;
352            }
353        };
354
355        // Decode message
356        let message = match Message::decode(&data, format, compression) {
357            Ok(msg) => msg,
358            Err(e) => {
359                error!("Failed to decode message from {}: {}", client_id, e);
360                continue;
361            }
362        };
363
364        // Handle message
365        if let Err(e) =
366            handle_message(message, &client_id, &state, &mut format, &mut compression).await
367        {
368            error!("Error handling message from {}: {}", client_id, e);
369        }
370    }
371
372    // Cleanup on disconnect
373    info!("Cleaning up client {}", client_id);
374    state.clients.remove(&client_id);
375    if let Err(e) = state.subscriptions.remove_client(&client_id) {
376        error!("Failed to remove client subscriptions: {}", e);
377    }
378}
379
380/// Handle a received message.
381async fn handle_message(
382    message: Message,
383    client_id: &str,
384    state: &AppState,
385    format: &mut MessageFormat,
386    compression: &mut Compression,
387) -> Result<()> {
388    match message {
389        Message::Handshake {
390            version,
391            format: client_format,
392            compression: client_compression,
393        } => {
394            debug!("Handshake from {}: v{}", client_id, version);
395
396            // Negotiate protocol
397            *format = client_format;
398            *compression = client_compression;
399
400            // Update client state
401            if let Some(mut client) = state.clients.get_mut(client_id) {
402                client.format = *format;
403                client.compression = *compression;
404            }
405
406            // Send acknowledgement
407            state.send_to_client(
408                client_id,
409                Message::HandshakeAck {
410                    version,
411                    format: *format,
412                    compression: *compression,
413                },
414            )?;
415        }
416
417        Message::SubscribeTiles {
418            subscription_id,
419            bbox,
420            zoom_range,
421            ..
422        } => {
423            debug!("Subscribe tiles from {}: {}", client_id, subscription_id);
424
425            let sub = Subscription::tiles(client_id.to_string(), bbox, zoom_range, None);
426            state.subscriptions.add(sub)?;
427
428            state.send_to_client(
429                client_id,
430                Message::Ack {
431                    request_id: subscription_id,
432                    success: true,
433                    message: Some("Subscribed to tiles".to_string()),
434                },
435            )?;
436        }
437
438        Message::SubscribeFeatures {
439            subscription_id,
440            layer,
441            ..
442        } => {
443            debug!("Subscribe features from {}: {}", client_id, subscription_id);
444
445            let sub = Subscription::features(client_id.to_string(), layer, None);
446            state.subscriptions.add(sub)?;
447
448            state.send_to_client(
449                client_id,
450                Message::Ack {
451                    request_id: subscription_id,
452                    success: true,
453                    message: Some("Subscribed to features".to_string()),
454                },
455            )?;
456        }
457
458        Message::SubscribeEvents {
459            subscription_id,
460            event_types,
461        } => {
462            debug!("Subscribe events from {}: {}", client_id, subscription_id);
463
464            let event_types_set = event_types.into_iter().collect();
465            let sub = Subscription::events(client_id.to_string(), event_types_set, None);
466            state.subscriptions.add(sub)?;
467
468            state.send_to_client(
469                client_id,
470                Message::Ack {
471                    request_id: subscription_id,
472                    success: true,
473                    message: Some("Subscribed to events".to_string()),
474                },
475            )?;
476        }
477
478        Message::Unsubscribe { subscription_id } => {
479            debug!("Unsubscribe from {}: {}", client_id, subscription_id);
480
481            state.subscriptions.remove(&subscription_id)?;
482
483            state.send_to_client(
484                client_id,
485                Message::Ack {
486                    request_id: subscription_id,
487                    success: true,
488                    message: Some("Unsubscribed".to_string()),
489                },
490            )?;
491        }
492
493        Message::Ping { id } => {
494            state.send_to_client(client_id, Message::Pong { id })?;
495        }
496
497        _ => {
498            warn!("Unexpected message type from {}", client_id);
499        }
500    }
501
502    Ok(())
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_server_config_default() {
511        let config = ServerConfig::default();
512        assert_eq!(config.max_connections, 10000);
513        assert_eq!(config.message_buffer_size, 1000);
514        assert!(config.enable_cors);
515    }
516
517    #[test]
518    fn test_server_builder() {
519        let result = ServerBuilder::new().bind("127.0.0.1:8080");
520        assert!(result.is_ok());
521        if let Ok(builder) = result {
522            let server = builder
523                .max_connections(5000)
524                .message_buffer_size(500)
525                .default_format(MessageFormat::Json)
526                .enable_cors(false)
527                .build();
528
529            assert_eq!(server.state.config.bind_addr.to_string(), "127.0.0.1:8080");
530            assert_eq!(server.state.config.max_connections, 5000);
531            assert_eq!(server.state.config.message_buffer_size, 500);
532            assert_eq!(server.state.config.default_format, MessageFormat::Json);
533            assert!(!server.state.config.enable_cors);
534        }
535    }
536
537    #[test]
538    fn test_app_state() {
539        let state = AppState::new(ServerConfig::default());
540
541        assert_eq!(state.clients.len(), 0);
542        assert_eq!(state.subscriptions.count(), 0);
543    }
544}