armature_websocket/
server.rs

1//! WebSocket server implementation.
2
3use crate::connection::{Connection, ConnectionWriter};
4use crate::error::{WebSocketError, WebSocketResult};
5use crate::handler::WebSocketHandler;
6use crate::message::Message;
7use crate::room::RoomManager;
8use futures_util::StreamExt;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::mpsc;
14use tokio_tungstenite::accept_async;
15
16/// WebSocket server configuration.
17#[derive(Debug, Clone)]
18pub struct WebSocketServerConfig {
19    /// Address to bind to
20    pub bind_addr: SocketAddr,
21    /// Maximum message size in bytes
22    pub max_message_size: usize,
23    /// Heartbeat interval
24    pub heartbeat_interval: Duration,
25    /// Connection timeout
26    pub connection_timeout: Duration,
27}
28
29impl Default for WebSocketServerConfig {
30    fn default() -> Self {
31        Self {
32            bind_addr: "0.0.0.0:9001".parse().unwrap(),
33            max_message_size: 64 * 1024, // 64KB
34            heartbeat_interval: Duration::from_secs(30),
35            connection_timeout: Duration::from_secs(60),
36        }
37    }
38}
39
40/// Builder for WebSocket server configuration.
41#[derive(Debug, Default)]
42pub struct WebSocketServerBuilder {
43    config: WebSocketServerConfig,
44}
45
46impl WebSocketServerBuilder {
47    /// Create a new builder.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Set the bind address.
53    pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
54        self.config.bind_addr = addr;
55        self
56    }
57
58    /// Set the bind address from a string.
59    pub fn bind(mut self, addr: &str) -> WebSocketResult<Self> {
60        self.config.bind_addr = addr
61            .parse()
62            .map_err(|e| WebSocketError::Server(format!("Invalid address: {}", e)))?;
63        Ok(self)
64    }
65
66    /// Set the maximum message size.
67    pub fn max_message_size(mut self, size: usize) -> Self {
68        self.config.max_message_size = size;
69        self
70    }
71
72    /// Set the heartbeat interval.
73    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
74        self.config.heartbeat_interval = interval;
75        self
76    }
77
78    /// Set the connection timeout.
79    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
80        self.config.connection_timeout = timeout;
81        self
82    }
83
84    /// Build the server with the given handler.
85    pub fn build<H: WebSocketHandler>(self, handler: H) -> WebSocketServer<H> {
86        WebSocketServer::new(self.config, handler)
87    }
88}
89
90/// WebSocket server.
91pub struct WebSocketServer<H: WebSocketHandler> {
92    config: WebSocketServerConfig,
93    handler: Arc<H>,
94    room_manager: Arc<RoomManager>,
95}
96
97impl<H: WebSocketHandler> WebSocketServer<H> {
98    /// Create a new WebSocket server.
99    pub fn new(config: WebSocketServerConfig, handler: H) -> Self {
100        Self {
101            config,
102            handler: Arc::new(handler),
103            room_manager: Arc::new(RoomManager::new()),
104        }
105    }
106
107    /// Create a builder for the server.
108    pub fn builder() -> WebSocketServerBuilder {
109        WebSocketServerBuilder::new()
110    }
111
112    /// Get a reference to the room manager.
113    pub fn room_manager(&self) -> &Arc<RoomManager> {
114        &self.room_manager
115    }
116
117    /// Run the server.
118    pub async fn run(&self) -> WebSocketResult<()> {
119        let listener = TcpListener::bind(self.config.bind_addr).await?;
120        tracing::info!(addr = %self.config.bind_addr, "WebSocket server listening");
121
122        loop {
123            match listener.accept().await {
124                Ok((stream, addr)) => {
125                    let handler = Arc::clone(&self.handler);
126                    let room_manager = Arc::clone(&self.room_manager);
127                    let config = self.config.clone();
128
129                    tokio::spawn(async move {
130                        if let Err(e) =
131                            Self::handle_connection(stream, addr, handler, room_manager, config)
132                                .await
133                        {
134                            tracing::error!(addr = %addr, error = %e, "Connection error");
135                        }
136                    });
137                }
138                Err(e) => {
139                    tracing::error!(error = %e, "Failed to accept connection");
140                }
141            }
142        }
143    }
144
145    /// Handle a single connection.
146    async fn handle_connection(
147        stream: TcpStream,
148        addr: SocketAddr,
149        handler: Arc<H>,
150        room_manager: Arc<RoomManager>,
151        _config: WebSocketServerConfig,
152    ) -> WebSocketResult<()> {
153        let ws_stream = accept_async(stream).await?;
154        let connection_id = uuid::Uuid::new_v4().to_string();
155
156        tracing::debug!(connection_id = %connection_id, addr = %addr, "WebSocket connection established");
157
158        // Split the WebSocket stream
159        let (write, mut read) = ws_stream.split();
160
161        // Create message channel
162        let (tx, rx) = mpsc::unbounded_channel();
163
164        // Create connection object
165        let connection = Connection::new(connection_id.clone(), Some(addr), tx);
166
167        // Register connection
168        room_manager.register_connection(connection.clone());
169
170        // Notify handler of connection
171        handler.on_connect(&connection_id).await;
172
173        // Spawn writer task
174        let writer = ConnectionWriter::new(write, rx);
175        let writer_handle = tokio::spawn(async move { writer.run().await });
176
177        // Read messages
178        while let Some(result) = read.next().await {
179            match result {
180                Ok(msg) => {
181                    if msg.is_close() {
182                        break;
183                    }
184
185                    let message: Message = msg.into();
186
187                    // Handle ping/pong
188                    if message.is_ping() {
189                        let pong_payload = handler.on_ping(&connection_id, message.as_bytes()).await;
190                        let _ = connection.send(Message::pong(pong_payload));
191                        continue;
192                    }
193
194                    if message.is_pong() {
195                        handler.on_pong(&connection_id, message.as_bytes()).await;
196                        continue;
197                    }
198
199                    // Handle regular message
200                    handler.on_message(&connection_id, message).await;
201                }
202                Err(e) => {
203                    let ws_error = WebSocketError::Protocol(e);
204                    handler.on_error(&connection_id, &ws_error).await;
205                    break;
206                }
207            }
208        }
209
210        // Close connection
211        connection.close();
212
213        // Wait for writer to finish
214        let _ = writer_handle.await;
215
216        // Notify handler of disconnection
217        handler.on_disconnect(&connection_id).await;
218
219        // Unregister connection
220        room_manager.unregister_connection(&connection_id);
221
222        tracing::debug!(connection_id = %connection_id, "WebSocket connection closed");
223
224        Ok(())
225    }
226}
227