aerosocket_server/
server.rs

1//! WebSocket server implementation
2//!
3//! This module provides the main server implementation for handling WebSocket connections.
4
5use crate::{
6    config::ServerConfig,
7    connection::{Connection, ConnectionHandle},
8    handler::{BoxedHandler, Handler},
9    rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::error::ConfigError;
12use aerosocket_core::handshake::{
13    create_server_handshake, parse_client_handshake, response_to_string, validate_client_handshake,
14    HandshakeConfig,
15};
16use aerosocket_core::transport::TransportStream;
17use aerosocket_core::{Error, Result, Transport};
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex};
22use tokio::time::{timeout, Duration};
23
24/// WebSocket server
25pub struct Server {
26    config: ServerConfig,
27    handler: BoxedHandler,
28    rate_limiter: Option<Arc<RateLimitMiddleware>>,
29}
30
31/// Connection manager for tracking active connections
32#[derive(Debug)]
33pub struct ConnectionManager {
34    connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
35    next_id: Arc<Mutex<u64>>,
36}
37
38impl ConnectionManager {
39    /// Create a new connection manager
40    pub fn new() -> Self {
41        Self {
42            connections: Arc::new(Mutex::new(HashMap::new())),
43            next_id: Arc::new(Mutex::new(1)),
44        }
45    }
46
47    /// Add a new connection
48    pub async fn add_connection(&self, connection: Connection) -> u64 {
49        let mut next_id = self.next_id.lock().await;
50        let id = *next_id;
51        *next_id += 1;
52
53        let handle = ConnectionHandle::new(id, connection);
54        let mut connections = self.connections.lock().await;
55        connections.insert(id, handle);
56        id
57    }
58
59    /// Remove a connection
60    pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
61        let mut connections = self.connections.lock().await;
62        connections.remove(&id)
63    }
64
65    /// Get a connection by ID
66    pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
67        let connections = self.connections.lock().await;
68        connections.get(&id).cloned()
69    }
70
71    /// Get all active connections
72    pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
73        let connections = self.connections.lock().await;
74        connections.values().cloned().collect()
75    }
76
77    /// Get the number of active connections
78    pub async fn connection_count(&self) -> usize {
79        let connections = self.connections.lock().await;
80        connections.len()
81    }
82}
83
84impl std::fmt::Debug for Server {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("Server")
87            .field("config", &self.config)
88            .field("handler", &"<handler>")
89            .finish()
90    }
91}
92
93impl Server {
94    /// Create a new server with the given config and handler
95    pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
96        let rate_limiter = if config.backpressure.enabled {
97            Some(Arc::new(RateLimitMiddleware::new(
98                crate::rate_limit::RateLimitConfig {
99                    max_requests: config.backpressure.max_requests_per_minute,
100                    window: Duration::from_secs(60),
101                    max_connections: config.max_connections / 10, // 10% of max connections per IP
102                    connection_timeout: config.idle_timeout,
103                },
104            )))
105        } else {
106            None
107        };
108
109        Self {
110            config,
111            handler,
112            rate_limiter,
113        }
114    }
115
116    /// Create a server builder
117    pub fn builder() -> ServerBuilder {
118        ServerBuilder::new()
119    }
120
121    /// Start serving connections
122    pub async fn serve(self) -> Result<()> {
123        let connection_manager = Arc::new(ConnectionManager::new());
124        self.serve_with_connection_manager(connection_manager).await
125    }
126
127    /// Start serving with graceful shutdown
128    pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
129    where
130        F: std::future::Future<Output = ()> + Send + 'static,
131    {
132        let connection_manager = Arc::new(ConnectionManager::new());
133        let shutdown_signal = Box::pin(shutdown_signal);
134        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal)
135            .await
136    }
137
138    /// Internal serve method
139    async fn serve_with_connection_manager(
140        self,
141        connection_manager: Arc<ConnectionManager>,
142    ) -> Result<()> {
143        let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
144        let shutdown_rx = Box::pin(async move {
145            let _ = shutdown_rx.recv().await;
146        });
147
148        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx)
149            .await
150    }
151
152    /// Internal serve method with shutdown signal
153    async fn serve_with_connection_manager_and_shutdown<F>(
154        self,
155        _connection_manager: Arc<ConnectionManager>,
156        _shutdown_signal: F,
157    ) -> Result<()>
158    where
159        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
160    {
161        // Create transport based on configuration
162        #[cfg(feature = "tcp-transport")]
163        {
164            if self.config.transport_type == crate::config::TransportType::Tcp {
165                let transport =
166                    crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
167                return self
168                    .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
169                    .await;
170            }
171        }
172
173        #[cfg(feature = "tls-transport")]
174        {
175            if self.config.transport_type == crate::config::TransportType::Tls {
176                let _tls_config = self.config.tls.as_ref().ok_or_else(|| {
177                    Error::Other("TLS configuration required for TLS transport".to_string())
178                })?;
179                // Note: TLS transport is disabled in this release
180                return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
181            }
182        }
183
184        // Fallback to TCP if no specific transport is configured
185        if self.config.transport_type == crate::config::TransportType::Tcp {
186            #[cfg(feature = "tcp-transport")]
187            {
188                let transport =
189                    crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
190                return self
191                    .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
192                    .await;
193            }
194        }
195        Err(Error::Config(ConfigError::Validation(
196            "No transport available".to_string(),
197        )))
198    }
199
200    /// Serve with TCP transport
201    #[cfg(feature = "tcp-transport")]
202    async fn serve_with_tcp_transport<F>(
203        self,
204        transport: crate::tcp_transport::TcpTransport,
205        connection_manager: Arc<ConnectionManager>,
206        _shutdown_signal: F,
207    ) -> Result<()>
208    where
209        F: std::future::Future<Output = ()> + Send + 'static,
210    {
211        // Spawn connection handling task
212        let handler = self.handler;
213        let config = self.config.clone();
214        let manager = connection_manager.clone();
215        let rate_limiter = self.rate_limiter.clone();
216
217        let server_task = tokio::spawn(async move {
218            let mut connection_counter = 0u64;
219
220            loop {
221                // Check for shutdown
222                tokio::select! {
223                    result = transport.accept() => {
224                        match result {
225                            Ok(mut stream) => {
226                                // Get remote address for rate limiting
227                                let remote_addr = match stream.remote_addr() {
228                                    Ok(addr) => addr.ip(),
229                                    Err(e) => {
230                                        crate::log_error!("Failed to get remote address: {:?}", e);
231                                        let _ = stream.close().await;
232                                        continue;
233                                    }
234                                };
235
236                                // Check rate limiting if enabled
237                                if let Some(ref rate_limiter) = rate_limiter {
238                                    if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
239                                        crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
240                                        let _ = stream.close().await;
241                                        continue;
242                                    }
243                                }
244
245                                // Check connection limit
246                                if manager.connection_count().await >= config.max_connections {
247                                    crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
248                                    // Close the stream gracefully
249                                    let _ = stream.close().await;
250                                    continue;
251                                }
252
253                                connection_counter += 1;
254                                let manager = manager.clone();
255                                let handler = handler.clone();
256                                let config = config.clone();
257                                let rate_limiter = rate_limiter.clone();
258
259                                // Spawn connection handler
260                                tokio::spawn(async move {
261                                    if let Err(e) = Self::handle_connection(
262                                        stream,
263                                        handler,
264                                        config,
265                                        manager,
266                                        rate_limiter,
267                                    ).await {
268                                        crate::log_error!("Connection handling error: {:?}", e);
269                                    }
270                                });
271                            }
272                            Err(e) => {
273                                crate::log_error!("Accept error: {:?}", e);
274                                // Continue accepting other connections
275                            }
276                        }
277                    }
278                    _ = tokio::signal::ctrl_c() => {
279                        break;
280                    }
281                }
282            }
283        });
284
285        // Wait for server task completion
286        match server_task.await {
287            Ok(()) => Ok(()),
288            Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
289        }
290    }
291
292    /// Serve with TLS transport
293    #[cfg(feature = "tls-transport")]
294    async fn serve_with_tls_transport<F>(
295        self,
296        _transport: crate::tls_transport::TlsTransport,
297        _connection_manager: Arc<ConnectionManager>,
298        _shutdown_signal: F,
299    ) -> Result<()>
300    where
301        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
302    {
303        Err(Error::Other(
304            "TLS transport is not available in this release".to_string(),
305        ))
306    }
307
308    /// Handle a single TLS connection
309    #[cfg(feature = "tls-transport")]
310    async fn handle_tls_connection(
311        _stream: crate::tls_transport::TlsStreamWrapper,
312        _handler: BoxedHandler,
313        _config: ServerConfig,
314        _connection_manager: Arc<ConnectionManager>,
315        _rate_limiter: Option<Arc<RateLimitMiddleware>>,
316    ) -> Result<()> {
317        Err(Error::Other(
318            "TLS connection handling is not available in this release".to_string(),
319        ))
320    }
321
322    /// Handle a single connection
323    async fn handle_connection(
324        mut stream: crate::tcp_transport::TcpStream,
325        handler: BoxedHandler,
326        config: ServerConfig,
327        connection_manager: Arc<ConnectionManager>,
328        rate_limiter: Option<Arc<RateLimitMiddleware>>,
329    ) -> Result<()> {
330        // Perform WebSocket handshake
331        let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
332
333        // Convert to boxed transport stream
334        let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
335
336        // Create connection with stream
337        let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
338
339        // Add to connection manager
340        let connection_id = connection_manager.add_connection(connection).await;
341
342        // Get connection handle
343        let connection_handle = connection_manager
344            .get_connection(connection_id)
345            .await
346            .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
347
348        // Call handler
349        if let Err(e) = handler.handle(connection_handle).await {
350            crate::log_error!("Handler error: {:?}", e);
351        }
352
353        // Remove connection from manager
354        connection_manager.remove_connection(connection_id).await;
355
356        // Clean up rate limiting
357        if let Some(ref rate_limiter) = rate_limiter {
358            rate_limiter.connection_closed(remote_addr.ip()).await;
359        }
360
361        Ok(())
362    }
363
364    /// Perform WebSocket handshake over TLS
365    #[cfg(feature = "tls-transport")]
366    async fn perform_tls_handshake(
367        stream: &mut crate::tls_transport::TlsStreamWrapper,
368        config: &ServerConfig,
369    ) -> Result<(SocketAddr, SocketAddr)> {
370        // Read HTTP request over TLS
371        let request_data =
372            Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
373        let request_str = String::from_utf8_lossy(&request_data);
374
375        // Parse handshake request
376        let request = parse_client_handshake(&request_str)?;
377
378        // Create handshake config
379        let handshake_config = HandshakeConfig {
380            protocols: config.supported_protocols.clone(),
381            extensions: config.supported_extensions.clone(),
382            origin: config.allowed_origin.clone(),
383            host: None,
384            extra_headers: config.extra_headers.clone(),
385        };
386
387        // Validate request
388        validate_client_handshake(&request, &handshake_config)?;
389
390        // Create response
391        let response = create_server_handshake(&request, &handshake_config)?;
392        let response_str = response_to_string(&response);
393
394        // Send response over TLS
395        stream.write_all(response_str.as_bytes()).await?;
396        stream.flush().await?;
397
398        // Get addresses
399        let remote_addr = stream.remote_addr()?;
400        let local_addr = stream.local_addr()?;
401
402        Ok((remote_addr, local_addr))
403    }
404
405    /// Read handshake request from TLS stream
406    #[cfg(feature = "tls-transport")]
407    async fn read_tls_handshake_request(
408        stream: &mut crate::tls_transport::TlsStreamWrapper,
409        timeout_duration: Duration,
410    ) -> Result<Vec<u8>> {
411        let mut buffer = Vec::new();
412        let mut temp_buffer = [0u8; 1024];
413
414        let read_result = timeout(timeout_duration, async {
415            loop {
416                let n = stream.read(&mut temp_buffer).await?;
417                if n == 0 {
418                    break;
419                }
420
421                buffer.extend_from_slice(&temp_buffer[..n]);
422
423                // Check for end of headers (double CRLF)
424                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
425                    break;
426                }
427
428                // Prevent reading too much
429                if buffer.len() > 8192 {
430                    return Err(Error::Other("TLS handshake request too large".to_string()));
431                }
432            }
433
434            Ok::<(), Error>(())
435        })
436        .await;
437
438        match read_result {
439            Ok(result) => {
440                result?;
441                Ok(buffer)
442            }
443            Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
444        }
445    }
446
447    /// Perform WebSocket handshake
448    async fn perform_handshake(
449        stream: &mut crate::tcp_transport::TcpStream,
450        config: &ServerConfig,
451    ) -> Result<(SocketAddr, SocketAddr)> {
452        // Read HTTP request
453        let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
454        let request_str = String::from_utf8_lossy(&request_data);
455
456        // Parse handshake request
457        let request = parse_client_handshake(&request_str)?;
458
459        // Create handshake config
460        let handshake_config = HandshakeConfig {
461            protocols: config.supported_protocols.clone(),
462            extensions: config.supported_extensions.clone(),
463            origin: config.allowed_origin.clone(),
464            host: None,
465            extra_headers: config.extra_headers.clone(),
466        };
467
468        // Validate request
469        validate_client_handshake(&request, &handshake_config)?;
470
471        // Create response
472        let response = create_server_handshake(&request, &handshake_config)?;
473        let response_str = response_to_string(&response);
474
475        // Send response
476        stream.write_all(response_str.as_bytes()).await?;
477        stream.flush().await?;
478
479        // Get addresses
480        let remote_addr = stream.remote_addr()?;
481        let local_addr = stream.local_addr()?;
482
483        Ok((remote_addr, local_addr))
484    }
485
486    /// Read handshake request from stream
487    async fn read_handshake_request(
488        stream: &mut crate::tcp_transport::TcpStream,
489        timeout_duration: Duration,
490    ) -> Result<Vec<u8>> {
491        let mut buffer = Vec::new();
492        let mut temp_buffer = [0u8; 1024];
493
494        let read_result = timeout(timeout_duration, async {
495            loop {
496                let n = stream.read(&mut temp_buffer).await?;
497                if n == 0 {
498                    break;
499                }
500
501                buffer.extend_from_slice(&temp_buffer[..n]);
502
503                // Check for end of headers (double CRLF)
504                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
505                    break;
506                }
507
508                // Prevent reading too much
509                if buffer.len() > 8192 {
510                    return Err(Error::Other("Handshake request too large".to_string()));
511                }
512            }
513
514            Ok::<(), Error>(())
515        })
516        .await;
517
518        match read_result {
519            Ok(result) => {
520                result?;
521                Ok(buffer)
522            }
523            Err(_) => Err(Error::Other("Handshake timeout".to_string())),
524        }
525    }
526
527    /// Perform graceful shutdown
528    async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
529        // Get all connections
530        let connections = connection_manager.get_all_connections().await;
531
532        // Send close frames to all connections
533        for handle in connections {
534            if let Ok(mut connection) = handle.try_lock().await {
535                let _ = connection.close(Some(1000), Some("Server shutdown")).await;
536            }
537        }
538
539        Ok(())
540    }
541}
542
543/// Server builder
544#[derive(Debug, Clone)]
545pub struct ServerBuilder {
546    config: ServerConfig,
547}
548
549impl ServerBuilder {
550    /// Create a new server builder
551    pub fn new() -> Self {
552        Self {
553            config: ServerConfig::default(),
554        }
555    }
556
557    /// Bind to the given address
558    pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
559        self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
560            Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
561        })?;
562        Ok(self)
563    }
564
565    /// Set maximum connections
566    pub fn max_connections(mut self, max: usize) -> Self {
567        self.config.max_connections = max;
568        self
569    }
570
571    /// Set maximum frame size
572    pub fn max_frame_size(mut self, size: usize) -> Self {
573        self.config.max_frame_size = size;
574        self
575    }
576
577    /// Set maximum message size
578    pub fn max_message_size(mut self, size: usize) -> Self {
579        self.config.max_message_size = size;
580        self
581    }
582
583    /// Set handshake timeout
584    pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
585        self.config.handshake_timeout = timeout;
586        self
587    }
588
589    /// Set idle timeout
590    pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
591        self.config.idle_timeout = timeout;
592        self
593    }
594
595    /// Enable/disable compression
596    pub fn compression(mut self, enabled: bool) -> Self {
597        self.config.compression.enabled = enabled;
598        self
599    }
600
601    /// Set backpressure strategy
602    pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
603        self.config.backpressure.strategy = strategy;
604        self
605    }
606
607    /// Build the server
608    pub fn build(self) -> Result<Server> {
609        // Validate configuration
610        self.config.validate()?;
611
612        // Create default handler
613        let handler = Box::new(crate::handler::DefaultHandler::new());
614
615        Ok(Server::new(self.config, handler))
616    }
617
618    /// Build the server with a custom handler
619    pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
620    where
621        H: Handler + Send + Sync + 'static,
622    {
623        // Validate configuration
624        self.config.validate()?;
625
626        Ok(Server::new(self.config, Box::new(handler)))
627    }
628}
629
630impl Default for ServerBuilder {
631    fn default() -> Self {
632        Self::new()
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    #[test]
641    fn test_server_builder() {
642        let builder = ServerBuilder::new()
643            .bind("127.0.0.1:8080")
644            .unwrap()
645            .max_connections(1000)
646            .max_frame_size(1024 * 1024)
647            .compression(true);
648
649        assert!(builder.build().is_ok());
650    }
651}