1use aerosocket_core::error::{ConfigError, Error};
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct ServerConfig {
11 pub bind_address: std::net::SocketAddr,
13 pub max_connections: usize,
15 pub max_frame_size: usize,
17 pub max_message_size: usize,
19 pub handshake_timeout: Duration,
21 pub idle_timeout: Duration,
23 pub compression: CompressionConfig,
25 pub backpressure: BackpressureConfig,
27 pub tls: Option<TlsConfig>,
29 pub transport_type: TransportType,
31 pub supported_protocols: Vec<String>,
33 pub supported_extensions: Vec<String>,
35 pub allowed_origin: Option<String>,
37 pub extra_headers: std::collections::HashMap<String, String>,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum TransportType {
44 Tcp,
46 Tls,
48}
49
50impl Default for ServerConfig {
51 fn default() -> Self {
52 Self {
53 bind_address: "0.0.0.0:8080".parse().unwrap(),
54 max_connections: 10_000,
55 max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
56 max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
57 handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
58 idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
59 compression: CompressionConfig::default(),
60 backpressure: BackpressureConfig::default(),
61 tls: None,
62 transport_type: TransportType::Tcp,
63 supported_protocols: vec![],
64 supported_extensions: vec![],
65 allowed_origin: None,
66 extra_headers: std::collections::HashMap::new(),
67 }
68 }
69}
70
71impl ServerConfig {
72 pub fn validate(&self) -> aerosocket_core::Result<()> {
74 if self.max_connections == 0 {
75 return Err(Error::Config(ConfigError::Validation(
76 "max_connections must be greater than 0".to_string(),
77 )));
78 }
79
80 if self.max_frame_size == 0 {
81 return Err(Error::Config(ConfigError::Validation(
82 "max_frame_size must be greater than 0".to_string(),
83 )));
84 }
85
86 if self.max_message_size == 0 {
87 return Err(Error::Config(ConfigError::Validation(
88 "max_message_size must be greater than 0".to_string(),
89 )));
90 }
91
92 if self.max_message_size < self.max_frame_size {
93 return Err(Error::Config(ConfigError::Validation(
94 "max_message_size must be greater than or equal to max_frame_size".to_string(),
95 )));
96 }
97
98 Ok(())
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct CompressionConfig {
105 pub enabled: bool,
107 pub level: u8,
109 pub server_context_takeover: bool,
111 pub client_context_takeover: bool,
113 pub server_max_window_bits: Option<u8>,
115 pub client_max_window_bits: Option<u8>,
117}
118
119impl Default for CompressionConfig {
120 fn default() -> Self {
121 Self {
122 enabled: false,
123 level: 6,
124 server_context_takeover: true,
125 client_context_takeover: true,
126 server_max_window_bits: None,
127 client_max_window_bits: None,
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct BackpressureConfig {
135 pub enabled: bool,
137 pub max_requests_per_minute: usize,
139 pub strategy: BackpressureStrategy,
141 pub buffer_size: usize,
143 pub high_water_mark: usize,
145 pub low_water_mark: usize,
147}
148
149impl Default for BackpressureConfig {
150 fn default() -> Self {
151 Self {
152 enabled: true,
153 max_requests_per_minute: 60,
154 strategy: BackpressureStrategy::Buffer,
155 buffer_size: 64 * 1024, high_water_mark: 48 * 1024, low_water_mark: 16 * 1024, }
159 }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum BackpressureStrategy {
165 Buffer,
167 DropOldest,
169 Reject,
171 FlowControl,
173}
174
175#[derive(Debug, Clone)]
177pub struct TlsConfig {
178 pub cert_file: String,
180 pub key_file: String,
182 pub cert_chain_file: Option<String>,
184 pub client_auth: bool,
186 pub ca_file: Option<String>,
188}
189
190impl TlsConfig {
191 pub fn new(cert_file: String, key_file: String) -> Self {
193 Self {
194 cert_file,
195 key_file,
196 cert_chain_file: None,
197 client_auth: false,
198 ca_file: None,
199 }
200 }
201
202 pub fn cert_chain_file(mut self, file: String) -> Self {
204 self.cert_chain_file = Some(file);
205 self
206 }
207
208 pub fn client_auth(mut self, enabled: bool) -> Self {
210 self.client_auth = enabled;
211 self
212 }
213
214 pub fn ca_file(mut self, file: String) -> Self {
216 self.ca_file = Some(file);
217 self
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_server_config_default() {
227 let config = ServerConfig::default();
228 assert!(config.validate().is_ok());
229 assert_eq!(config.max_connections, 10_000);
230 assert_eq!(config.bind_address.port(), 8080);
231 }
232
233 #[test]
234 fn test_server_config_validation() {
235 let mut config = ServerConfig::default();
236 config.max_connections = 0;
237 assert!(config.validate().is_err());
238
239 config.max_connections = 1000;
240 config.max_frame_size = 0;
241 assert!(config.validate().is_err());
242
243 config.max_frame_size = 1024;
244 config.max_message_size = 512;
245 assert!(config.validate().is_err());
246 }
247
248 #[test]
249 fn test_tls_config() {
250 let config = TlsConfig::new("cert.pem".to_string(), "key.pem".to_string())
251 .cert_chain_file("chain.pem".to_string())
252 .client_auth(true)
253 .ca_file("ca.pem".to_string());
254
255 assert_eq!(config.cert_file, "cert.pem");
256 assert_eq!(config.key_file, "key.pem");
257 assert_eq!(config.cert_chain_file, Some("chain.pem".to_string()));
258 assert!(config.client_auth);
259 assert_eq!(config.ca_file, Some("ca.pem".to_string()));
260 }
261}