aerosocket_server/
config.rs

1//! Server configuration
2//!
3//! This module provides configuration options for the WebSocket server.
4
5use aerosocket_core::error::{ConfigError, Error};
6use std::time::Duration;
7
8/// Server configuration
9#[derive(Debug, Clone)]
10pub struct ServerConfig {
11    /// Bind address
12    pub bind_address: std::net::SocketAddr,
13    /// Maximum concurrent connections
14    pub max_connections: usize,
15    /// Maximum frame size in bytes
16    pub max_frame_size: usize,
17    /// Maximum message size in bytes
18    pub max_message_size: usize,
19    /// Handshake timeout
20    pub handshake_timeout: Duration,
21    /// Idle timeout
22    pub idle_timeout: Duration,
23    /// Compression configuration
24    pub compression: CompressionConfig,
25    /// Backpressure configuration
26    pub backpressure: BackpressureConfig,
27    /// TLS configuration
28    pub tls: Option<TlsConfig>,
29    /// Transport type
30    pub transport_type: TransportType,
31    /// Supported WebSocket subprotocols
32    pub supported_protocols: Vec<String>,
33    /// Supported WebSocket extensions
34    pub supported_extensions: Vec<String>,
35    /// Allowed origin (for CORS)
36    pub allowed_origin: Option<String>,
37    /// Extra headers to send in handshake response
38    pub extra_headers: std::collections::HashMap<String, String>,
39}
40
41/// Transport type
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum TransportType {
44    /// TCP transport
45    Tcp,
46    /// TLS transport
47    Tls,
48}
49
50impl Default for ServerConfig {
51    fn default() -> Self {
52        Self {
53            bind_address: "0.0.0.0:8080".parse().unwrap(),
54            max_connections: 10_000,
55            max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
56            max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
57            handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
58            idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
59            compression: CompressionConfig::default(),
60            backpressure: BackpressureConfig::default(),
61            tls: None,
62            transport_type: TransportType::Tcp,
63            supported_protocols: vec![],
64            supported_extensions: vec![],
65            allowed_origin: None,
66            extra_headers: std::collections::HashMap::new(),
67        }
68    }
69}
70
71impl ServerConfig {
72    /// Validate the configuration
73    pub fn validate(&self) -> aerosocket_core::Result<()> {
74        if self.max_connections == 0 {
75            return Err(Error::Config(ConfigError::Validation(
76                "max_connections must be greater than 0".to_string(),
77            )));
78        }
79
80        if self.max_frame_size == 0 {
81            return Err(Error::Config(ConfigError::Validation(
82                "max_frame_size must be greater than 0".to_string(),
83            )));
84        }
85
86        if self.max_message_size == 0 {
87            return Err(Error::Config(ConfigError::Validation(
88                "max_message_size must be greater than 0".to_string(),
89            )));
90        }
91
92        if self.max_message_size < self.max_frame_size {
93            return Err(Error::Config(ConfigError::Validation(
94                "max_message_size must be greater than or equal to max_frame_size".to_string(),
95            )));
96        }
97
98        Ok(())
99    }
100}
101
102/// Compression configuration
103#[derive(Debug, Clone)]
104pub struct CompressionConfig {
105    /// Enable compression
106    pub enabled: bool,
107    /// Compression level (0-9)
108    pub level: u8,
109    /// Server context takeover
110    pub server_context_takeover: bool,
111    /// Client context takeover
112    pub client_context_takeover: bool,
113    /// Server max window bits
114    pub server_max_window_bits: Option<u8>,
115    /// Client max window bits
116    pub client_max_window_bits: Option<u8>,
117}
118
119impl Default for CompressionConfig {
120    fn default() -> Self {
121        Self {
122            enabled: false,
123            level: 6,
124            server_context_takeover: true,
125            client_context_takeover: true,
126            server_max_window_bits: None,
127            client_max_window_bits: None,
128        }
129    }
130}
131
132/// Backpressure configuration
133#[derive(Debug, Clone)]
134pub struct BackpressureConfig {
135    /// Whether backpressure is enabled
136    pub enabled: bool,
137    /// Maximum requests per minute per IP
138    pub max_requests_per_minute: usize,
139    /// Backpressure strategy
140    pub strategy: BackpressureStrategy,
141    /// Buffer size in bytes
142    pub buffer_size: usize,
143    /// High water mark in bytes
144    pub high_water_mark: usize,
145    /// Low water mark in bytes
146    pub low_water_mark: usize,
147}
148
149impl Default for BackpressureConfig {
150    fn default() -> Self {
151        Self {
152            enabled: true,
153            max_requests_per_minute: 60,
154            strategy: BackpressureStrategy::Buffer,
155            buffer_size: 64 * 1024,     // 64KB
156            high_water_mark: 48 * 1024, // 48KB
157            low_water_mark: 16 * 1024,  // 16KB
158        }
159    }
160}
161
162/// Backpressure strategy
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum BackpressureStrategy {
165    /// Buffer messages (default)
166    Buffer,
167    /// Drop oldest messages when buffer is full
168    DropOldest,
169    /// Reject new messages when buffer is full
170    Reject,
171    /// Apply flow control to sender
172    FlowControl,
173}
174
175/// TLS configuration
176#[derive(Debug, Clone)]
177pub struct TlsConfig {
178    /// Path to certificate file
179    pub cert_file: String,
180    /// Path to private key file
181    pub key_file: String,
182    /// Certificate chain file (optional)
183    pub cert_chain_file: Option<String>,
184    /// Enable client authentication
185    pub client_auth: bool,
186    /// CA file for client authentication
187    pub ca_file: Option<String>,
188}
189
190impl TlsConfig {
191    /// Create a new TLS configuration
192    pub fn new(cert_file: String, key_file: String) -> Self {
193        Self {
194            cert_file,
195            key_file,
196            cert_chain_file: None,
197            client_auth: false,
198            ca_file: None,
199        }
200    }
201
202    /// Set certificate chain file
203    pub fn cert_chain_file(mut self, file: String) -> Self {
204        self.cert_chain_file = Some(file);
205        self
206    }
207
208    /// Enable client authentication
209    pub fn client_auth(mut self, enabled: bool) -> Self {
210        self.client_auth = enabled;
211        self
212    }
213
214    /// Set CA file for client authentication
215    pub fn ca_file(mut self, file: String) -> Self {
216        self.ca_file = Some(file);
217        self
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_server_config_default() {
227        let config = ServerConfig::default();
228        assert!(config.validate().is_ok());
229        assert_eq!(config.max_connections, 10_000);
230        assert_eq!(config.bind_address.port(), 8080);
231    }
232
233    #[test]
234    fn test_server_config_validation() {
235        let mut config = ServerConfig::default();
236        config.max_connections = 0;
237        assert!(config.validate().is_err());
238
239        config.max_connections = 1000;
240        config.max_frame_size = 0;
241        assert!(config.validate().is_err());
242
243        config.max_frame_size = 1024;
244        config.max_message_size = 512;
245        assert!(config.validate().is_err());
246    }
247
248    #[test]
249    fn test_tls_config() {
250        let config = TlsConfig::new("cert.pem".to_string(), "key.pem".to_string())
251            .cert_chain_file("chain.pem".to_string())
252            .client_auth(true)
253            .ca_file("ca.pem".to_string());
254
255        assert_eq!(config.cert_file, "cert.pem");
256        assert_eq!(config.key_file, "key.pem");
257        assert_eq!(config.cert_chain_file, Some("chain.pem".to_string()));
258        assert!(config.client_auth);
259        assert_eq!(config.ca_file, Some("ca.pem".to_string()));
260    }
261}