1use aerosocket_core::error::{ConfigError, Error};
6use std::time::Duration;
7
8#[cfg(feature = "tls-transport")]
9use rustls::pki_types::{CertificateDer, PrivateKeyDer};
10#[cfg(feature = "tls-transport")]
11use rustls::ServerConfig as RustlsServerConfig;
12#[cfg(feature = "tls-transport")]
13use rustls_pemfile::{certs, pkcs8_private_keys};
14#[cfg(feature = "tls-transport")]
15use std::fs::File;
16#[cfg(feature = "tls-transport")]
17use std::io::BufReader;
18
19#[derive(Debug, Clone)]
21pub struct ServerConfig {
22 pub bind_address: std::net::SocketAddr,
24 pub max_connections: usize,
26 pub max_frame_size: usize,
28 pub max_message_size: usize,
30 pub handshake_timeout: Duration,
32 pub idle_timeout: Duration,
34 pub compression: CompressionConfig,
36 pub backpressure: BackpressureConfig,
38 pub tls: Option<TlsConfig>,
40 pub transport_type: TransportType,
42 pub supported_protocols: Vec<String>,
44 pub supported_extensions: Vec<String>,
46 pub allowed_origins: Vec<String>,
48 pub extra_headers: std::collections::HashMap<String, String>,
50 pub rate_limit: Option<crate::rate_limit::RateLimitConfig>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum TransportType {
59 Tcp,
61 Tls,
63}
64
65impl Default for ServerConfig {
66 fn default() -> Self {
67 Self {
68 bind_address: "0.0.0.0:8080".parse().unwrap(),
69 max_connections: 10_000,
70 max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
71 max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
72 handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
73 idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
74 compression: CompressionConfig::default(),
75 backpressure: BackpressureConfig::default(),
76 tls: None,
77 transport_type: TransportType::Tcp,
78 supported_protocols: vec![],
79 supported_extensions: vec![],
80 allowed_origins: vec![],
81 extra_headers: std::collections::HashMap::new(),
82 rate_limit: Some(crate::rate_limit::RateLimitConfig::default()),
84 }
85 }
86}
87
88impl ServerConfig {
89 pub fn validate(&self) -> aerosocket_core::Result<()> {
91 if self.max_connections == 0 {
92 return Err(Error::Config(ConfigError::Validation(
93 "max_connections must be greater than 0".to_string(),
94 )));
95 }
96
97 if self.max_frame_size == 0 {
98 return Err(Error::Config(ConfigError::Validation(
99 "max_frame_size must be greater than 0".to_string(),
100 )));
101 }
102
103 if self.max_message_size == 0 {
104 return Err(Error::Config(ConfigError::Validation(
105 "max_message_size must be greater than 0".to_string(),
106 )));
107 }
108
109 if self.max_message_size < self.max_frame_size {
110 return Err(Error::Config(ConfigError::Validation(
111 "max_message_size must be greater than or equal to max_frame_size".to_string(),
112 )));
113 }
114
115 Ok(())
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct CompressionConfig {
122 pub enabled: bool,
124 pub level: u8,
126 pub server_context_takeover: bool,
128 pub client_context_takeover: bool,
130 pub server_max_window_bits: Option<u8>,
132 pub client_max_window_bits: Option<u8>,
134}
135
136impl Default for CompressionConfig {
137 fn default() -> Self {
138 Self {
139 enabled: false,
140 level: 6,
141 server_context_takeover: true,
142 client_context_takeover: true,
143 server_max_window_bits: None,
144 client_max_window_bits: None,
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct BackpressureConfig {
152 pub enabled: bool,
154 pub max_requests_per_minute: usize,
156 pub strategy: BackpressureStrategy,
158 pub buffer_size: usize,
160 pub high_water_mark: usize,
162 pub low_water_mark: usize,
164}
165
166impl Default for BackpressureConfig {
167 fn default() -> Self {
168 Self {
169 enabled: true,
170 max_requests_per_minute: 60,
171 strategy: BackpressureStrategy::Buffer,
172 buffer_size: 64 * 1024, high_water_mark: 48 * 1024, low_water_mark: 16 * 1024, }
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub enum BackpressureStrategy {
182 Buffer,
184 DropOldest,
186 Reject,
188 FlowControl,
190}
191
192#[derive(Debug, Clone)]
194pub struct TlsConfig {
195 pub cert_file: String,
197 pub key_file: String,
199 pub cert_chain_file: Option<String>,
201 pub client_auth: bool,
203 pub ca_file: Option<String>,
205}
206
207impl TlsConfig {
208 pub fn new(cert_file: String, key_file: String) -> Self {
210 Self {
211 cert_file,
212 key_file,
213 cert_chain_file: None,
214 client_auth: false,
215 ca_file: None,
216 }
217 }
218
219 pub fn cert_chain_file(mut self, file: String) -> Self {
221 self.cert_chain_file = Some(file);
222 self
223 }
224
225 pub fn client_auth(mut self, enabled: bool) -> Self {
227 self.client_auth = enabled;
228 self
229 }
230
231 pub fn ca_file(mut self, file: String) -> Self {
233 self.ca_file = Some(file);
234 self
235 }
236}
237
238#[cfg(feature = "tls-transport")]
239fn load_certs(path: &str) -> aerosocket_core::Result<Vec<CertificateDer<'static>>> {
240 let file = File::open(path).map_err(|e| {
241 Error::Config(ConfigError::Validation(format!(
242 "Failed to open certificate file {}: {}",
243 path, e
244 )))
245 })?;
246 let mut reader = BufReader::new(file);
247 let certs: Result<Vec<_>, _> = certs(&mut reader).collect();
248 certs.map_err(|e| {
249 Error::Config(ConfigError::Validation(format!(
250 "Failed to parse certificate file {}: {}",
251 path, e
252 )))
253 })
254}
255
256#[cfg(feature = "tls-transport")]
257fn load_private_key(path: &str) -> aerosocket_core::Result<PrivateKeyDer<'static>> {
258 let file = File::open(path).map_err(|e| {
259 Error::Config(ConfigError::Validation(format!(
260 "Failed to open private key file {}: {}",
261 path, e
262 )))
263 })?;
264 let mut reader = BufReader::new(file);
265
266 for item in pkcs8_private_keys(&mut reader) {
268 match item {
269 Ok(key) => return Ok(PrivateKeyDer::Pkcs8(key)),
270 Err(e) => return Err(Error::Config(ConfigError::Validation(format!(
271 "Failed to parse PKCS8 private key in {}: {}",
272 path, e
273 )))),
274 }
275 }
276
277 let file = File::open(path).map_err(|e| {
279 Error::Config(ConfigError::Validation(format!(
280 "Failed to reopen private key file {}: {}",
281 path, e
282 )))
283 })?;
284 let mut reader = BufReader::new(file);
285 use rustls_pemfile::rsa_private_keys;
286 for item in rsa_private_keys(&mut reader) {
287 match item {
288 Ok(key) => return Ok(PrivateKeyDer::Pkcs1(key)),
289 Err(e) => return Err(Error::Config(ConfigError::Validation(format!(
290 "Failed to parse RSA private key in {}: {}",
291 path, e
292 )))),
293 }
294 }
295
296 Err(Error::Config(ConfigError::Validation(format!(
297 "No private keys found in {}",
298 path
299 ))))
300}
301
302#[cfg(feature = "tls-transport")]
303pub fn build_rustls_server_config(tls: &TlsConfig) -> aerosocket_core::Result<RustlsServerConfig> {
304 let certs = load_certs(&tls.cert_file)?;
305 let key = load_private_key(&tls.key_file)?;
306
307 RustlsServerConfig::builder()
308 .with_no_client_auth()
309 .with_single_cert(certs, key)
310 .map_err(|e| {
311 Error::Config(ConfigError::Validation(format!(
312 "Invalid TLS certificate/key: {}",
313 e
314 )))
315 })
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_server_config_default() {
324 let config = ServerConfig::default();
325 assert!(config.validate().is_ok());
326 assert_eq!(config.max_connections, 10_000);
327 assert_eq!(config.bind_address.port(), 8080);
328 }
329
330 #[test]
331 fn test_server_config_validation() {
332 let mut config = ServerConfig::default();
333 config.max_connections = 0;
334 assert!(config.validate().is_err());
335
336 config.max_connections = 1000;
337 config.max_frame_size = 0;
338 assert!(config.validate().is_err());
339
340 config.max_frame_size = 1024;
341 config.max_message_size = 512;
342 assert!(config.validate().is_err());
343 }
344
345 #[test]
346 fn test_tls_config() {
347 let config = TlsConfig::new("cert.pem".to_string(), "key.pem".to_string())
348 .cert_chain_file("chain.pem".to_string())
349 .client_auth(true)
350 .ca_file("ca.pem".to_string());
351
352 assert_eq!(config.cert_file, "cert.pem");
353 assert_eq!(config.key_file, "key.pem");
354 assert_eq!(config.cert_chain_file, Some("chain.pem".to_string()));
355 assert!(config.client_auth);
356 assert_eq!(config.ca_file, Some("ca.pem".to_string()));
357 }
358}