Skip to main content

oxigdal_websocket/server/
ws_server.rs

1//! WebSocket server implementation
2
3use crate::error::{Error, Result};
4use crate::protocol::{MessageFormat, ProtocolCodec, ProtocolConfig};
5use crate::server::connection::Connection;
6use crate::server::heartbeat::{HeartbeatConfig, HeartbeatMonitor};
7use crate::server::manager::ConnectionManager;
8use crate::server::pool::{ConnectionPool, PoolConfig};
9use crate::server::{DEFAULT_MAX_CONNECTIONS, DEFAULT_MAX_MESSAGE_SIZE};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::RwLock;
14use tokio_tungstenite::accept_async;
15
16/// Server configuration
17#[derive(Debug, Clone)]
18pub struct ServerConfig {
19    /// Bind address
20    pub bind_addr: SocketAddr,
21    /// Maximum connections
22    pub max_connections: usize,
23    /// Maximum message size
24    pub max_message_size: usize,
25    /// Protocol configuration
26    pub protocol: ProtocolConfig,
27    /// Heartbeat configuration
28    pub heartbeat: HeartbeatConfig,
29    /// Pool configuration
30    pub pool: PoolConfig,
31}
32
33impl Default for ServerConfig {
34    fn default() -> Self {
35        Self {
36            bind_addr: SocketAddr::from(([0, 0, 0, 0], 9001)),
37            max_connections: DEFAULT_MAX_CONNECTIONS,
38            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
39            protocol: ProtocolConfig::default(),
40            heartbeat: HeartbeatConfig::default(),
41            pool: PoolConfig::default(),
42        }
43    }
44}
45
46/// WebSocket server
47pub struct Server {
48    config: ServerConfig,
49    manager: Arc<ConnectionManager>,
50    pool: Arc<ConnectionPool>,
51    heartbeat: Arc<HeartbeatMonitor>,
52    shutdown: Arc<RwLock<bool>>,
53}
54
55impl Server {
56    /// Create a new server
57    pub fn new(config: ServerConfig) -> Self {
58        let manager = Arc::new(ConnectionManager::new(config.max_connections));
59        let pool = Arc::new(ConnectionPool::new(config.pool.clone()));
60        let heartbeat = Arc::new(HeartbeatMonitor::new(config.heartbeat.clone()));
61
62        Self {
63            config,
64            manager,
65            pool,
66            heartbeat,
67            shutdown: Arc::new(RwLock::new(false)),
68        }
69    }
70
71    /// Create a server builder
72    pub fn builder() -> ServerBuilder {
73        ServerBuilder::new()
74    }
75
76    /// Start the server
77    pub async fn start(self: Arc<Self>) -> Result<()> {
78        tracing::info!("Starting WebSocket server on {}", self.config.bind_addr);
79
80        // Start heartbeat monitor
81        let heartbeat = self.heartbeat.clone();
82        tokio::spawn(async move {
83            if let Err(e) = heartbeat.start().await {
84                tracing::error!("Heartbeat monitor error: {}", e);
85            }
86        });
87
88        // Start pool maintenance task
89        let pool = self.pool.clone();
90        let shutdown = self.shutdown.clone();
91        tokio::spawn(async move {
92            let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
93            loop {
94                if *shutdown.read().await {
95                    break;
96                }
97                interval.tick().await;
98                if let Err(e) = pool.maintain().await {
99                    tracing::error!("Pool maintenance error: {}", e);
100                }
101            }
102        });
103
104        // Bind listener
105        let listener = TcpListener::bind(&self.config.bind_addr)
106            .await
107            .map_err(|e| Error::Connection(format!("Failed to bind: {}", e)))?;
108
109        tracing::info!("Server listening on {}", self.config.bind_addr);
110
111        // Accept connections
112        loop {
113            // Check for shutdown
114            if *self.shutdown.read().await {
115                break;
116            }
117
118            match listener.accept().await {
119                Ok((stream, addr)) => {
120                    let server = self.clone();
121                    tokio::spawn(async move {
122                        if let Err(e) = server.handle_connection(stream, addr).await {
123                            tracing::error!("Connection error from {}: {}", addr, e);
124                        }
125                    });
126                }
127                Err(e) => {
128                    tracing::error!("Accept error: {}", e);
129                }
130            }
131        }
132
133        tracing::info!("Server stopped");
134        Ok(())
135    }
136
137    /// Handle a new connection
138    async fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) -> Result<()> {
139        tracing::info!("New connection from {}", addr);
140
141        // Perform WebSocket handshake
142        let ws_stream = accept_async(stream)
143            .await
144            .map_err(|e| Error::WebSocket(format!("Handshake failed: {}", e)))?;
145
146        // Create protocol codec
147        let codec = ProtocolCodec::new(self.config.protocol.clone());
148
149        // Create connection
150        let (connection, rx) = Connection::new(ws_stream, addr, codec);
151        let connection = Arc::new(connection);
152
153        // Add to manager
154        self.manager.add(connection.clone())?;
155
156        // Add to heartbeat monitor
157        self.heartbeat.add_connection(connection.clone()).await;
158
159        // Spawn outgoing message handler
160        let conn = connection.clone();
161        tokio::spawn(async move {
162            if let Err(e) = conn.process_outgoing(rx).await {
163                tracing::error!("Outgoing message handler error: {}", e);
164            }
165        });
166
167        // Handle incoming messages
168        loop {
169            match connection.receive().await {
170                Ok(Some(message)) => {
171                    tracing::debug!(
172                        "Received message from {}: {:?}",
173                        addr,
174                        message.message_type()
175                    );
176                    // Handle message (can be extended with custom logic)
177                }
178                Ok(None) => {
179                    // No message, continue
180                }
181                Err(e) => {
182                    tracing::error!("Receive error from {}: {}", addr, e);
183                    break;
184                }
185            }
186
187            // Check connection state
188            let state = connection.state().await;
189            if state != crate::server::connection::ConnectionState::Connected {
190                break;
191            }
192        }
193
194        // Cleanup
195        self.manager.remove(&connection.id());
196        self.heartbeat.remove_connection(&connection.id()).await;
197
198        tracing::info!("Connection from {} closed", addr);
199        Ok(())
200    }
201
202    /// Shutdown the server
203    pub async fn shutdown(&self) -> Result<()> {
204        tracing::info!("Shutting down server");
205
206        // Set shutdown flag
207        let mut shutdown = self.shutdown.write().await;
208        *shutdown = true;
209        drop(shutdown);
210
211        // Stop heartbeat monitor
212        self.heartbeat.shutdown().await;
213
214        // Close all connections
215        self.manager.close_all().await?;
216
217        // Shutdown pool
218        self.pool.shutdown().await?;
219
220        tracing::info!("Server shutdown complete");
221        Ok(())
222    }
223
224    /// Get connection manager
225    pub fn manager(&self) -> &Arc<ConnectionManager> {
226        &self.manager
227    }
228
229    /// Get connection pool
230    pub fn pool(&self) -> &Arc<ConnectionPool> {
231        &self.pool
232    }
233
234    /// Get heartbeat monitor
235    pub fn heartbeat(&self) -> &Arc<HeartbeatMonitor> {
236        &self.heartbeat
237    }
238
239    /// Get server statistics
240    pub async fn stats(&self) -> ServerStats {
241        let manager_stats = self.manager.stats();
242        let pool_stats = self.pool.stats();
243        let heartbeat_stats = self.heartbeat.stats().await;
244
245        ServerStats {
246            active_connections: manager_stats.active_connections,
247            total_connections: manager_stats.total_connections,
248            messages_sent: manager_stats.messages_sent,
249            messages_received: manager_stats.messages_received,
250            bytes_sent: manager_stats.bytes_sent,
251            bytes_received: manager_stats.bytes_received,
252            errors: manager_stats.errors,
253            pool_idle: pool_stats.idle_connections,
254            pool_active: pool_stats.active_connections,
255            heartbeat_monitored: heartbeat_stats.monitored_connections,
256        }
257    }
258}
259
260/// Server statistics
261#[derive(Debug, Clone)]
262pub struct ServerStats {
263    /// Active connections
264    pub active_connections: usize,
265    /// Total connections served
266    pub total_connections: u64,
267    /// Messages sent
268    pub messages_sent: u64,
269    /// Messages received
270    pub messages_received: u64,
271    /// Bytes sent
272    pub bytes_sent: u64,
273    /// Bytes received
274    pub bytes_received: u64,
275    /// Errors
276    pub errors: u64,
277    /// Pool idle connections
278    pub pool_idle: usize,
279    /// Pool active connections
280    pub pool_active: usize,
281    /// Heartbeat monitored connections
282    pub heartbeat_monitored: usize,
283}
284
285/// Server builder
286pub struct ServerBuilder {
287    config: ServerConfig,
288}
289
290impl ServerBuilder {
291    /// Create a new server builder
292    pub fn new() -> Self {
293        Self {
294            config: ServerConfig::default(),
295        }
296    }
297
298    /// Set bind address
299    pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
300        self.config.bind_addr = addr;
301        self
302    }
303
304    /// Set max connections
305    pub fn max_connections(mut self, max: usize) -> Self {
306        self.config.max_connections = max;
307        self
308    }
309
310    /// Set max message size
311    pub fn max_message_size(mut self, size: usize) -> Self {
312        self.config.max_message_size = size;
313        self
314    }
315
316    /// Set message format
317    pub fn message_format(mut self, format: MessageFormat) -> Self {
318        self.config.protocol.format = format;
319        self
320    }
321
322    /// Set heartbeat interval
323    pub fn heartbeat_interval(mut self, secs: u64) -> Self {
324        self.config.heartbeat.interval_secs = secs;
325        self
326    }
327
328    /// Enable/disable heartbeat
329    pub fn enable_heartbeat(mut self, enabled: bool) -> Self {
330        self.config.heartbeat.enabled = enabled;
331        self
332    }
333
334    /// Set pool config
335    pub fn pool_config(mut self, config: PoolConfig) -> Self {
336        self.config.pool = config;
337        self
338    }
339
340    /// Build the server
341    pub fn build(self) -> Arc<Server> {
342        Arc::new(Server::new(self.config))
343    }
344}
345
346impl Default for ServerBuilder {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_server_config_default() {
358        let config = ServerConfig::default();
359        assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
360        assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
361    }
362
363    #[test]
364    fn test_server_builder() {
365        let server = Server::builder()
366            .max_connections(5000)
367            .max_message_size(8 * 1024 * 1024)
368            .heartbeat_interval(60)
369            .build();
370
371        // Server should be created successfully
372        assert!(Arc::strong_count(&server) >= 1);
373    }
374
375    #[tokio::test]
376    async fn test_server_stats() {
377        let server = Server::builder().build();
378        let stats = server.stats().await;
379
380        assert_eq!(stats.active_connections, 0);
381        assert_eq!(stats.total_connections, 0);
382    }
383}