aerosocket_server/
config.rs

1//! Server configuration
2//!
3//! This module provides configuration options for the WebSocket server.
4
5use aerosocket_core::error::{Error, ConfigError};
6use std::time::Duration;
7
8/// Server configuration
9#[derive(Debug, Clone)]
10pub struct ServerConfig {
11    /// Bind address
12    pub bind_address: std::net::SocketAddr,
13    /// Maximum concurrent connections
14    pub max_connections: usize,
15    /// Maximum frame size in bytes
16    pub max_frame_size: usize,
17    /// Maximum message size in bytes
18    pub max_message_size: usize,
19    /// Handshake timeout
20    pub handshake_timeout: Duration,
21    /// Idle timeout
22    pub idle_timeout: Duration,
23    /// Compression configuration
24    pub compression: CompressionConfig,
25    /// Backpressure configuration
26    pub backpressure: BackpressureConfig,
27    /// TLS configuration
28    pub tls: Option<TlsConfig>,
29    /// Transport type
30    pub transport_type: TransportType,
31    /// Supported WebSocket subprotocols
32    pub supported_protocols: Vec<String>,
33    /// Supported WebSocket extensions
34    pub supported_extensions: Vec<String>,
35    /// Allowed origin (for CORS)
36    pub allowed_origin: Option<String>,
37    /// Extra headers to send in handshake response
38    pub extra_headers: std::collections::HashMap<String, String>,
39}
40
41/// Transport type
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum TransportType {
44    /// TCP transport
45    Tcp,
46    /// TLS transport
47    Tls,
48}
49
50impl Default for ServerConfig {
51    fn default() -> Self {
52        Self {
53            bind_address: "0.0.0.0:8080".parse().unwrap(),
54            max_connections: 10_000,
55            max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
56            max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
57            handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
58            idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
59            compression: CompressionConfig::default(),
60            backpressure: BackpressureConfig::default(),
61            tls: None,
62            transport_type: TransportType::Tcp,
63            supported_protocols: vec![],
64            supported_extensions: vec![],
65            allowed_origin: None,
66            extra_headers: std::collections::HashMap::new(),
67        }
68    }
69}
70
71impl ServerConfig {
72    /// Validate the configuration
73    pub fn validate(&self) -> aerosocket_core::Result<()> {
74        if self.max_connections == 0 {
75            return Err(Error::Config(ConfigError::Validation("max_connections must be greater than 0".to_string())));
76        }
77
78        if self.max_frame_size == 0 {
79            return Err(Error::Config(ConfigError::Validation("max_frame_size must be greater than 0".to_string())));
80        }
81
82        if self.max_message_size == 0 {
83            return Err(Error::Config(ConfigError::Validation("max_message_size must be greater than 0".to_string())));
84        }
85
86        if self.max_message_size < self.max_frame_size {
87            return Err(Error::Config(ConfigError::Validation("max_message_size must be greater than or equal to max_frame_size".to_string())));
88        }
89
90        Ok(())
91    }
92}
93
94/// Compression configuration
95#[derive(Debug, Clone)]
96pub struct CompressionConfig {
97    /// Enable compression
98    pub enabled: bool,
99    /// Compression level (0-9)
100    pub level: u8,
101    /// Server context takeover
102    pub server_context_takeover: bool,
103    /// Client context takeover
104    pub client_context_takeover: bool,
105    /// Server max window bits
106    pub server_max_window_bits: Option<u8>,
107    /// Client max window bits
108    pub client_max_window_bits: Option<u8>,
109}
110
111impl Default for CompressionConfig {
112    fn default() -> Self {
113        Self {
114            enabled: false,
115            level: 6,
116            server_context_takeover: true,
117            client_context_takeover: true,
118            server_max_window_bits: None,
119            client_max_window_bits: None,
120        }
121    }
122}
123
124/// Backpressure configuration
125#[derive(Debug, Clone)]
126pub struct BackpressureConfig {
127    /// Whether backpressure is enabled
128    pub enabled: bool,
129    /// Maximum requests per minute per IP
130    pub max_requests_per_minute: usize,
131    /// Backpressure strategy
132    pub strategy: BackpressureStrategy,
133    /// Buffer size in bytes
134    pub buffer_size: usize,
135    /// High water mark in bytes
136    pub high_water_mark: usize,
137    /// Low water mark in bytes
138    pub low_water_mark: usize,
139}
140
141impl Default for BackpressureConfig {
142    fn default() -> Self {
143        Self {
144            enabled: true,
145            max_requests_per_minute: 60,
146            strategy: BackpressureStrategy::Buffer,
147            buffer_size: 64 * 1024, // 64KB
148            high_water_mark: 48 * 1024, // 48KB
149            low_water_mark: 16 * 1024, // 16KB
150        }
151    }
152}
153
154/// Backpressure strategy
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub enum BackpressureStrategy {
157    /// Buffer messages (default)
158    Buffer,
159    /// Drop oldest messages when buffer is full
160    DropOldest,
161    /// Reject new messages when buffer is full
162    Reject,
163    /// Apply flow control to sender
164    FlowControl,
165}
166
167/// TLS configuration
168#[derive(Debug, Clone)]
169pub struct TlsConfig {
170    /// Path to certificate file
171    pub cert_file: String,
172    /// Path to private key file
173    pub key_file: String,
174    /// Certificate chain file (optional)
175    pub cert_chain_file: Option<String>,
176    /// Enable client authentication
177    pub client_auth: bool,
178    /// CA file for client authentication
179    pub ca_file: Option<String>,
180}
181
182impl TlsConfig {
183    /// Create a new TLS configuration
184    pub fn new(cert_file: String, key_file: String) -> Self {
185        Self {
186            cert_file,
187            key_file,
188            cert_chain_file: None,
189            client_auth: false,
190            ca_file: None,
191        }
192    }
193
194    /// Set certificate chain file
195    pub fn cert_chain_file(mut self, file: String) -> Self {
196        self.cert_chain_file = Some(file);
197        self
198    }
199
200    /// Enable client authentication
201    pub fn client_auth(mut self, enabled: bool) -> Self {
202        self.client_auth = enabled;
203        self
204    }
205
206    /// Set CA file for client authentication
207    pub fn ca_file(mut self, file: String) -> Self {
208        self.ca_file = Some(file);
209        self
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_server_config_default() {
219        let config = ServerConfig::default();
220        assert!(config.validate().is_ok());
221        assert_eq!(config.max_connections, 10_000);
222        assert_eq!(config.bind_address.port(), 8080);
223    }
224
225    #[test]
226    fn test_server_config_validation() {
227        let mut config = ServerConfig::default();
228        config.max_connections = 0;
229        assert!(config.validate().is_err());
230
231        config.max_connections = 1000;
232        config.max_frame_size = 0;
233        assert!(config.validate().is_err());
234
235        config.max_frame_size = 1024;
236        config.max_message_size = 512;
237        assert!(config.validate().is_err());
238    }
239
240    #[test]
241    fn test_tls_config() {
242        let config = TlsConfig::new("cert.pem".to_string(), "key.pem".to_string())
243            .cert_chain_file("chain.pem".to_string())
244            .client_auth(true)
245            .ca_file("ca.pem".to_string());
246
247        assert_eq!(config.cert_file, "cert.pem");
248        assert_eq!(config.key_file, "key.pem");
249        assert_eq!(config.cert_chain_file, Some("chain.pem".to_string()));
250        assert!(config.client_auth);
251        assert_eq!(config.ca_file, Some("ca.pem".to_string()));
252    }
253}