Skip to main content

ombrac_server/config/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12/// Transport configuration for QUIC connections
13#[derive(Deserialize, Serialize, Debug, Clone)]
14#[serde(rename_all = "snake_case")]
15pub struct TransportConfig {
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub tls_mode: Option<TlsMode>,
18
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub ca_cert: Option<PathBuf>,
21
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tls_cert: Option<PathBuf>,
24
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tls_key: Option<PathBuf>,
27
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub zero_rtt: Option<bool>,
30
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub alpn_protocols: Option<Vec<Vec<u8>>>,
33
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub congestion: Option<Congestion>,
36
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub cwnd_init: Option<u64>,
39
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub idle_timeout: Option<u64>,
42
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub keep_alive: Option<u64>,
45
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub max_streams: Option<u64>,
48}
49
50impl TransportConfig {
51    /// Get TLS mode with default
52    pub fn tls_mode(&self) -> TlsMode {
53        self.tls_mode.unwrap_or_default()
54    }
55
56    /// Get zero_rtt with default
57    pub fn zero_rtt(&self) -> bool {
58        self.zero_rtt.unwrap_or(false)
59    }
60
61    /// Get ALPN protocols with default
62    pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
63        self.alpn_protocols
64            .clone()
65            .unwrap_or_else(|| vec!["h3".into()])
66    }
67
68    /// Get congestion control with default
69    pub fn congestion(&self) -> Congestion {
70        self.congestion.unwrap_or(Congestion::Bbr)
71    }
72
73    /// Get idle timeout with default (in milliseconds)
74    pub fn idle_timeout(&self) -> u64 {
75        self.idle_timeout.unwrap_or(30000)
76    }
77
78    /// Get keep-alive interval with default (in milliseconds)
79    pub fn keep_alive(&self) -> u64 {
80        self.keep_alive.unwrap_or(8000)
81    }
82
83    /// Get max streams with default
84    pub fn max_streams(&self) -> u64 {
85        self.max_streams.unwrap_or(1000)
86    }
87}
88
89impl Default for TransportConfig {
90    fn default() -> Self {
91        Self {
92            tls_mode: Some(TlsMode::Tls),
93            ca_cert: None,
94            tls_cert: None,
95            tls_key: None,
96            zero_rtt: Some(false),
97            alpn_protocols: Some(vec!["h3".into()]),
98            congestion: Some(Congestion::Bbr),
99            cwnd_init: None,
100            idle_timeout: Some(30000),
101            keep_alive: Some(8000),
102            max_streams: Some(1000),
103        }
104    }
105}
106
107/// Connection-level configuration for managing connection lifecycle and resource limits
108#[derive(Deserialize, Serialize, Debug, Clone)]
109#[serde(rename_all = "snake_case")]
110pub struct ConnectionConfig {
111    /// Maximum number of concurrent connections [default: 10000]
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub max_connections: Option<usize>,
114
115    /// Authentication timeout in seconds [default: 10]
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub auth_timeout_secs: Option<u64>,
118
119    /// Maximum concurrent stream connections per client connection [default: 4096]
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub max_concurrent_streams: Option<usize>,
122
123    /// Maximum concurrent datagram handlers per client connection [default: 4096]
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub max_concurrent_datagrams: Option<usize>,
126}
127
128impl ConnectionConfig {
129    /// Get max connections with default
130    pub fn max_connections(&self) -> usize {
131        self.max_connections.unwrap_or(10000)
132    }
133
134    /// Get authentication timeout with default (in seconds)
135    pub fn auth_timeout_secs(&self) -> u64 {
136        self.auth_timeout_secs.unwrap_or(10)
137    }
138
139    /// Get max concurrent streams with default
140    pub fn max_concurrent_streams(&self) -> usize {
141        self.max_concurrent_streams.unwrap_or(4096)
142    }
143
144    /// Get max concurrent datagrams with default
145    pub fn max_concurrent_datagrams(&self) -> usize {
146        self.max_concurrent_datagrams.unwrap_or(4096)
147    }
148}
149
150impl Default for ConnectionConfig {
151    fn default() -> Self {
152        Self {
153            max_connections: Some(10000),
154            auth_timeout_secs: Some(10),
155            max_concurrent_streams: Some(4096),
156            max_concurrent_datagrams: Some(4096),
157        }
158    }
159}
160
161/// Logging configuration
162#[cfg(feature = "tracing")]
163#[derive(Deserialize, Serialize, Debug, Clone)]
164#[serde(rename_all = "snake_case")]
165pub struct LoggingConfig {
166    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub log_level: Option<String>,
169}
170
171#[cfg(feature = "tracing")]
172impl LoggingConfig {
173    /// Get log level with default
174    pub fn log_level(&self) -> &str {
175        self.log_level.as_deref().unwrap_or("INFO")
176    }
177}
178
179#[cfg(feature = "tracing")]
180impl Default for LoggingConfig {
181    fn default() -> Self {
182        Self {
183            log_level: Some("INFO".to_string()),
184        }
185    }
186}
187
188#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
189#[serde(rename_all = "kebab-case")]
190pub enum TlsMode {
191    #[default]
192    Tls,
193    MTls,
194    Insecure,
195}
196
197/// Final service configuration with all defaults applied
198#[derive(Debug, Clone)]
199pub struct ServiceConfig {
200    pub secret: String,
201    pub listen: SocketAddr,
202    pub transport: TransportConfig,
203    pub connection: ConnectionConfig,
204    #[cfg(feature = "tracing")]
205    pub logging: LoggingConfig,
206}
207
208/// Configuration builder that merges different configuration sources
209/// and applies defaults in a clear, predictable order
210pub struct ConfigBuilder {
211    secret: Option<String>,
212    listen: Option<SocketAddr>,
213    transport: TransportConfig,
214    connection: ConnectionConfig,
215    #[cfg(feature = "tracing")]
216    logging: LoggingConfig,
217}
218
219impl ConfigBuilder {
220    /// Create a new builder with default values
221    pub fn new() -> Self {
222        Self {
223            secret: None,
224            listen: None,
225            transport: TransportConfig::default(),
226            connection: ConnectionConfig::default(),
227            #[cfg(feature = "tracing")]
228            logging: LoggingConfig::default(),
229        }
230    }
231
232    /// Merge JSON configuration into the builder
233    pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
234        if let Some(secret) = json_config.secret {
235            self.secret = Some(secret);
236        }
237        if let Some(listen) = json_config.listen {
238            self.listen = Some(listen);
239        }
240        if let Some(transport) = json_config.transport {
241            self.transport = Self::merge_transport(self.transport, transport);
242        }
243        if let Some(conn) = json_config.connection {
244            self.connection = Self::merge_connection(self.connection, conn);
245        }
246        #[cfg(feature = "tracing")]
247        {
248            if let Some(logging) = json_config.logging {
249                self.logging = Self::merge_logging(self.logging, logging);
250            }
251        }
252        self
253    }
254
255    /// Merge CLI configuration into the builder (CLI overrides JSON)
256    pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
257        if let Some(secret) = cli_config.secret {
258            self.secret = Some(secret);
259        }
260        if let Some(listen) = cli_config.listen {
261            self.listen = Some(listen);
262        }
263        self.transport = Self::merge_transport(self.transport, cli_config.transport);
264        #[cfg(feature = "tracing")]
265        {
266            self.logging = Self::merge_logging(self.logging, cli_config.logging);
267        }
268        self
269    }
270
271    /// Build the final ServiceConfig, validating required fields
272    pub fn build(self) -> Result<ServiceConfig, String> {
273        let secret = self
274            .secret
275            .ok_or_else(|| "missing required field: secret".to_string())?;
276        let listen = self
277            .listen
278            .ok_or_else(|| "missing required field: listen".to_string())?;
279
280        Ok(ServiceConfig {
281            secret,
282            listen,
283            transport: self.transport,
284            connection: self.connection,
285            #[cfg(feature = "tracing")]
286            logging: self.logging,
287        })
288    }
289
290    fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
291        TransportConfig {
292            tls_mode: override_config.tls_mode.or(base.tls_mode),
293            ca_cert: override_config.ca_cert.or(base.ca_cert),
294            tls_cert: override_config.tls_cert.or(base.tls_cert),
295            tls_key: override_config.tls_key.or(base.tls_key),
296            zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
297            alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
298            congestion: override_config.congestion.or(base.congestion),
299            cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
300            idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
301            keep_alive: override_config.keep_alive.or(base.keep_alive),
302            max_streams: override_config.max_streams.or(base.max_streams),
303        }
304    }
305
306    fn merge_connection(
307        base: ConnectionConfig,
308        override_config: ConnectionConfig,
309    ) -> ConnectionConfig {
310        ConnectionConfig {
311            max_connections: override_config.max_connections.or(base.max_connections),
312            auth_timeout_secs: override_config.auth_timeout_secs.or(base.auth_timeout_secs),
313            max_concurrent_streams: override_config
314                .max_concurrent_streams
315                .or(base.max_concurrent_streams),
316            max_concurrent_datagrams: override_config
317                .max_concurrent_datagrams
318                .or(base.max_concurrent_datagrams),
319        }
320    }
321
322    #[cfg(feature = "tracing")]
323    fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
324        LoggingConfig {
325            log_level: override_config.log_level.or(base.log_level),
326        }
327    }
328}
329
330impl Default for ConfigBuilder {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336/// Load configuration from command-line arguments and/or JSON file.
337///
338/// This function merges configurations in the following order:
339/// 1. Default configuration values
340/// 2. Values from JSON config file (if provided)
341/// 3. Command-line argument overrides
342///
343/// # Returns
344///
345/// A `ServiceConfig` ready to use, or an error if required fields are missing.
346#[cfg(feature = "binary")]
347pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
348    use clap::Parser;
349    let cli_args = cli::Args::parse();
350    let mut builder = ConfigBuilder::new();
351
352    // Load JSON config if specified
353    if let Some(config_path) = &cli_args.config {
354        let json_config = json::JsonConfig::from_file(config_path)?;
355        builder = builder.merge_json(json_config);
356    }
357
358    // Merge CLI overrides
359    let cli_config = cli::CliConfig {
360        secret: cli_args.secret,
361        listen: cli_args.listen,
362        transport: cli_args.transport.into_transport_config(),
363        #[cfg(feature = "tracing")]
364        logging: cli_args.logging.into_logging_config(),
365    };
366    builder = builder.merge_cli(cli_config);
367
368    builder.build().map_err(|e| e.into())
369}
370
371/// Loads configuration from a JSON string.
372///
373/// This function is useful for programmatic configuration or when loading
374/// from external sources (e.g., environment variables, API responses).
375///
376/// # Arguments
377///
378/// * `json_str` - A JSON string containing the configuration
379///
380/// # Returns
381///
382/// A `ServiceConfig` ready to use, or an error if parsing fails or required fields are missing.
383pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
384    let json_config = json::JsonConfig::from_json_str(json_str)?;
385    ConfigBuilder::new()
386        .merge_json(json_config)
387        .build()
388        .map_err(|e| e.into())
389}
390
391/// Loads configuration from a JSON file.
392///
393/// This function reads configuration from a file path.
394///
395/// # Arguments
396///
397/// * `config_path` - Path to the JSON configuration file
398///
399/// # Returns
400///
401/// A `ServiceConfig` ready to use, or an error if the file doesn't exist,
402/// parsing fails, or required fields are missing.
403pub fn load_from_file(
404    config_path: &std::path::Path,
405) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
406    let json_config = json::JsonConfig::from_file(config_path)?;
407    ConfigBuilder::new()
408        .merge_json(json_config)
409        .build()
410        .map_err(|e| e.into())
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn load_from_json_minimal_uses_defaults() {
419        let json = r#"{
420            "secret": "k",
421            "listen": "0.0.0.0:443"
422        }"#;
423        let cfg = load_from_json(json).unwrap();
424        assert_eq!(cfg.secret, "k");
425        assert_eq!(cfg.listen.to_string(), "0.0.0.0:443");
426        assert_eq!(cfg.transport.tls_mode, Some(TlsMode::Tls));
427        assert_eq!(cfg.transport.idle_timeout, Some(30000));
428        assert_eq!(cfg.connection.max_connections, Some(10000));
429        assert_eq!(cfg.connection.auth_timeout_secs, Some(10));
430        assert_eq!(cfg.connection.max_concurrent_streams, Some(4096));
431    }
432
433    #[test]
434    fn load_from_json_missing_secret_fails() {
435        let json = r#"{ "listen": "0.0.0.0:443" }"#;
436        let err = load_from_json(json).unwrap_err();
437        assert!(err.to_string().contains("secret"));
438    }
439
440    #[test]
441    fn load_from_json_missing_listen_fails() {
442        let json = r#"{ "secret": "k" }"#;
443        let err = load_from_json(json).unwrap_err();
444        assert!(err.to_string().contains("listen"));
445    }
446
447    #[test]
448    fn load_from_json_invalid_listen_address_fails() {
449        let json = r#"{ "secret": "k", "listen": "not-an-address" }"#;
450        let result = load_from_json(json);
451        assert!(result.is_err());
452    }
453
454    #[test]
455    fn load_from_json_overrides_transport() {
456        let json = r#"{
457            "secret": "k",
458            "listen": "127.0.0.1:443",
459            "transport": {
460                "tls_mode": "m-tls",
461                "idle_timeout": 12345,
462                "max_streams": 999
463            }
464        }"#;
465        let cfg = load_from_json(json).unwrap();
466        assert_eq!(cfg.transport.tls_mode, Some(TlsMode::MTls));
467        assert_eq!(cfg.transport.idle_timeout, Some(12345));
468        assert_eq!(cfg.transport.max_streams, Some(999));
469    }
470
471    #[test]
472    fn load_from_json_overrides_connection_limits() {
473        let json = r#"{
474            "secret": "k",
475            "listen": "127.0.0.1:443",
476            "connection": {
477                "max_connections": 500,
478                "auth_timeout_secs": 5,
479                "max_concurrent_streams": 100,
480                "max_concurrent_datagrams": 200
481            }
482        }"#;
483        let cfg = load_from_json(json).unwrap();
484        assert_eq!(cfg.connection.max_connections, Some(500));
485        assert_eq!(cfg.connection.auth_timeout_secs, Some(5));
486        assert_eq!(cfg.connection.max_concurrent_streams, Some(100));
487        assert_eq!(cfg.connection.max_concurrent_datagrams, Some(200));
488    }
489
490    #[test]
491    fn cli_overrides_json_in_merge_order() {
492        let json = json::JsonConfig {
493            secret: Some("from_json".into()),
494            listen: Some("0.0.0.0:5555".parse().unwrap()),
495            transport: Some(TransportConfig {
496                idle_timeout: Some(11111),
497                keep_alive: Some(2222),
498                ..Default::default()
499            }),
500            connection: None,
501            #[cfg(feature = "tracing")]
502            logging: None,
503        };
504
505        let cli = cli::CliConfig {
506            secret: None, // JSON wins
507            listen: Some("127.0.0.1:6666".parse().unwrap()), // CLI wins
508            transport: TransportConfig {
509                idle_timeout: Some(99999), // CLI wins
510                keep_alive: None,          // JSON wins
511                ..Default::default()
512            },
513            #[cfg(feature = "tracing")]
514            logging: LoggingConfig::default(),
515        };
516
517        let cfg = ConfigBuilder::new()
518            .merge_json(json)
519            .merge_cli(cli)
520            .build()
521            .unwrap();
522
523        assert_eq!(cfg.secret, "from_json");
524        assert_eq!(cfg.listen.to_string(), "127.0.0.1:6666");
525        assert_eq!(cfg.transport.idle_timeout, Some(99999));
526        assert_eq!(cfg.transport.keep_alive, Some(2222));
527    }
528
529    #[test]
530    fn transport_config_accessors_apply_defaults_on_none() {
531        let cfg = TransportConfig {
532            tls_mode: None,
533            ca_cert: None,
534            tls_cert: None,
535            tls_key: None,
536            zero_rtt: None,
537            alpn_protocols: None,
538            congestion: None,
539            cwnd_init: None,
540            idle_timeout: None,
541            keep_alive: None,
542            max_streams: None,
543        };
544        assert_eq!(cfg.tls_mode(), TlsMode::Tls);
545        assert!(!cfg.zero_rtt());
546        assert_eq!(cfg.idle_timeout(), 30000);
547        assert_eq!(cfg.keep_alive(), 8000);
548        assert_eq!(cfg.max_streams(), 1000);
549        assert_eq!(cfg.alpn_protocols(), vec![b"h3".to_vec()]);
550    }
551
552    #[test]
553    fn connection_config_accessors_apply_defaults_on_none() {
554        let cfg = ConnectionConfig {
555            max_connections: None,
556            auth_timeout_secs: None,
557            max_concurrent_streams: None,
558            max_concurrent_datagrams: None,
559        };
560        assert_eq!(cfg.max_connections(), 10000);
561        assert_eq!(cfg.auth_timeout_secs(), 10);
562        assert_eq!(cfg.max_concurrent_streams(), 4096);
563        assert_eq!(cfg.max_concurrent_datagrams(), 4096);
564    }
565
566    #[test]
567    fn tls_mode_kebab_case_serialization() {
568        assert_eq!(serde_json::to_string(&TlsMode::Tls).unwrap(), "\"tls\"");
569        assert_eq!(
570            serde_json::to_string(&TlsMode::MTls).unwrap(),
571            "\"m-tls\""
572        );
573        assert_eq!(
574            serde_json::to_string(&TlsMode::Insecure).unwrap(),
575            "\"insecure\""
576        );
577        assert_eq!(TlsMode::default(), TlsMode::Tls);
578    }
579
580    #[test]
581    fn json_config_roundtrips() {
582        let original = r#"{
583            "secret": "abc",
584            "listen": "0.0.0.0:443",
585            "transport": { "tls_mode": "insecure", "max_streams": 50 },
586            "connection": { "max_connections": 100 }
587        }"#;
588        let parsed = json::JsonConfig::from_json_str(original).unwrap();
589        let s = serde_json::to_string(&parsed).unwrap();
590        let reparsed = json::JsonConfig::from_json_str(&s).unwrap();
591        assert_eq!(reparsed.secret.as_deref(), Some("abc"));
592    }
593
594    #[test]
595    fn load_from_file_missing_path_returns_error() {
596        let p = std::path::Path::new("/no/such/file/srvcfg.json");
597        assert!(load_from_file(p).is_err());
598    }
599
600    #[test]
601    fn load_from_file_reads_real_file() {
602        let path = std::env::temp_dir()
603            .join(format!("ombrac-server-cfg-{}.json", std::process::id()));
604        std::fs::write(
605            &path,
606            r#"{"secret":"abc","listen":"127.0.0.1:9999"}"#,
607        )
608        .unwrap();
609
610        let cfg = load_from_file(&path).unwrap();
611        assert_eq!(cfg.secret, "abc");
612        assert_eq!(cfg.listen.to_string(), "127.0.0.1:9999");
613
614        std::fs::remove_file(&path).ok();
615    }
616}