aerosocket_server/
server.rs

1//! WebSocket server implementation
2//!
3//! This module provides the main server implementation for handling WebSocket connections.
4
5use crate::{
6    config::ServerConfig,
7    connection::{Connection, ConnectionHandle},
8    handler::{BoxedHandler, Handler},
9    rate_limit::RateLimitMiddleware,
10};
11use aerosocket_core::error::ConfigError;
12use aerosocket_core::handshake::{
13    create_server_handshake, parse_client_handshake, response_to_string, validate_client_handshake,
14    HandshakeConfig,
15};
16use aerosocket_core::transport::TransportStream;
17use aerosocket_core::{Error, Result, Transport};
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::sync::{mpsc, Mutex};
22use tokio::time::{timeout, Duration};
23
24/// WebSocket server
25pub struct Server {
26    config: ServerConfig,
27    handler: BoxedHandler,
28    rate_limiter: Option<Arc<RateLimitMiddleware>>,
29}
30
31/// Connection manager for tracking active connections
32#[derive(Debug)]
33pub struct ConnectionManager {
34    connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
35    next_id: Arc<Mutex<u64>>,
36}
37
38impl ConnectionManager {
39    /// Create a new connection manager
40    pub fn new() -> Self {
41        Self {
42            connections: Arc::new(Mutex::new(HashMap::new())),
43            next_id: Arc::new(Mutex::new(1)),
44        }
45    }
46
47    /// Add a new connection
48    pub async fn add_connection(&self, connection: Connection) -> u64 {
49        let mut next_id = self.next_id.lock().await;
50        let id = *next_id;
51        *next_id += 1;
52
53        let handle = ConnectionHandle::new(id, connection);
54        let mut connections = self.connections.lock().await;
55        connections.insert(id, handle);
56        id
57    }
58
59    /// Remove a connection
60    pub async fn remove_connection(&self, id: u64) -> Option<ConnectionHandle> {
61        let mut connections = self.connections.lock().await;
62        connections.remove(&id)
63    }
64
65    /// Get a connection by ID
66    pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
67        let connections = self.connections.lock().await;
68        connections.get(&id).cloned()
69    }
70
71    /// Get all active connections
72    pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
73        let connections = self.connections.lock().await;
74        connections.values().cloned().collect()
75    }
76
77    /// Get the number of active connections
78    pub async fn connection_count(&self) -> usize {
79        let connections = self.connections.lock().await;
80        connections.len()
81    }
82}
83
84impl std::fmt::Debug for Server {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("Server")
87            .field("config", &self.config)
88            .field("handler", &"<handler>")
89            .finish()
90    }
91}
92
93impl Server {
94    /// Create a new server with the given config and handler
95    pub fn new(config: ServerConfig, handler: BoxedHandler) -> Self {
96        let rate_limiter = if config.backpressure.enabled {
97            Some(Arc::new(RateLimitMiddleware::new(
98                crate::rate_limit::RateLimitConfig {
99                    max_requests: config.backpressure.max_requests_per_minute,
100                    window: Duration::from_secs(60),
101                    max_connections: config.max_connections / 10, // 10% of max connections per IP
102                    connection_timeout: config.idle_timeout,
103                },
104            )))
105        } else {
106            None
107        };
108
109        Self {
110            config,
111            handler,
112            rate_limiter,
113        }
114    }
115
116    /// Create a server builder
117    pub fn builder() -> ServerBuilder {
118        ServerBuilder::new()
119    }
120
121    /// Start serving connections
122    pub async fn serve(self) -> Result<()> {
123        let connection_manager = Arc::new(ConnectionManager::new());
124        self.serve_with_connection_manager(connection_manager).await
125    }
126
127    /// Start serving with graceful shutdown
128    pub async fn serve_with_graceful_shutdown<F>(self, shutdown_signal: F) -> Result<()>
129    where
130        F: std::future::Future<Output = ()> + Send + 'static,
131    {
132        let connection_manager = Arc::new(ConnectionManager::new());
133        let shutdown_signal = Box::pin(shutdown_signal);
134        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_signal)
135            .await
136    }
137
138    /// Internal serve method
139    async fn serve_with_connection_manager(
140        self,
141        connection_manager: Arc<ConnectionManager>,
142    ) -> Result<()> {
143        let (_shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
144        let shutdown_rx = Box::pin(async move {
145            let _ = shutdown_rx.recv().await;
146        });
147
148        self.serve_with_connection_manager_and_shutdown(connection_manager, shutdown_rx)
149            .await
150    }
151
152    /// Internal serve method with shutdown signal
153    async fn serve_with_connection_manager_and_shutdown<F>(
154        self,
155        _connection_manager: Arc<ConnectionManager>,
156        _shutdown_signal: F,
157    ) -> Result<()>
158    where
159        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
160    {
161        // Create transport based on configuration
162        #[cfg(feature = "tcp-transport")]
163        {
164            if self.config.transport_type == crate::config::TransportType::Tcp {
165                let transport =
166                    crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
167                return self
168                    .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
169                    .await;
170            }
171        }
172
173        #[cfg(feature = "tls-transport")]
174        {
175            if self.config.transport_type == crate::config::TransportType::Tls {
176                let _tls_config = self.config.tls.as_ref().ok_or_else(|| {
177                    Error::Other("TLS configuration required for TLS transport".to_string())
178                })?;
179                // Note: TLS transport is disabled in this release
180                return Err(Error::Other("TLS transport is not available in this release. Please enable the 'tls-transport' feature and implement proper TLS configuration.".to_string()));
181            }
182        }
183
184        // Fallback to TCP if no specific transport is configured
185        if self.config.transport_type == crate::config::TransportType::Tcp {
186            #[cfg(feature = "tcp-transport")]
187            {
188                let transport =
189                    crate::tcp_transport::TcpTransport::bind(self.config.bind_address).await?;
190                return self
191                    .serve_with_tcp_transport(transport, _connection_manager, _shutdown_signal)
192                    .await;
193            }
194        }
195        Err(Error::Config(ConfigError::Validation(
196            "No transport available".to_string(),
197        )))
198    }
199
200    /// Serve with TCP transport
201    #[cfg(feature = "tcp-transport")]
202    async fn serve_with_tcp_transport<F>(
203        self,
204        transport: crate::tcp_transport::TcpTransport,
205        connection_manager: Arc<ConnectionManager>,
206        _shutdown_signal: F,
207    ) -> Result<()>
208    where
209        F: std::future::Future<Output = ()> + Send + 'static,
210    {
211        // Spawn connection handling task
212        let handler = self.handler;
213        let config = self.config.clone();
214        let manager = connection_manager.clone();
215        let rate_limiter = self.rate_limiter.clone();
216
217        let server_task = tokio::spawn(async move {
218            let mut connection_counter = 0u64;
219
220            loop {
221                // Check for shutdown
222                tokio::select! {
223                    result = transport.accept() => {
224                        match result {
225                            Ok(mut stream) => {
226                                // Get remote address for rate limiting
227                                let remote_addr = match stream.remote_addr() {
228                                    Ok(addr) => addr.ip(),
229                                    Err(e) => {
230                                        crate::log_error!("Failed to get remote address: {:?}", e);
231                                        let _ = stream.close().await;
232                                        continue;
233                                    }
234                                };
235
236                                // Check rate limiting if enabled
237                                if let Some(ref rate_limiter) = rate_limiter {
238                                    if !rate_limiter.check_connection(remote_addr).await.unwrap_or(true) {
239                                        crate::log_warn!("Rate limit exceeded for IP: {}", remote_addr);
240                                        let _ = stream.close().await;
241                                        continue;
242                                    }
243                                }
244
245                                // Check connection limit
246                                if manager.connection_count().await >= config.max_connections {
247                                    crate::log_warn!("Connection limit reached, rejecting connection from {}", remote_addr);
248                                    // Close the stream gracefully
249                                    let _ = stream.close().await;
250                                    continue;
251                                }
252
253                                connection_counter += 1;
254                                crate::log_debug!("Accepted connection #{} from {}", connection_counter, remote_addr);
255                                let manager = manager.clone();
256                                let handler = handler.clone();
257                                let config = config.clone();
258                                let rate_limiter = rate_limiter.clone();
259
260                                // Spawn connection handler
261                                tokio::spawn(async move {
262                                    if let Err(e) = Self::handle_connection(
263                                        stream,
264                                        handler,
265                                        config,
266                                        manager,
267                                        rate_limiter,
268                                    ).await {
269                                        crate::log_error!("Connection handling error: {:?}", e);
270                                    }
271                                });
272                            }
273                            Err(e) => {
274                                crate::log_error!("Accept error: {:?}", e);
275                                // Continue accepting other connections
276                            }
277                        }
278                    }
279                    _ = tokio::signal::ctrl_c() => {
280                        break;
281                    }
282                }
283            }
284        });
285
286        // Wait for server task completion
287        match server_task.await {
288            Ok(()) => Ok(()),
289            Err(e) => Err(Error::Other(format!("Server task panicked: {}", e))),
290        }
291    }
292
293    /// Serve with TLS transport
294    #[cfg(feature = "tls-transport")]
295    async fn serve_with_tls_transport<F>(
296        self,
297        _transport: crate::tls_transport::TlsTransport,
298        _connection_manager: Arc<ConnectionManager>,
299        _shutdown_signal: F,
300    ) -> Result<()>
301    where
302        F: std::future::Future<Output = ()> + Send + Unpin + 'static,
303    {
304        Err(Error::Other(
305            "TLS transport is not available in this release".to_string(),
306        ))
307    }
308
309    /// Handle a single TLS connection
310    #[cfg(feature = "tls-transport")]
311    async fn handle_tls_connection(
312        _stream: crate::tls_transport::TlsStreamWrapper,
313        _handler: BoxedHandler,
314        _config: ServerConfig,
315        _connection_manager: Arc<ConnectionManager>,
316        _rate_limiter: Option<Arc<RateLimitMiddleware>>,
317    ) -> Result<()> {
318        Err(Error::Other(
319            "TLS connection handling is not available in this release".to_string(),
320        ))
321    }
322
323    /// Handle a single connection
324    async fn handle_connection(
325        mut stream: crate::tcp_transport::TcpStream,
326        handler: BoxedHandler,
327        config: ServerConfig,
328        connection_manager: Arc<ConnectionManager>,
329        rate_limiter: Option<Arc<RateLimitMiddleware>>,
330    ) -> Result<()> {
331        // Perform WebSocket handshake
332        let (remote_addr, local_addr) = Self::perform_handshake(&mut stream, &config).await?;
333
334        // Convert to boxed transport stream
335        let boxed_stream: Box<dyn TransportStream> = Box::new(stream);
336
337        // Create connection with stream
338        let connection = Connection::with_stream(remote_addr, local_addr, boxed_stream);
339
340        // Add to connection manager
341        let connection_id = connection_manager.add_connection(connection).await;
342
343        // Get connection handle
344        let connection_handle = connection_manager
345            .get_connection(connection_id)
346            .await
347            .ok_or_else(|| Error::Other("Failed to get connection handle".to_string()))?;
348
349        // Call handler
350        if let Err(e) = handler.handle(connection_handle).await {
351            crate::log_error!("Handler error: {:?}", e);
352        }
353
354        // Remove connection from manager
355        connection_manager.remove_connection(connection_id).await;
356
357        // Clean up rate limiting
358        if let Some(ref rate_limiter) = rate_limiter {
359            rate_limiter.connection_closed(remote_addr.ip()).await;
360        }
361
362        Ok(())
363    }
364
365    /// Perform WebSocket handshake over TLS
366    #[cfg(feature = "tls-transport")]
367    async fn perform_tls_handshake(
368        stream: &mut crate::tls_transport::TlsStreamWrapper,
369        config: &ServerConfig,
370    ) -> Result<(SocketAddr, SocketAddr)> {
371        // Read HTTP request over TLS
372        let request_data =
373            Self::read_tls_handshake_request(stream, config.handshake_timeout).await?;
374        let request_str = String::from_utf8_lossy(&request_data);
375
376        // Parse handshake request
377        let request = parse_client_handshake(&request_str)?;
378
379        // Create handshake config
380        let handshake_config = HandshakeConfig {
381            protocols: config.supported_protocols.clone(),
382            extensions: config.supported_extensions.clone(),
383            origin: config.allowed_origin.clone(),
384            host: None,
385            extra_headers: config.extra_headers.clone(),
386        };
387
388        // Validate request
389        validate_client_handshake(&request, &handshake_config)?;
390
391        // Create response
392        let response = create_server_handshake(&request, &handshake_config)?;
393        let response_str = response_to_string(&response);
394
395        // Send response over TLS
396        stream.write_all(response_str.as_bytes()).await?;
397        stream.flush().await?;
398
399        // Get addresses
400        let remote_addr = stream.remote_addr()?;
401        let local_addr = stream.local_addr()?;
402
403        Ok((remote_addr, local_addr))
404    }
405
406    /// Read handshake request from TLS stream
407    #[cfg(feature = "tls-transport")]
408    async fn read_tls_handshake_request(
409        stream: &mut crate::tls_transport::TlsStreamWrapper,
410        timeout_duration: Duration,
411    ) -> Result<Vec<u8>> {
412        let mut buffer = Vec::new();
413        let mut temp_buffer = [0u8; 1024];
414
415        let read_result = timeout(timeout_duration, async {
416            loop {
417                let n = stream.read(&mut temp_buffer).await?;
418                if n == 0 {
419                    break;
420                }
421
422                buffer.extend_from_slice(&temp_buffer[..n]);
423
424                // Check for end of headers (double CRLF)
425                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
426                    break;
427                }
428
429                // Prevent reading too much
430                if buffer.len() > 8192 {
431                    return Err(Error::Other("TLS handshake request too large".to_string()));
432                }
433            }
434
435            Ok::<(), Error>(())
436        })
437        .await;
438
439        match read_result {
440            Ok(result) => {
441                result?;
442                Ok(buffer)
443            }
444            Err(_) => Err(Error::Other("TLS handshake timeout".to_string())),
445        }
446    }
447
448    /// Perform WebSocket handshake
449    async fn perform_handshake(
450        stream: &mut crate::tcp_transport::TcpStream,
451        config: &ServerConfig,
452    ) -> Result<(SocketAddr, SocketAddr)> {
453        // Read HTTP request
454        let request_data = Self::read_handshake_request(stream, config.handshake_timeout).await?;
455        let request_str = String::from_utf8_lossy(&request_data);
456
457        // Parse handshake request
458        let request = parse_client_handshake(&request_str)?;
459
460        // Create handshake config
461        let handshake_config = HandshakeConfig {
462            protocols: config.supported_protocols.clone(),
463            extensions: config.supported_extensions.clone(),
464            origin: config.allowed_origin.clone(),
465            host: None,
466            extra_headers: config.extra_headers.clone(),
467        };
468
469        // Validate request
470        validate_client_handshake(&request, &handshake_config)?;
471
472        // Create response
473        let response = create_server_handshake(&request, &handshake_config)?;
474        let response_str = response_to_string(&response);
475
476        // Send response
477        stream.write_all(response_str.as_bytes()).await?;
478        stream.flush().await?;
479
480        // Get addresses
481        let remote_addr = stream.remote_addr()?;
482        let local_addr = stream.local_addr()?;
483
484        Ok((remote_addr, local_addr))
485    }
486
487    /// Read handshake request from stream
488    async fn read_handshake_request(
489        stream: &mut crate::tcp_transport::TcpStream,
490        timeout_duration: Duration,
491    ) -> Result<Vec<u8>> {
492        let mut buffer = Vec::new();
493        let mut temp_buffer = [0u8; 1024];
494
495        let read_result = timeout(timeout_duration, async {
496            loop {
497                let n = stream.read(&mut temp_buffer).await?;
498                if n == 0 {
499                    break;
500                }
501
502                buffer.extend_from_slice(&temp_buffer[..n]);
503
504                // Check for end of headers (double CRLF)
505                if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
506                    break;
507                }
508
509                // Prevent reading too much
510                if buffer.len() > 8192 {
511                    return Err(Error::Other("Handshake request too large".to_string()));
512                }
513            }
514
515            Ok::<(), Error>(())
516        })
517        .await;
518
519        match read_result {
520            Ok(result) => {
521                result?;
522                Ok(buffer)
523            }
524            Err(_) => Err(Error::Other("Handshake timeout".to_string())),
525        }
526    }
527
528    /// Perform graceful shutdown
529    async fn graceful_shutdown(&self, connection_manager: Arc<ConnectionManager>) -> Result<()> {
530        // Get all connections
531        let connections = connection_manager.get_all_connections().await;
532
533        // Send close frames to all connections
534        for handle in connections {
535            if let Ok(mut connection) = handle.try_lock().await {
536                let _ = connection.close(Some(1000), Some("Server shutdown")).await;
537            }
538        }
539
540        Ok(())
541    }
542}
543
544/// Server builder
545#[derive(Debug, Clone)]
546pub struct ServerBuilder {
547    config: ServerConfig,
548}
549
550impl ServerBuilder {
551    /// Create a new server builder
552    pub fn new() -> Self {
553        Self {
554            config: ServerConfig::default(),
555        }
556    }
557
558    /// Bind to the given address
559    pub fn bind<A: std::net::ToSocketAddrs>(mut self, addr: A) -> Result<Self> {
560        self.config.bind_address = addr.to_socket_addrs()?.next().ok_or_else(|| {
561            Error::Config(ConfigError::Validation("Invalid bind address".to_string()))
562        })?;
563        Ok(self)
564    }
565
566    /// Set maximum connections
567    pub fn max_connections(mut self, max: usize) -> Self {
568        self.config.max_connections = max;
569        self
570    }
571
572    /// Set maximum frame size
573    pub fn max_frame_size(mut self, size: usize) -> Self {
574        self.config.max_frame_size = size;
575        self
576    }
577
578    /// Set maximum message size
579    pub fn max_message_size(mut self, size: usize) -> Self {
580        self.config.max_message_size = size;
581        self
582    }
583
584    /// Set handshake timeout
585    pub fn handshake_timeout(mut self, timeout: std::time::Duration) -> Self {
586        self.config.handshake_timeout = timeout;
587        self
588    }
589
590    /// Set idle timeout
591    pub fn idle_timeout(mut self, timeout: std::time::Duration) -> Self {
592        self.config.idle_timeout = timeout;
593        self
594    }
595
596    /// Enable/disable compression
597    pub fn compression(mut self, enabled: bool) -> Self {
598        self.config.compression.enabled = enabled;
599        self
600    }
601
602    /// Set backpressure strategy
603    pub fn backpressure(mut self, strategy: crate::config::BackpressureStrategy) -> Self {
604        self.config.backpressure.strategy = strategy;
605        self
606    }
607
608    /// Build the server
609    pub fn build(self) -> Result<Server> {
610        // Validate configuration
611        self.config.validate()?;
612
613        // Create default handler
614        let handler = Box::new(crate::handler::DefaultHandler::new());
615
616        Ok(Server::new(self.config, handler))
617    }
618
619    /// Build the server with a custom handler
620    pub fn build_with_handler<H>(self, handler: H) -> Result<Server>
621    where
622        H: Handler + Send + Sync + 'static,
623    {
624        // Validate configuration
625        self.config.validate()?;
626
627        Ok(Server::new(self.config, Box::new(handler)))
628    }
629}
630
631impl Default for ServerBuilder {
632    fn default() -> Self {
633        Self::new()
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640
641    #[test]
642    fn test_server_builder() {
643        let builder = ServerBuilder::new()
644            .bind("127.0.0.1:8080")
645            .unwrap()
646            .max_connections(1000)
647            .max_frame_size(1024 * 1024)
648            .compression(true);
649
650        assert!(builder.build().is_ok());
651    }
652}