aerosocket_client/
config.rs

1//! Client configuration for AeroSocket
2//!
3//! This module provides configuration options for WebSocket clients.
4
5use aerosocket_core::error::ConfigError;
6use aerosocket_core::Error;
7use std::time::Duration;
8
9/// Client configuration
10#[derive(Debug, Clone)]
11pub struct ClientConfig {
12    /// Maximum frame size in bytes
13    pub max_frame_size: usize,
14    /// Maximum message size in bytes
15    pub max_message_size: usize,
16    /// Handshake timeout
17    pub handshake_timeout: Duration,
18    /// Idle timeout
19    pub idle_timeout: Duration,
20    /// Compression configuration
21    pub compression: CompressionConfig,
22    /// TLS configuration
23    pub tls: Option<TlsConfig>,
24    /// User agent string
25    pub user_agent: String,
26    /// Origin header
27    pub origin: Option<String>,
28    /// WebSocket subprotocols
29    pub protocols: Vec<String>,
30    /// Custom headers
31    pub headers: Vec<(String, String)>,
32}
33
34impl Default for ClientConfig {
35    fn default() -> Self {
36        Self {
37            max_frame_size: aerosocket_core::protocol::constants::DEFAULT_MAX_FRAME_SIZE,
38            max_message_size: aerosocket_core::protocol::constants::DEFAULT_MAX_MESSAGE_SIZE,
39            handshake_timeout: aerosocket_core::protocol::constants::DEFAULT_HANDSHAKE_TIMEOUT,
40            idle_timeout: aerosocket_core::protocol::constants::DEFAULT_IDLE_TIMEOUT,
41            compression: CompressionConfig::default(),
42            tls: None,
43            user_agent: format!("aerosocket-client/{}", env!("CARGO_PKG_VERSION")),
44            origin: None,
45            protocols: Vec::new(),
46            headers: Vec::new(),
47        }
48    }
49}
50
51impl ClientConfig {
52    /// Validate the configuration
53    pub fn validate(&self) -> aerosocket_core::Result<()> {
54        if self.max_frame_size == 0 {
55            return Err(Error::Config(ConfigError::Validation(
56                "max_frame_size must be greater than 0".to_string(),
57            )));
58        }
59
60        if self.max_message_size == 0 {
61            return Err(Error::Config(ConfigError::Validation(
62                "max_message_size must be greater than 0".to_string(),
63            )));
64        }
65
66        if self.max_message_size < self.max_frame_size {
67            return Err(Error::Config(ConfigError::Validation(
68                "max_message_size must be greater than or equal to max_frame_size".to_string(),
69            )));
70        }
71
72        if self.handshake_timeout.is_zero() {
73            return Err(Error::Config(ConfigError::Validation(
74                "handshake_timeout must be greater than 0".to_string(),
75            )));
76        }
77
78        Ok(())
79    }
80
81    /// Set maximum frame size
82    pub fn max_frame_size(mut self, size: usize) -> Self {
83        self.max_frame_size = size;
84        self
85    }
86
87    /// Set maximum message size
88    pub fn max_message_size(mut self, size: usize) -> Self {
89        self.max_message_size = size;
90        self
91    }
92
93    /// Set handshake timeout
94    pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
95        self.handshake_timeout = timeout;
96        self
97    }
98
99    /// Set idle timeout
100    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
101        self.idle_timeout = timeout;
102        self
103    }
104
105    /// Set user agent
106    pub fn user_agent(mut self, agent: String) -> Self {
107        self.user_agent = agent;
108        self
109    }
110
111    /// Set origin
112    pub fn origin(mut self, origin: String) -> Self {
113        self.origin = Some(origin);
114        self
115    }
116
117    /// Add a subprotocol
118    pub fn add_protocol(mut self, protocol: String) -> Self {
119        self.protocols.push(protocol);
120        self
121    }
122
123    /// Add a custom header
124    pub fn add_header(mut self, name: String, value: String) -> Self {
125        self.headers.push((name, value));
126        self
127    }
128
129    /// Set TLS configuration
130    pub fn tls(mut self, config: TlsConfig) -> Self {
131        self.tls = Some(config);
132        self
133    }
134}
135
136/// Compression configuration
137#[derive(Debug, Clone)]
138pub struct CompressionConfig {
139    /// Enable compression
140    pub enabled: bool,
141    /// Compression level (0-9)
142    pub level: u8,
143    /// Server context takeover
144    pub server_context_takeover: bool,
145    /// Client context takeover
146    pub client_context_takeover: bool,
147    /// Server max window bits
148    pub server_max_window_bits: Option<u8>,
149    /// Client max window bits
150    pub client_max_window_bits: Option<u8>,
151}
152
153impl Default for CompressionConfig {
154    fn default() -> Self {
155        Self {
156            enabled: false,
157            level: 6,
158            server_context_takeover: true,
159            client_context_takeover: true,
160            server_max_window_bits: None,
161            client_max_window_bits: None,
162        }
163    }
164}
165
166/// TLS configuration
167#[derive(Debug, Clone)]
168pub struct TlsConfig {
169    /// Enable TLS verification
170    pub verify: bool,
171    /// Path to CA certificate file
172    pub ca_file: Option<String>,
173    /// Path to client certificate file
174    pub cert_file: Option<String>,
175    /// Path to client private key file
176    pub key_file: Option<String>,
177    /// Server name for SNI
178    pub server_name: Option<String>,
179    /// Minimum TLS version
180    pub min_version: Option<TlsVersion>,
181    /// Maximum TLS version
182    pub max_version: Option<TlsVersion>,
183}
184
185/// TLS version
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub enum TlsVersion {
188    /// TLS 1.0
189    V1_0,
190    /// TLS 1.1
191    V1_1,
192    /// TLS 1.2
193    V1_2,
194    /// TLS 1.3
195    V1_3,
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_client_config_default() {
204        let config = ClientConfig::default();
205        assert!(config.validate().is_ok());
206        assert_eq!(config.max_frame_size, 16 * 1024 * 1024); // 16MB
207        assert_eq!(config.max_message_size, 64 * 1024 * 1024); // 64MB
208    }
209
210    #[test]
211    fn test_client_config_validation() {
212        let config = ClientConfig {
213            max_frame_size: 0,
214            ..Default::default()
215        };
216        assert!(config.validate().is_err());
217
218        let config = ClientConfig {
219            max_frame_size: 1024,
220            max_message_size: 512,
221            ..Default::default()
222        };
223        assert!(config.validate().is_err());
224    }
225
226    #[test]
227    fn test_client_config_builder() {
228        let config = ClientConfig::default()
229            .max_frame_size(2048)
230            .user_agent("test-agent".to_string())
231            .add_protocol("chat".to_string())
232            .add_header("X-Custom".to_string(), "value".to_string());
233
234        assert_eq!(config.max_frame_size, 2048);
235        assert_eq!(config.user_agent, "test-agent");
236        assert_eq!(config.protocols.len(), 1);
237        assert_eq!(config.protocols[0], "chat");
238        assert_eq!(config.headers.len(), 1);
239        assert_eq!(
240            config.headers[0],
241            ("X-Custom".to_string(), "value".to_string())
242        );
243    }
244}