Skip to main content

ombrac_server/config/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12/// Transport configuration for QUIC connections
13#[derive(Deserialize, Serialize, Debug, Clone)]
14#[serde(rename_all = "snake_case")]
15pub struct TransportConfig {
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub tls_mode: Option<TlsMode>,
18
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub ca_cert: Option<PathBuf>,
21
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tls_cert: Option<PathBuf>,
24
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tls_key: Option<PathBuf>,
27
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub zero_rtt: Option<bool>,
30
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub alpn_protocols: Option<Vec<Vec<u8>>>,
33
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub congestion: Option<Congestion>,
36
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub cwnd_init: Option<u64>,
39
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub idle_timeout: Option<u64>,
42
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub keep_alive: Option<u64>,
45
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub max_streams: Option<u64>,
48}
49
50impl TransportConfig {
51    /// Get TLS mode with default
52    pub fn tls_mode(&self) -> TlsMode {
53        self.tls_mode.unwrap_or_default()
54    }
55
56    /// Get zero_rtt with default
57    pub fn zero_rtt(&self) -> bool {
58        self.zero_rtt.unwrap_or(false)
59    }
60
61    /// Get ALPN protocols with default
62    pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
63        self.alpn_protocols
64            .clone()
65            .unwrap_or_else(|| vec!["h3".into()])
66    }
67
68    /// Get congestion control with default
69    pub fn congestion(&self) -> Congestion {
70        self.congestion.unwrap_or(Congestion::Bbr)
71    }
72
73    /// Get idle timeout with default (in milliseconds)
74    pub fn idle_timeout(&self) -> u64 {
75        self.idle_timeout.unwrap_or(30000)
76    }
77
78    /// Get keep-alive interval with default (in milliseconds)
79    pub fn keep_alive(&self) -> u64 {
80        self.keep_alive.unwrap_or(8000)
81    }
82
83    /// Get max streams with default
84    pub fn max_streams(&self) -> u64 {
85        self.max_streams.unwrap_or(1000)
86    }
87}
88
89impl Default for TransportConfig {
90    fn default() -> Self {
91        Self {
92            tls_mode: Some(TlsMode::Tls),
93            ca_cert: None,
94            tls_cert: None,
95            tls_key: None,
96            zero_rtt: Some(false),
97            alpn_protocols: Some(vec!["h3".into()]),
98            congestion: Some(Congestion::Bbr),
99            cwnd_init: None,
100            idle_timeout: Some(30000),
101            keep_alive: Some(8000),
102            max_streams: Some(1000),
103        }
104    }
105}
106
107/// Connection-level configuration for managing connection lifecycle and resource limits
108#[derive(Deserialize, Serialize, Debug, Clone)]
109#[serde(rename_all = "snake_case")]
110pub struct ConnectionConfig {
111    /// Maximum number of concurrent connections [default: 10000]
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub max_connections: Option<usize>,
114
115    /// Authentication timeout in seconds [default: 10]
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub auth_timeout_secs: Option<u64>,
118
119    /// Maximum concurrent stream connections per client connection [default: 4096]
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub max_concurrent_streams: Option<usize>,
122
123    /// Maximum concurrent datagram handlers per client connection [default: 4096]
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub max_concurrent_datagrams: Option<usize>,
126}
127
128impl ConnectionConfig {
129    /// Get max connections with default
130    pub fn max_connections(&self) -> usize {
131        self.max_connections.unwrap_or(10000)
132    }
133
134    /// Get authentication timeout with default (in seconds)
135    pub fn auth_timeout_secs(&self) -> u64 {
136        self.auth_timeout_secs.unwrap_or(10)
137    }
138
139    /// Get handshake timeout with default (in seconds)
140    ///
141    /// Deprecated: Use `auth_timeout_secs` instead
142    #[deprecated(note = "Use auth_timeout_secs instead")]
143    pub fn handshake_timeout_secs(&self) -> u64 {
144        self.auth_timeout_secs()
145    }
146
147    /// Get max concurrent streams with default
148    pub fn max_concurrent_streams(&self) -> usize {
149        self.max_concurrent_streams.unwrap_or(4096)
150    }
151
152    /// Get max concurrent datagrams with default
153    pub fn max_concurrent_datagrams(&self) -> usize {
154        self.max_concurrent_datagrams.unwrap_or(4096)
155    }
156}
157
158impl Default for ConnectionConfig {
159    fn default() -> Self {
160        Self {
161            max_connections: Some(10000),
162            auth_timeout_secs: Some(10),
163            max_concurrent_streams: Some(4096),
164            max_concurrent_datagrams: Some(4096),
165        }
166    }
167}
168
169/// Logging configuration
170#[cfg(feature = "tracing")]
171#[derive(Deserialize, Serialize, Debug, Clone)]
172#[serde(rename_all = "snake_case")]
173pub struct LoggingConfig {
174    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub log_level: Option<String>,
177}
178
179#[cfg(feature = "tracing")]
180impl LoggingConfig {
181    /// Get log level with default
182    pub fn log_level(&self) -> &str {
183        self.log_level.as_deref().unwrap_or("INFO")
184    }
185}
186
187#[cfg(feature = "tracing")]
188impl Default for LoggingConfig {
189    fn default() -> Self {
190        Self {
191            log_level: Some("INFO".to_string()),
192        }
193    }
194}
195
196#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
197#[serde(rename_all = "kebab-case")]
198pub enum TlsMode {
199    #[default]
200    Tls,
201    MTls,
202    Insecure,
203}
204
205/// Final service configuration with all defaults applied
206#[derive(Debug, Clone)]
207pub struct ServiceConfig {
208    pub secret: String,
209    pub listen: SocketAddr,
210    pub transport: TransportConfig,
211    pub connection: ConnectionConfig,
212    #[cfg(feature = "tracing")]
213    pub logging: LoggingConfig,
214}
215
216/// Configuration builder that merges different configuration sources
217/// and applies defaults in a clear, predictable order
218pub struct ConfigBuilder {
219    secret: Option<String>,
220    listen: Option<SocketAddr>,
221    transport: TransportConfig,
222    connection: ConnectionConfig,
223    #[cfg(feature = "tracing")]
224    logging: LoggingConfig,
225}
226
227impl ConfigBuilder {
228    /// Create a new builder with default values
229    pub fn new() -> Self {
230        Self {
231            secret: None,
232            listen: None,
233            transport: TransportConfig::default(),
234            connection: ConnectionConfig::default(),
235            #[cfg(feature = "tracing")]
236            logging: LoggingConfig::default(),
237        }
238    }
239
240    /// Merge JSON configuration into the builder
241    pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
242        if let Some(secret) = json_config.secret {
243            self.secret = Some(secret);
244        }
245        if let Some(listen) = json_config.listen {
246            self.listen = Some(listen);
247        }
248        if let Some(transport) = json_config.transport {
249            self.transport = Self::merge_transport(self.transport, transport);
250        }
251        if let Some(conn) = json_config.connection {
252            self.connection = Self::merge_connection(self.connection, conn);
253        }
254        #[cfg(feature = "tracing")]
255        {
256            if let Some(logging) = json_config.logging {
257                self.logging = Self::merge_logging(self.logging, logging);
258            }
259        }
260        self
261    }
262
263    /// Merge CLI configuration into the builder (CLI overrides JSON)
264    pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
265        if let Some(secret) = cli_config.secret {
266            self.secret = Some(secret);
267        }
268        if let Some(listen) = cli_config.listen {
269            self.listen = Some(listen);
270        }
271        self.transport = Self::merge_transport(self.transport, cli_config.transport);
272        #[cfg(feature = "tracing")]
273        {
274            self.logging = Self::merge_logging(self.logging, cli_config.logging);
275        }
276        self
277    }
278
279    /// Build the final ServiceConfig, validating required fields
280    pub fn build(self) -> Result<ServiceConfig, String> {
281        let secret = self
282            .secret
283            .ok_or_else(|| "missing required field: secret".to_string())?;
284        let listen = self
285            .listen
286            .ok_or_else(|| "missing required field: listen".to_string())?;
287
288        Ok(ServiceConfig {
289            secret,
290            listen,
291            transport: self.transport,
292            connection: self.connection,
293            #[cfg(feature = "tracing")]
294            logging: self.logging,
295        })
296    }
297
298    fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
299        TransportConfig {
300            tls_mode: override_config.tls_mode.or(base.tls_mode),
301            ca_cert: override_config.ca_cert.or(base.ca_cert),
302            tls_cert: override_config.tls_cert.or(base.tls_cert),
303            tls_key: override_config.tls_key.or(base.tls_key),
304            zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
305            alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
306            congestion: override_config.congestion.or(base.congestion),
307            cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
308            idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
309            keep_alive: override_config.keep_alive.or(base.keep_alive),
310            max_streams: override_config.max_streams.or(base.max_streams),
311        }
312    }
313
314    fn merge_connection(
315        base: ConnectionConfig,
316        override_config: ConnectionConfig,
317    ) -> ConnectionConfig {
318        ConnectionConfig {
319            max_connections: override_config.max_connections.or(base.max_connections),
320            auth_timeout_secs: override_config.auth_timeout_secs.or(base.auth_timeout_secs),
321            max_concurrent_streams: override_config
322                .max_concurrent_streams
323                .or(base.max_concurrent_streams),
324            max_concurrent_datagrams: override_config
325                .max_concurrent_datagrams
326                .or(base.max_concurrent_datagrams),
327        }
328    }
329
330    #[cfg(feature = "tracing")]
331    fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
332        LoggingConfig {
333            log_level: override_config.log_level.or(base.log_level),
334        }
335    }
336}
337
338impl Default for ConfigBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344/// Load configuration from command-line arguments and/or JSON file.
345///
346/// This function merges configurations in the following order:
347/// 1. Default configuration values
348/// 2. Values from JSON config file (if provided)
349/// 3. Command-line argument overrides
350///
351/// # Returns
352///
353/// A `ServiceConfig` ready to use, or an error if required fields are missing.
354#[cfg(feature = "binary")]
355pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
356    use clap::Parser;
357    let cli_args = cli::Args::parse();
358    let mut builder = ConfigBuilder::new();
359
360    // Load JSON config if specified
361    if let Some(config_path) = &cli_args.config {
362        let json_config = json::JsonConfig::from_file(config_path)?;
363        builder = builder.merge_json(json_config);
364    }
365
366    // Merge CLI overrides
367    let cli_config = cli::CliConfig {
368        secret: cli_args.secret,
369        listen: cli_args.listen,
370        transport: cli_args.transport.into_transport_config(),
371        #[cfg(feature = "tracing")]
372        logging: cli_args.logging.into_logging_config(),
373    };
374    builder = builder.merge_cli(cli_config);
375
376    builder.build().map_err(|e| e.into())
377}
378
379/// Loads configuration from a JSON string.
380///
381/// This function is useful for programmatic configuration or when loading
382/// from external sources (e.g., environment variables, API responses).
383///
384/// # Arguments
385///
386/// * `json_str` - A JSON string containing the configuration
387///
388/// # Returns
389///
390/// A `ServiceConfig` ready to use, or an error if parsing fails or required fields are missing.
391pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
392    let json_config = json::JsonConfig::from_json_str(json_str)?;
393    ConfigBuilder::new()
394        .merge_json(json_config)
395        .build()
396        .map_err(|e| e.into())
397}
398
399/// Loads configuration from a JSON file.
400///
401/// This function reads configuration from a file path.
402///
403/// # Arguments
404///
405/// * `config_path` - Path to the JSON configuration file
406///
407/// # Returns
408///
409/// A `ServiceConfig` ready to use, or an error if the file doesn't exist,
410/// parsing fails, or required fields are missing.
411pub fn load_from_file(
412    config_path: &std::path::Path,
413) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
414    let json_config = json::JsonConfig::from_file(config_path)?;
415    ConfigBuilder::new()
416        .merge_json(json_config)
417        .build()
418        .map_err(|e| e.into())
419}