1use aerosocket_core::error::{Error, ConfigError};
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct ServerConfig {
11 pub bind_address: std::net::SocketAddr,
13 pub max_connections: usize,
15 pub max_frame_size: usize,
17 pub max_message_size: usize,
19 pub handshake_timeout: Duration,
21 pub idle_timeout: Duration,
23 pub compression: CompressionConfig,
25 pub backpressure: BackpressureConfig,
27 pub tls: Option<TlsConfig>,
29 pub transport_type: TransportType,
31 pub supported_protocols: Vec<String>,
33 pub supported_extensions: Vec<String>,
35 pub allowed_origin: Option<String>,
37 pub extra_headers: std::collections::HashMap<String, String>,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum TransportType {
44 Tcp,
46 Tls,
48}
49
50impl Default for ServerConfig {
51 fn default() -> Self {
52 Self {
53 bind_address: "0.0.0.0:8080".parse().unwrap(),
54 max_connections: 10_000,
55 max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
56 max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
57 handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
58 idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
59 compression: CompressionConfig::default(),
60 backpressure: BackpressureConfig::default(),
61 tls: None,
62 transport_type: TransportType::Tcp,
63 supported_protocols: vec![],
64 supported_extensions: vec![],
65 allowed_origin: None,
66 extra_headers: std::collections::HashMap::new(),
67 }
68 }
69}
70
71impl ServerConfig {
72 pub fn validate(&self) -> aerosocket_core::Result<()> {
74 if self.max_connections == 0 {
75 return Err(Error::Config(ConfigError::Validation("max_connections must be greater than 0".to_string())));
76 }
77
78 if self.max_frame_size == 0 {
79 return Err(Error::Config(ConfigError::Validation("max_frame_size must be greater than 0".to_string())));
80 }
81
82 if self.max_message_size == 0 {
83 return Err(Error::Config(ConfigError::Validation("max_message_size must be greater than 0".to_string())));
84 }
85
86 if self.max_message_size < self.max_frame_size {
87 return Err(Error::Config(ConfigError::Validation("max_message_size must be greater than or equal to max_frame_size".to_string())));
88 }
89
90 Ok(())
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct CompressionConfig {
97 pub enabled: bool,
99 pub level: u8,
101 pub server_context_takeover: bool,
103 pub client_context_takeover: bool,
105 pub server_max_window_bits: Option<u8>,
107 pub client_max_window_bits: Option<u8>,
109}
110
111impl Default for CompressionConfig {
112 fn default() -> Self {
113 Self {
114 enabled: false,
115 level: 6,
116 server_context_takeover: true,
117 client_context_takeover: true,
118 server_max_window_bits: None,
119 client_max_window_bits: None,
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct BackpressureConfig {
127 pub enabled: bool,
129 pub max_requests_per_minute: usize,
131 pub strategy: BackpressureStrategy,
133 pub buffer_size: usize,
135 pub high_water_mark: usize,
137 pub low_water_mark: usize,
139}
140
141impl Default for BackpressureConfig {
142 fn default() -> Self {
143 Self {
144 enabled: true,
145 max_requests_per_minute: 60,
146 strategy: BackpressureStrategy::Buffer,
147 buffer_size: 64 * 1024, high_water_mark: 48 * 1024, low_water_mark: 16 * 1024, }
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub enum BackpressureStrategy {
157 Buffer,
159 DropOldest,
161 Reject,
163 FlowControl,
165}
166
167#[derive(Debug, Clone)]
169pub struct TlsConfig {
170 pub cert_file: String,
172 pub key_file: String,
174 pub cert_chain_file: Option<String>,
176 pub client_auth: bool,
178 pub ca_file: Option<String>,
180}
181
182impl TlsConfig {
183 pub fn new(cert_file: String, key_file: String) -> Self {
185 Self {
186 cert_file,
187 key_file,
188 cert_chain_file: None,
189 client_auth: false,
190 ca_file: None,
191 }
192 }
193
194 pub fn cert_chain_file(mut self, file: String) -> Self {
196 self.cert_chain_file = Some(file);
197 self
198 }
199
200 pub fn client_auth(mut self, enabled: bool) -> Self {
202 self.client_auth = enabled;
203 self
204 }
205
206 pub fn ca_file(mut self, file: String) -> Self {
208 self.ca_file = Some(file);
209 self
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_server_config_default() {
219 let config = ServerConfig::default();
220 assert!(config.validate().is_ok());
221 assert_eq!(config.max_connections, 10_000);
222 assert_eq!(config.bind_address.port(), 8080);
223 }
224
225 #[test]
226 fn test_server_config_validation() {
227 let mut config = ServerConfig::default();
228 config.max_connections = 0;
229 assert!(config.validate().is_err());
230
231 config.max_connections = 1000;
232 config.max_frame_size = 0;
233 assert!(config.validate().is_err());
234
235 config.max_frame_size = 1024;
236 config.max_message_size = 512;
237 assert!(config.validate().is_err());
238 }
239
240 #[test]
241 fn test_tls_config() {
242 let config = TlsConfig::new("cert.pem".to_string(), "key.pem".to_string())
243 .cert_chain_file("chain.pem".to_string())
244 .client_auth(true)
245 .ca_file("ca.pem".to_string());
246
247 assert_eq!(config.cert_file, "cert.pem");
248 assert_eq!(config.key_file, "key.pem");
249 assert_eq!(config.cert_chain_file, Some("chain.pem".to_string()));
250 assert!(config.client_auth);
251 assert_eq!(config.ca_file, Some("ca.pem".to_string()));
252 }
253}