aerosocket_server/
server.rs

1//! WebSocket server implementation
2//!
3//! This module provides the main server implementation for handling WebSocket connections.
4
5use crate::{
6    config::ServerConfig,
7    handler::{BoxedHandler, Handler},
8    connection::{Connection, ConnectionHandle},
9    rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::{Error, Result, Transport};
12use aerosocket_core::transport::TransportStream;
13use aerosocket_core::handshake::{HandshakeConfig, parse_client_handshake, validate_client_handshake, create_server_handshake, response_to_string};
14use aerosocket_core::error::ConfigError;
15use std::sync::Arc;
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use tokio::sync::{Mutex, mpsc};
19use tokio::time::{timeout, Duration};
20
21/// WebSocket server
22pub struct Server {
23    config: ServerConfig,
24    handler: BoxedHandler,
25    rate_limiter: Option<Arc<RateLimitMiddleware>>,
26}
27
28/// Connection manager for tracking active connections
29#[derive(Debug)]
30pub struct ConnectionManager {
31    connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
32    next_id: Arc<Mutex<u64>>,
33}
34
35impl ConnectionManager {
36    /// Create a new connection manager
37    pub fn new() -> Self {
38        Self {
39            connections: Arc::new(Mutex::new(HashMap::new())),
40            next_id: Arc::new(Mutex::new(1)),
41        }
42    }
43
44    /// Add a new connection
45    pub async fn add_connection(&self, connection: Connection) -> u64 {
46        let mut next_id = self.next_id.lock().await;
47        let id = *next_id;
48        *next_id += 1;
49
50        let handle = ConnectionHandle::new(id, connection);
51        let mut connections = self.connections.lock().await;
52        connections.insert(id, handle);
53        id
54    }
55
56    /// Remove a connection
57    pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
58        let mut connections = self.connections.lock().await;
59        connections.remove(&id)
60    }
61
62    /// Get a connection by ID
63    pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
64        let connections = self.connections.lock().await;
65        connections.get(&id).cloned()
66    }
67
68    /// Get all active connections
69    pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
70        let connections = self.connections.lock().await;
71        connections.values().cloned().collect()
72    }
73
74    /// Get the number of active connections
75    pub async fn connection_count(&self) -> usize {
76        let connections = self.connections.lock().await;
77        connections.len()
78    }
79}
80
81impl std::fmt::Debug for Server {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Server")
84            .field("config", &self.config)
85            .field("handler", &"<handler>")
86            .finish()
87    }
88}
89
90impl Server {
91    /// Create a new server with the given config and handler
92    pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
93        let rate_limiter = if config.backpressure.enabled {
94            Some(Arc::new(RateLimitMiddleware::new(crate::rate_limit::RateLimitConfig {
95                max_requests: config.backpressure.max_requests_per_minute,
96                window: Duration::from_secs(60),
97                max_connections: config.max_connections / 10, // 10% of max connections per IP
98                connection_timeout: config.idle_timeout,
99            })))
100        } else {
101            None
102        };
103
104        Self { 
105            config, 
106            handler,
107            rate_limiter,
108        }
109    }
110
111    /// Create a server builder
112    pub fn builder() -> ServerBuilder {
113        ServerBuilder::new()
114    }
115
116    /// Start serving connections
117    pub async fn serve(self) -> Result<()> {
118        let connection_manager = Arc::new(ConnectionManager::new());
119        self.serve_with_connection_manager(connection_manager).await
120    }
121
122    /// Start serving with graceful shutdown
123    pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
124    where
125        F: std::future::Future<Output = ()> + Send + 'static,
126    {
127        let connection_manager = Arc::new(ConnectionManager::new());
128        let shutdown_signal = Box::pin(shutdown_signal);
129        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal).await
130    }
131
132    /// Internal serve method
133    async fn serve_with_connection_manager(self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
134        let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
135        let shutdown_rx = Box::pin(async move {
136            let _ = shutdown_rx.recv().await;
137        });
138
139        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx).await
140    }
141
142    /// Internal serve method with shutdown signal
143    async fn serve_with_connection_manager_and_shutdown<F>(
144        self,
145        _connection_manager: Arc<ConnectionManager>,
146        _shutdown_signal: F,
147    ) -> Result<()>
148    where
149        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
150    {
151        // Create transport based on configuration
152        #[cfg(feature = "tcp-transport")]
153        {
154            if self.config.transport_type == crate::config::TransportType::Tcp {
155                let transport = crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
156                return self.serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal).await;
157            }
158        }
159        
160        #[cfg(feature = "tls-transport")]
161        {
162            if self.config.transport_type == crate::config::TransportType::Tls {
163                let _tls_config = self.config.tls.as_ref()
164                    .ok_or_else(|| Error::Other("TLS configuration required for TLS transport".to_string()))?;
165                // Note: TLS transport is disabled in this release
166                return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
167            }
168        }
169
170        // Fallback to TCP if no specific transport is configured
171        if self.config.transport_type == crate::config::TransportType::Tcp {
172            #[cfg(feature = "tcp-transport")]
173            {
174                let transport = crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
175                return self.serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal).await;
176            }
177        }
178        Err(Error::Config(ConfigError::Validation("No transport available".to_string())))
179    }
180
181    /// Serve with TCP transport
182    #[cfg(feature = "tcp-transport")]
183    async fn serve_with_tcp_transport<F>(
184        self,
185        transport: crate::tcp_transport::TcpTransport,
186        connection_manager: Arc<ConnectionManager>,
187        _shutdown_signal: F,
188    ) -> Result<()>
189    where
190        F: std::future::Future<Output = ()> + Send + 'static,
191    {
192        // Spawn connection handling task
193        let handler = self.handler;
194        let config = self.config.clone();
195        let manager = connection_manager.clone();
196        let rate_limiter = self.rate_limiter.clone();
197        
198        let server_task = tokio::spawn(async move {
199            let mut connection_counter = 0u64;
200            
201            loop {
202                // Check for shutdown
203                tokio::select! {
204                    result = transport.accept() => {
205                        match result {
206                            Ok(mut stream) => {
207                                // Get remote address for rate limiting
208                                let remote_addr = match stream.remote_addr() {
209                                    Ok(addr) => addr.ip(),
210                                    Err(e) => {
211                                        crate::log_error!("Failed to get remote address: {:?}", e);
212                                        let _ = stream.close().await;
213                                        continue;
214                                    }
215                                };
216
217                                // Check rate limiting if enabled
218                                if let Some(ref rate_limiter) = rate_limiter {
219                                    if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
220                                        crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
221                                        let _ = stream.close().await;
222                                        continue;
223                                    }
224                                }
225
226                                // Check connection limit
227                                if manager.connection_count().await >= config.max_connections {
228                                    crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
229                                    // Close the stream gracefully
230                                    let _ = stream.close().await;
231                                    continue;
232                                }
233
234                                connection_counter += 1;
235                                let manager = manager.clone();
236                                let handler = handler.clone();
237                                let config = config.clone();
238                                let rate_limiter = rate_limiter.clone();
239                                
240                                // Spawn connection handler
241                                tokio::spawn(async move {
242                                    if let Err(e) = Self::handle_connection(
243                                        stream,
244                                        handler,
245                                        config,
246                                        manager,
247                                        rate_limiter,
248                                    ).await {
249                                        crate::log_error!("Connection handling error: {:?}", e);
250                                    }
251                                });
252                            }
253                            Err(e) => {
254                                crate::log_error!("Accept error: {:?}", e);
255                                // Continue accepting other connections
256                            }
257                        }
258                    }
259                    _ = tokio::signal::ctrl_c() => {
260                        break;
261                    }
262                }
263            }
264        });
265
266        // Wait for server task completion
267        match server_task.await {
268            Ok(()) => Ok(()),
269            Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
270        }
271    }
272
273    /// Serve with TLS transport
274    #[cfg(feature = "tls-transport")]
275    async fn serve_with_tls_transport<F>(
276        self,
277        _transport: crate::tls_transport::TlsTransport,
278        _connection_manager: Arc<ConnectionManager>,
279        _shutdown_signal: F,
280    ) -> Result<()>
281    where
282        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
283    {
284        Err(Error::Other("TLS transport is not available in this release".to_string()))
285    }
286
287    /// Handle a single TLS connection
288    #[cfg(feature = "tls-transport")]
289    async fn handle_tls_connection(
290        _stream: crate::tls_transport::TlsStreamWrapper,
291        _handler: BoxedHandler,
292        _config: ServerConfig,
293        _connection_manager: Arc<ConnectionManager>,
294        _rate_limiter: Option<Arc<RateLimitMiddleware>>,
295    ) -> Result<()> {
296        Err(Error::Other("TLS connection handling is not available in this release".to_string()))
297    }
298
299    /// Handle a single connection
300    async fn handle_connection(
301        mut stream: crate::tcp_transport::TcpStream,
302        handler: BoxedHandler,
303        config: ServerConfig,
304        connection_manager: Arc<ConnectionManager>,
305        rate_limiter: Option<Arc<RateLimitMiddleware>>,
306    ) -> Result<()> {
307        // Perform WebSocket handshake
308        let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
309        
310        // Convert to boxed transport stream
311        let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
312        
313        // Create connection with stream
314        let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
315        
316        // Add to connection manager
317        let connection_id = connection_manager.add_connection(connection).await;
318        
319        // Get connection handle
320        let connection_handle = connection_manager.get_connection(connection_id).await
321            .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
322
323        // Call handler
324        if let Err(e) = handler.handle(connection_handle).await {
325            crate::log_error!("Handler error: {:?}", e);
326        }
327
328        // Remove connection from manager
329        connection_manager.remove_connection(connection_id).await;
330        
331        // Clean up rate limiting
332        if let Some(ref rate_limiter) = rate_limiter {
333            rate_limiter.connection_closed(remote_addr.ip()).await;
334        }
335        
336        Ok(())
337    }
338
339    /// Perform WebSocket handshake over TLS
340    #[cfg(feature = "tls-transport")]
341    async fn perform_tls_handshake(
342        stream: &mut crate::tls_transport::TlsStreamWrapper,
343        config: &ServerConfig,
344    ) -> Result<(SocketAddr, SocketAddr)> {
345        // Read HTTP request over TLS
346        let request_data = Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
347        let request_str = String::from_utf8_lossy(&request_data);
348        
349        // Parse handshake request
350        let request = parse_client_handshake(&request_str)?;
351        
352        // Create handshake config
353        let handshake_config = HandshakeConfig {
354            protocols: config.supported_protocols.clone(),
355            extensions: config.supported_extensions.clone(),
356            origin: config.allowed_origin.clone(),
357            host: None,
358            extra_headers: config.extra_headers.clone(),
359        };
360        
361        // Validate request
362        validate_client_handshake(&request, &handshake_config)?;
363        
364        // Create response
365        let response = create_server_handshake(&request, &handshake_config)?;
366        let response_str = response_to_string(&response);
367        
368        // Send response over TLS
369        stream.write_all(response_str.as_bytes()).await?;
370        stream.flush().await?;
371        
372        // Get addresses
373        let remote_addr = stream.remote_addr()?;
374        let local_addr = stream.local_addr()?;
375        
376        Ok((remote_addr, local_addr))
377    }
378
379    /// Read handshake request from TLS stream
380    #[cfg(feature = "tls-transport")]
381    async fn read_tls_handshake_request(
382        stream: &mut crate::tls_transport::TlsStreamWrapper,
383        timeout_duration: Duration,
384    ) -> Result<Vec<u8>> {
385        let mut buffer = Vec::new();
386        let mut temp_buffer = [0u8; 1024];
387        
388        let read_result = timeout(timeout_duration, async {
389            loop {
390                let n = stream.read(&mut temp_buffer).await?;
391                if n == 0 {
392                    break;
393                }
394                
395                buffer.extend_from_slice(&temp_buffer[..n]);
396                
397                // Check for end of headers (double CRLF)
398                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
399                    break;
400                }
401                
402                // Prevent reading too much
403                if buffer.len() > 8192 {
404                    return Err(Error::Other("TLS handshake request too large".to_string()));
405                }
406            }
407            
408            Ok::<(), Error>(())
409        }).await;
410        
411        match read_result {
412            Ok(result) => {
413                result?;
414                Ok(buffer)
415            },
416            Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
417        }
418    }
419
420    /// Perform WebSocket handshake
421    async fn perform_handshake(
422        stream: &mut crate::tcp_transport::TcpStream,
423        config: &ServerConfig,
424    ) -> Result<(SocketAddr, SocketAddr)> {
425        // Read HTTP request
426        let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
427        let request_str = String::from_utf8_lossy(&request_data);
428        
429        // Parse handshake request
430        let request = parse_client_handshake(&request_str)?;
431        
432        // Create handshake config
433        let handshake_config = HandshakeConfig {
434            protocols: config.supported_protocols.clone(),
435            extensions: config.supported_extensions.clone(),
436            origin: config.allowed_origin.clone(),
437            host: None,
438            extra_headers: config.extra_headers.clone(),
439        };
440        
441        // Validate request
442        validate_client_handshake(&request, &handshake_config)?;
443        
444        // Create response
445        let response = create_server_handshake(&request, &handshake_config)?;
446        let response_str = response_to_string(&response);
447        
448        // Send response
449        stream.write_all(response_str.as_bytes()).await?;
450        stream.flush().await?;
451        
452        // Get addresses
453        let remote_addr = stream.remote_addr()?;
454        let local_addr = stream.local_addr()?;
455        
456        Ok((remote_addr, local_addr))
457    }
458
459    /// Read handshake request from stream
460    async fn read_handshake_request(
461        stream: &mut crate::tcp_transport::TcpStream,
462        timeout_duration: Duration,
463    ) -> Result<Vec<u8>> {
464        let mut buffer = Vec::new();
465        let mut temp_buffer = [0u8; 1024];
466        
467        let read_result = timeout(timeout_duration, async {
468            loop {
469                let n = stream.read(&mut temp_buffer).await?;
470                if n == 0 {
471                    break;
472                }
473                
474                buffer.extend_from_slice(&temp_buffer[..n]);
475                
476                // Check for end of headers (double CRLF)
477                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
478                    break;
479                }
480                
481                // Prevent reading too much
482                if buffer.len() > 8192 {
483                    return Err(Error::Other("Handshake request too large".to_string()));
484                }
485            }
486            
487            Ok::<(), Error>(())
488        }).await;
489        
490        match read_result {
491            Ok(result) => {
492                result?;
493                Ok(buffer)
494            },
495            Err(_) => Err(Error::Other("Handshake timeout".to_string())),
496        }
497    }
498
499    /// Perform graceful shutdown
500    async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
501        // Get all connections
502        let connections = connection_manager.get_all_connections().await;
503        
504        // Send close frames to all connections
505        for handle in connections {
506            if let Ok(mut connection) = handle.try_lock().await {
507                let _ = connection.close(Some(1000), Some("Server shutdown")).await;
508            }
509        }
510        
511        Ok(())
512    }
513}
514
515/// Server builder
516#[derive(Debug, Clone)]
517pub struct ServerBuilder {
518    config: ServerConfig,
519}
520
521impl ServerBuilder {
522    /// Create a new server builder
523    pub fn new() -> Self {
524        Self {
525            config: ServerConfig::default(),
526        }
527    }
528
529    /// Bind to the given address
530    pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
531        self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
532            Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
533        })?;
534        Ok(self)
535    }
536
537    /// Set maximum connections
538    pub fn max_connections(mut self, max: usize) -> Self {
539        self.config.max_connections = max;
540        self
541    }
542
543    /// Set maximum frame size
544    pub fn max_frame_size(mut self, size: usize) -> Self {
545        self.config.max_frame_size = size;
546        self
547    }
548
549    /// Set maximum message size
550    pub fn max_message_size(mut self, size: usize) -> Self {
551        self.config.max_message_size = size;
552        self
553    }
554
555    /// Set handshake timeout
556    pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
557        self.config.handshake_timeout = timeout;
558        self
559    }
560
561    /// Set idle timeout
562    pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
563        self.config.idle_timeout = timeout;
564        self
565    }
566
567    /// Enable/disable compression
568    pub fn compression(mut self, enabled: bool) -> Self {
569        self.config.compression.enabled = enabled;
570        self
571    }
572
573    /// Set backpressure strategy
574    pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
575        self.config.backpressure.strategy = strategy;
576        self
577    }
578
579    /// Build the server
580    pub fn build(self) -> Result<Server> {
581        // Validate configuration
582        self.config.validate()?;
583
584        // Create default handler
585        let handler = Box::new(crate::handler::DefaultHandler::new());
586
587        Ok(Server::new(self.config, handler))
588    }
589
590    /// Build the server with a custom handler
591    pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
592    where
593        H: Handler + Send + Sync + 'static,
594    {
595        // Validate configuration
596        self.config.validate()?;
597
598        Ok(Server::new(self.config, Box::new(handler)))
599    }
600}
601
602impl Default for ServerBuilder {
603    fn default() -> Self {
604        Self::new()
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[test]
613    fn test_server_builder() {
614        let builder = ServerBuilder::new()
615            .bind("127.0.0.1:8080")
616            .unwrap()
617            .max_connections(1000)
618            .max_frame_size(1024 * 1024)
619            .compression(true);
620
621        assert!(builder.build().is_ok());
622    }
623}