1use crate::error::{Error, Result};
4use crate::protocol::{Compression, Message, MessageFormat};
5use crate::subscription::{Subscription, SubscriptionManager};
6use axum::{
7 Router,
8 extract::{
9 State,
10 ws::{WebSocket, WebSocketUpgrade},
11 },
12 response::IntoResponse,
13 routing::get,
14};
15use dashmap::DashMap;
16use futures::{SinkExt, StreamExt};
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::sync::mpsc;
20use tower_http::cors::CorsLayer;
21use tracing::{debug, error, info, warn};
22use uuid::Uuid;
23
24#[derive(Debug, Clone)]
26pub struct ServerConfig {
27 pub bind_addr: SocketAddr,
29 pub max_connections: usize,
31 pub message_buffer_size: usize,
33 pub default_format: MessageFormat,
35 pub default_compression: Compression,
37 pub enable_cors: bool,
39}
40
41impl Default for ServerConfig {
42 fn default() -> Self {
43 Self {
44 bind_addr: SocketAddr::from(([0, 0, 0, 0], 9001)),
45 max_connections: 10000,
46 message_buffer_size: 1000,
47 default_format: MessageFormat::MessagePack,
48 default_compression: Compression::Zstd,
49 enable_cors: true,
50 }
51 }
52}
53
54struct ClientState {
56 id: String,
58 tx: mpsc::UnboundedSender<Message>,
60 format: MessageFormat,
62 compression: Compression,
64}
65
66impl ClientState {
67 fn send(&self, message: Message) -> Result<()> {
69 self.tx
70 .send(message)
71 .map_err(|_| Error::Send("Client disconnected".to_string()))
72 }
73}
74
75#[derive(Clone)]
77struct AppState {
78 clients: Arc<DashMap<String, ClientState>>,
80 subscriptions: Arc<SubscriptionManager>,
82 config: Arc<ServerConfig>,
84}
85
86impl AppState {
87 fn new(config: ServerConfig) -> Self {
88 Self {
89 clients: Arc::new(DashMap::new()),
90 subscriptions: Arc::new(SubscriptionManager::new()),
91 config: Arc::new(config),
92 }
93 }
94
95 fn broadcast(&self, message: Message) {
97 for client in self.clients.iter() {
98 if let Err(e) = client.send(message.clone()) {
99 warn!("Failed to send to client {}: {}", client.id, e);
100 }
101 }
102 }
103
104 fn send_to_client(&self, client_id: &str, message: Message) -> Result<()> {
106 if let Some(client) = self.clients.get(client_id) {
107 client.send(message)
108 } else {
109 Err(Error::NotFound(format!("Client not found: {}", client_id)))
110 }
111 }
112
113 #[allow(dead_code)]
115 fn send_to_subscribers(&self, subscription_id: &str, message: Message) {
116 if let Some(sub) = self.subscriptions.get(subscription_id) {
117 if let Err(e) = self.send_to_client(&sub.client_id, message) {
118 warn!("Failed to send to subscriber {}: {}", sub.client_id, e);
119 }
120 }
121 }
122}
123
124pub struct WebSocketServer {
126 state: AppState,
127}
128
129impl WebSocketServer {
130 pub fn new() -> Self {
132 Self::with_config(ServerConfig::default())
133 }
134
135 pub fn with_config(config: ServerConfig) -> Self {
137 Self {
138 state: AppState::new(config),
139 }
140 }
141
142 pub fn builder() -> ServerBuilder {
144 ServerBuilder::new()
145 }
146
147 pub async fn run(self) -> Result<()> {
149 let bind_addr = self.state.config.bind_addr;
150
151 let mut app = Router::new()
152 .route("/ws", get(ws_handler))
153 .route("/health", get(health_handler))
154 .with_state(self.state.clone());
155
156 if self.state.config.enable_cors {
157 app = app.layer(CorsLayer::permissive());
158 }
159
160 info!("WebSocket server listening on {}", bind_addr);
161
162 let listener = tokio::net::TcpListener::bind(bind_addr)
163 .await
164 .map_err(|e| Error::Server(format!("Failed to bind: {}", e)))?;
165
166 axum::serve(listener, app)
167 .await
168 .map_err(|e| Error::Server(format!("Server error: {}", e)))?;
169
170 Ok(())
171 }
172
173 pub fn stats(&self) -> ServerStats {
175 ServerStats {
176 active_connections: self.state.clients.len(),
177 total_subscriptions: self.state.subscriptions.count(),
178 unique_clients: self.state.subscriptions.client_count(),
179 }
180 }
181
182 pub fn broadcast(&self, message: Message) {
184 self.state.broadcast(message);
185 }
186
187 pub fn send_to_client(&self, client_id: &str, message: Message) -> Result<()> {
189 self.state.send_to_client(client_id, message)
190 }
191
192 pub fn subscriptions(&self) -> &SubscriptionManager {
194 &self.state.subscriptions
195 }
196}
197
198impl Default for WebSocketServer {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct ServerStats {
207 pub active_connections: usize,
209 pub total_subscriptions: usize,
211 pub unique_clients: usize,
213}
214
215pub struct ServerBuilder {
217 config: ServerConfig,
218}
219
220impl ServerBuilder {
221 pub fn new() -> Self {
223 Self {
224 config: ServerConfig::default(),
225 }
226 }
227
228 pub fn bind(mut self, addr: &str) -> Result<Self> {
230 self.config.bind_addr = addr
231 .parse()
232 .map_err(|e| Error::InvalidParameter(format!("Invalid address: {}", e)))?;
233 Ok(self)
234 }
235
236 pub fn max_connections(mut self, max: usize) -> Self {
238 self.config.max_connections = max;
239 self
240 }
241
242 pub fn message_buffer_size(mut self, size: usize) -> Self {
244 self.config.message_buffer_size = size;
245 self
246 }
247
248 pub fn default_format(mut self, format: MessageFormat) -> Self {
250 self.config.default_format = format;
251 self
252 }
253
254 pub fn default_compression(mut self, compression: Compression) -> Self {
256 self.config.default_compression = compression;
257 self
258 }
259
260 pub fn enable_cors(mut self, enable: bool) -> Self {
262 self.config.enable_cors = enable;
263 self
264 }
265
266 pub fn build(self) -> WebSocketServer {
268 WebSocketServer::with_config(self.config)
269 }
270}
271
272impl Default for ServerBuilder {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278async fn health_handler() -> &'static str {
280 "OK"
281}
282
283async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
285 ws.on_upgrade(|socket| handle_socket(socket, state))
286}
287
288async fn handle_socket(socket: WebSocket, state: AppState) {
290 let client_id = Uuid::new_v4().to_string();
291 info!("New WebSocket connection: {}", client_id);
292
293 let (mut sender, mut receiver) = socket.split();
294 let (tx, mut rx) = mpsc::unbounded_channel();
295
296 let mut format = state.config.default_format;
298 let mut compression = state.config.default_compression;
299
300 let client_state = ClientState {
302 id: client_id.clone(),
303 tx: tx.clone(),
304 format,
305 compression,
306 };
307 state.clients.insert(client_id.clone(), client_state);
308
309 let client_id_clone = client_id.clone();
311 tokio::spawn(async move {
312 while let Some(message) = rx.recv().await {
313 let data = match message.encode(format, compression) {
315 Ok(data) => data,
316 Err(e) => {
317 error!("Failed to encode message: {}", e);
318 continue;
319 }
320 };
321
322 if let Err(e) = sender
324 .send(axum::extract::ws::Message::Binary(data.into()))
325 .await
326 {
327 error!("Failed to send message to {}: {}", client_id_clone, e);
328 break;
329 }
330 }
331 });
332
333 while let Some(msg) = receiver.next().await {
335 let msg = match msg {
336 Ok(msg) => msg,
337 Err(e) => {
338 error!("WebSocket error for {}: {}", client_id, e);
339 break;
340 }
341 };
342
343 let data = match msg {
344 axum::extract::ws::Message::Binary(data) => data.to_vec(),
345 axum::extract::ws::Message::Text(text) => text.as_bytes().to_vec(),
346 axum::extract::ws::Message::Close(_) => {
347 info!("Client {} disconnected", client_id);
348 break;
349 }
350 axum::extract::ws::Message::Ping(_) | axum::extract::ws::Message::Pong(_) => {
351 continue;
352 }
353 };
354
355 let message = match Message::decode(&data, format, compression) {
357 Ok(msg) => msg,
358 Err(e) => {
359 error!("Failed to decode message from {}: {}", client_id, e);
360 continue;
361 }
362 };
363
364 if let Err(e) =
366 handle_message(message, &client_id, &state, &mut format, &mut compression).await
367 {
368 error!("Error handling message from {}: {}", client_id, e);
369 }
370 }
371
372 info!("Cleaning up client {}", client_id);
374 state.clients.remove(&client_id);
375 if let Err(e) = state.subscriptions.remove_client(&client_id) {
376 error!("Failed to remove client subscriptions: {}", e);
377 }
378}
379
380async fn handle_message(
382 message: Message,
383 client_id: &str,
384 state: &AppState,
385 format: &mut MessageFormat,
386 compression: &mut Compression,
387) -> Result<()> {
388 match message {
389 Message::Handshake {
390 version,
391 format: client_format,
392 compression: client_compression,
393 } => {
394 debug!("Handshake from {}: v{}", client_id, version);
395
396 *format = client_format;
398 *compression = client_compression;
399
400 if let Some(mut client) = state.clients.get_mut(client_id) {
402 client.format = *format;
403 client.compression = *compression;
404 }
405
406 state.send_to_client(
408 client_id,
409 Message::HandshakeAck {
410 version,
411 format: *format,
412 compression: *compression,
413 },
414 )?;
415 }
416
417 Message::SubscribeTiles {
418 subscription_id,
419 bbox,
420 zoom_range,
421 ..
422 } => {
423 debug!("Subscribe tiles from {}: {}", client_id, subscription_id);
424
425 let sub = Subscription::tiles(client_id.to_string(), bbox, zoom_range, None);
426 state.subscriptions.add(sub)?;
427
428 state.send_to_client(
429 client_id,
430 Message::Ack {
431 request_id: subscription_id,
432 success: true,
433 message: Some("Subscribed to tiles".to_string()),
434 },
435 )?;
436 }
437
438 Message::SubscribeFeatures {
439 subscription_id,
440 layer,
441 ..
442 } => {
443 debug!("Subscribe features from {}: {}", client_id, subscription_id);
444
445 let sub = Subscription::features(client_id.to_string(), layer, None);
446 state.subscriptions.add(sub)?;
447
448 state.send_to_client(
449 client_id,
450 Message::Ack {
451 request_id: subscription_id,
452 success: true,
453 message: Some("Subscribed to features".to_string()),
454 },
455 )?;
456 }
457
458 Message::SubscribeEvents {
459 subscription_id,
460 event_types,
461 } => {
462 debug!("Subscribe events from {}: {}", client_id, subscription_id);
463
464 let event_types_set = event_types.into_iter().collect();
465 let sub = Subscription::events(client_id.to_string(), event_types_set, None);
466 state.subscriptions.add(sub)?;
467
468 state.send_to_client(
469 client_id,
470 Message::Ack {
471 request_id: subscription_id,
472 success: true,
473 message: Some("Subscribed to events".to_string()),
474 },
475 )?;
476 }
477
478 Message::Unsubscribe { subscription_id } => {
479 debug!("Unsubscribe from {}: {}", client_id, subscription_id);
480
481 state.subscriptions.remove(&subscription_id)?;
482
483 state.send_to_client(
484 client_id,
485 Message::Ack {
486 request_id: subscription_id,
487 success: true,
488 message: Some("Unsubscribed".to_string()),
489 },
490 )?;
491 }
492
493 Message::Ping { id } => {
494 state.send_to_client(client_id, Message::Pong { id })?;
495 }
496
497 _ => {
498 warn!("Unexpected message type from {}", client_id);
499 }
500 }
501
502 Ok(())
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_server_config_default() {
511 let config = ServerConfig::default();
512 assert_eq!(config.max_connections, 10000);
513 assert_eq!(config.message_buffer_size, 1000);
514 assert!(config.enable_cors);
515 }
516
517 #[test]
518 fn test_server_builder() {
519 let result = ServerBuilder::new().bind("127.0.0.1:8080");
520 assert!(result.is_ok());
521 if let Ok(builder) = result {
522 let server = builder
523 .max_connections(5000)
524 .message_buffer_size(500)
525 .default_format(MessageFormat::Json)
526 .enable_cors(false)
527 .build();
528
529 assert_eq!(server.state.config.bind_addr.to_string(), "127.0.0.1:8080");
530 assert_eq!(server.state.config.max_connections, 5000);
531 assert_eq!(server.state.config.message_buffer_size, 500);
532 assert_eq!(server.state.config.default_format, MessageFormat::Json);
533 assert!(!server.state.config.enable_cors);
534 }
535 }
536
537 #[test]
538 fn test_app_state() {
539 let state = AppState::new(ServerConfig::default());
540
541 assert_eq!(state.clients.len(), 0);
542 assert_eq!(state.subscriptions.count(), 0);
543 }
544}