1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12#[derive(Deserialize, Serialize, Debug, Clone)]
14#[serde(rename_all = "snake_case")]
15pub struct TransportConfig {
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub tls_mode: Option<TlsMode>,
18
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub ca_cert: Option<PathBuf>,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tls_cert: Option<PathBuf>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tls_key: Option<PathBuf>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub zero_rtt: Option<bool>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub alpn_protocols: Option<Vec<Vec<u8>>>,
33
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub congestion: Option<Congestion>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub cwnd_init: Option<u64>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub idle_timeout: Option<u64>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub keep_alive: Option<u64>,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub max_streams: Option<u64>,
48}
49
50impl TransportConfig {
51 pub fn tls_mode(&self) -> TlsMode {
53 self.tls_mode.unwrap_or_default()
54 }
55
56 pub fn zero_rtt(&self) -> bool {
58 self.zero_rtt.unwrap_or(false)
59 }
60
61 pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
63 self.alpn_protocols
64 .clone()
65 .unwrap_or_else(|| vec!["h3".into()])
66 }
67
68 pub fn congestion(&self) -> Congestion {
70 self.congestion.unwrap_or(Congestion::Bbr)
71 }
72
73 pub fn idle_timeout(&self) -> u64 {
75 self.idle_timeout.unwrap_or(30000)
76 }
77
78 pub fn keep_alive(&self) -> u64 {
80 self.keep_alive.unwrap_or(8000)
81 }
82
83 pub fn max_streams(&self) -> u64 {
85 self.max_streams.unwrap_or(1000)
86 }
87}
88
89impl Default for TransportConfig {
90 fn default() -> Self {
91 Self {
92 tls_mode: Some(TlsMode::Tls),
93 ca_cert: None,
94 tls_cert: None,
95 tls_key: None,
96 zero_rtt: Some(false),
97 alpn_protocols: Some(vec!["h3".into()]),
98 congestion: Some(Congestion::Bbr),
99 cwnd_init: None,
100 idle_timeout: Some(30000),
101 keep_alive: Some(8000),
102 max_streams: Some(1000),
103 }
104 }
105}
106
107#[derive(Deserialize, Serialize, Debug, Clone)]
109#[serde(rename_all = "snake_case")]
110pub struct ConnectionConfig {
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub max_connections: Option<usize>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub auth_timeout_secs: Option<u64>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub max_concurrent_streams: Option<usize>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub max_concurrent_datagrams: Option<usize>,
126}
127
128impl ConnectionConfig {
129 pub fn max_connections(&self) -> usize {
131 self.max_connections.unwrap_or(10000)
132 }
133
134 pub fn auth_timeout_secs(&self) -> u64 {
136 self.auth_timeout_secs.unwrap_or(10)
137 }
138
139 pub fn max_concurrent_streams(&self) -> usize {
141 self.max_concurrent_streams.unwrap_or(4096)
142 }
143
144 pub fn max_concurrent_datagrams(&self) -> usize {
146 self.max_concurrent_datagrams.unwrap_or(4096)
147 }
148}
149
150impl Default for ConnectionConfig {
151 fn default() -> Self {
152 Self {
153 max_connections: Some(10000),
154 auth_timeout_secs: Some(10),
155 max_concurrent_streams: Some(4096),
156 max_concurrent_datagrams: Some(4096),
157 }
158 }
159}
160
161#[cfg(feature = "tracing")]
163#[derive(Deserialize, Serialize, Debug, Clone)]
164#[serde(rename_all = "snake_case")]
165pub struct LoggingConfig {
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub log_level: Option<String>,
169}
170
171#[cfg(feature = "tracing")]
172impl LoggingConfig {
173 pub fn log_level(&self) -> &str {
175 self.log_level.as_deref().unwrap_or("INFO")
176 }
177}
178
179#[cfg(feature = "tracing")]
180impl Default for LoggingConfig {
181 fn default() -> Self {
182 Self {
183 log_level: Some("INFO".to_string()),
184 }
185 }
186}
187
188#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
189#[serde(rename_all = "kebab-case")]
190pub enum TlsMode {
191 #[default]
192 Tls,
193 MTls,
194 Insecure,
195}
196
197#[derive(Debug, Clone)]
199pub struct ServiceConfig {
200 pub secret: String,
201 pub listen: SocketAddr,
202 pub transport: TransportConfig,
203 pub connection: ConnectionConfig,
204 #[cfg(feature = "tracing")]
205 pub logging: LoggingConfig,
206}
207
208pub struct ConfigBuilder {
211 secret: Option<String>,
212 listen: Option<SocketAddr>,
213 transport: TransportConfig,
214 connection: ConnectionConfig,
215 #[cfg(feature = "tracing")]
216 logging: LoggingConfig,
217}
218
219impl ConfigBuilder {
220 pub fn new() -> Self {
222 Self {
223 secret: None,
224 listen: None,
225 transport: TransportConfig::default(),
226 connection: ConnectionConfig::default(),
227 #[cfg(feature = "tracing")]
228 logging: LoggingConfig::default(),
229 }
230 }
231
232 pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
234 if let Some(secret) = json_config.secret {
235 self.secret = Some(secret);
236 }
237 if let Some(listen) = json_config.listen {
238 self.listen = Some(listen);
239 }
240 if let Some(transport) = json_config.transport {
241 self.transport = Self::merge_transport(self.transport, transport);
242 }
243 if let Some(conn) = json_config.connection {
244 self.connection = Self::merge_connection(self.connection, conn);
245 }
246 #[cfg(feature = "tracing")]
247 {
248 if let Some(logging) = json_config.logging {
249 self.logging = Self::merge_logging(self.logging, logging);
250 }
251 }
252 self
253 }
254
255 pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
257 if let Some(secret) = cli_config.secret {
258 self.secret = Some(secret);
259 }
260 if let Some(listen) = cli_config.listen {
261 self.listen = Some(listen);
262 }
263 self.transport = Self::merge_transport(self.transport, cli_config.transport);
264 #[cfg(feature = "tracing")]
265 {
266 self.logging = Self::merge_logging(self.logging, cli_config.logging);
267 }
268 self
269 }
270
271 pub fn build(self) -> Result<ServiceConfig, String> {
273 let secret = self
274 .secret
275 .ok_or_else(|| "missing required field: secret".to_string())?;
276 let listen = self
277 .listen
278 .ok_or_else(|| "missing required field: listen".to_string())?;
279
280 Ok(ServiceConfig {
281 secret,
282 listen,
283 transport: self.transport,
284 connection: self.connection,
285 #[cfg(feature = "tracing")]
286 logging: self.logging,
287 })
288 }
289
290 fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
291 TransportConfig {
292 tls_mode: override_config.tls_mode.or(base.tls_mode),
293 ca_cert: override_config.ca_cert.or(base.ca_cert),
294 tls_cert: override_config.tls_cert.or(base.tls_cert),
295 tls_key: override_config.tls_key.or(base.tls_key),
296 zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
297 alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
298 congestion: override_config.congestion.or(base.congestion),
299 cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
300 idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
301 keep_alive: override_config.keep_alive.or(base.keep_alive),
302 max_streams: override_config.max_streams.or(base.max_streams),
303 }
304 }
305
306 fn merge_connection(
307 base: ConnectionConfig,
308 override_config: ConnectionConfig,
309 ) -> ConnectionConfig {
310 ConnectionConfig {
311 max_connections: override_config.max_connections.or(base.max_connections),
312 auth_timeout_secs: override_config.auth_timeout_secs.or(base.auth_timeout_secs),
313 max_concurrent_streams: override_config
314 .max_concurrent_streams
315 .or(base.max_concurrent_streams),
316 max_concurrent_datagrams: override_config
317 .max_concurrent_datagrams
318 .or(base.max_concurrent_datagrams),
319 }
320 }
321
322 #[cfg(feature = "tracing")]
323 fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
324 LoggingConfig {
325 log_level: override_config.log_level.or(base.log_level),
326 }
327 }
328}
329
330impl Default for ConfigBuilder {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336#[cfg(feature = "binary")]
347pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
348 use clap::Parser;
349 let cli_args = cli::Args::parse();
350 let mut builder = ConfigBuilder::new();
351
352 if let Some(config_path) = &cli_args.config {
354 let json_config = json::JsonConfig::from_file(config_path)?;
355 builder = builder.merge_json(json_config);
356 }
357
358 let cli_config = cli::CliConfig {
360 secret: cli_args.secret,
361 listen: cli_args.listen,
362 transport: cli_args.transport.into_transport_config(),
363 #[cfg(feature = "tracing")]
364 logging: cli_args.logging.into_logging_config(),
365 };
366 builder = builder.merge_cli(cli_config);
367
368 builder.build().map_err(|e| e.into())
369}
370
371pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
384 let json_config = json::JsonConfig::from_json_str(json_str)?;
385 ConfigBuilder::new()
386 .merge_json(json_config)
387 .build()
388 .map_err(|e| e.into())
389}
390
391pub fn load_from_file(
404 config_path: &std::path::Path,
405) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
406 let json_config = json::JsonConfig::from_file(config_path)?;
407 ConfigBuilder::new()
408 .merge_json(json_config)
409 .build()
410 .map_err(|e| e.into())
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn load_from_json_minimal_uses_defaults() {
419 let json = r#"{
420 "secret": "k",
421 "listen": "0.0.0.0:443"
422 }"#;
423 let cfg = load_from_json(json).unwrap();
424 assert_eq!(cfg.secret, "k");
425 assert_eq!(cfg.listen.to_string(), "0.0.0.0:443");
426 assert_eq!(cfg.transport.tls_mode, Some(TlsMode::Tls));
427 assert_eq!(cfg.transport.idle_timeout, Some(30000));
428 assert_eq!(cfg.connection.max_connections, Some(10000));
429 assert_eq!(cfg.connection.auth_timeout_secs, Some(10));
430 assert_eq!(cfg.connection.max_concurrent_streams, Some(4096));
431 }
432
433 #[test]
434 fn load_from_json_missing_secret_fails() {
435 let json = r#"{ "listen": "0.0.0.0:443" }"#;
436 let err = load_from_json(json).unwrap_err();
437 assert!(err.to_string().contains("secret"));
438 }
439
440 #[test]
441 fn load_from_json_missing_listen_fails() {
442 let json = r#"{ "secret": "k" }"#;
443 let err = load_from_json(json).unwrap_err();
444 assert!(err.to_string().contains("listen"));
445 }
446
447 #[test]
448 fn load_from_json_invalid_listen_address_fails() {
449 let json = r#"{ "secret": "k", "listen": "not-an-address" }"#;
450 let result = load_from_json(json);
451 assert!(result.is_err());
452 }
453
454 #[test]
455 fn load_from_json_overrides_transport() {
456 let json = r#"{
457 "secret": "k",
458 "listen": "127.0.0.1:443",
459 "transport": {
460 "tls_mode": "m-tls",
461 "idle_timeout": 12345,
462 "max_streams": 999
463 }
464 }"#;
465 let cfg = load_from_json(json).unwrap();
466 assert_eq!(cfg.transport.tls_mode, Some(TlsMode::MTls));
467 assert_eq!(cfg.transport.idle_timeout, Some(12345));
468 assert_eq!(cfg.transport.max_streams, Some(999));
469 }
470
471 #[test]
472 fn load_from_json_overrides_connection_limits() {
473 let json = r#"{
474 "secret": "k",
475 "listen": "127.0.0.1:443",
476 "connection": {
477 "max_connections": 500,
478 "auth_timeout_secs": 5,
479 "max_concurrent_streams": 100,
480 "max_concurrent_datagrams": 200
481 }
482 }"#;
483 let cfg = load_from_json(json).unwrap();
484 assert_eq!(cfg.connection.max_connections, Some(500));
485 assert_eq!(cfg.connection.auth_timeout_secs, Some(5));
486 assert_eq!(cfg.connection.max_concurrent_streams, Some(100));
487 assert_eq!(cfg.connection.max_concurrent_datagrams, Some(200));
488 }
489
490 #[test]
491 fn cli_overrides_json_in_merge_order() {
492 let json = json::JsonConfig {
493 secret: Some("from_json".into()),
494 listen: Some("0.0.0.0:5555".parse().unwrap()),
495 transport: Some(TransportConfig {
496 idle_timeout: Some(11111),
497 keep_alive: Some(2222),
498 ..Default::default()
499 }),
500 connection: None,
501 #[cfg(feature = "tracing")]
502 logging: None,
503 };
504
505 let cli = cli::CliConfig {
506 secret: None, listen: Some("127.0.0.1:6666".parse().unwrap()), transport: TransportConfig {
509 idle_timeout: Some(99999), keep_alive: None, ..Default::default()
512 },
513 #[cfg(feature = "tracing")]
514 logging: LoggingConfig::default(),
515 };
516
517 let cfg = ConfigBuilder::new()
518 .merge_json(json)
519 .merge_cli(cli)
520 .build()
521 .unwrap();
522
523 assert_eq!(cfg.secret, "from_json");
524 assert_eq!(cfg.listen.to_string(), "127.0.0.1:6666");
525 assert_eq!(cfg.transport.idle_timeout, Some(99999));
526 assert_eq!(cfg.transport.keep_alive, Some(2222));
527 }
528
529 #[test]
530 fn transport_config_accessors_apply_defaults_on_none() {
531 let cfg = TransportConfig {
532 tls_mode: None,
533 ca_cert: None,
534 tls_cert: None,
535 tls_key: None,
536 zero_rtt: None,
537 alpn_protocols: None,
538 congestion: None,
539 cwnd_init: None,
540 idle_timeout: None,
541 keep_alive: None,
542 max_streams: None,
543 };
544 assert_eq!(cfg.tls_mode(), TlsMode::Tls);
545 assert!(!cfg.zero_rtt());
546 assert_eq!(cfg.idle_timeout(), 30000);
547 assert_eq!(cfg.keep_alive(), 8000);
548 assert_eq!(cfg.max_streams(), 1000);
549 assert_eq!(cfg.alpn_protocols(), vec![b"h3".to_vec()]);
550 }
551
552 #[test]
553 fn connection_config_accessors_apply_defaults_on_none() {
554 let cfg = ConnectionConfig {
555 max_connections: None,
556 auth_timeout_secs: None,
557 max_concurrent_streams: None,
558 max_concurrent_datagrams: None,
559 };
560 assert_eq!(cfg.max_connections(), 10000);
561 assert_eq!(cfg.auth_timeout_secs(), 10);
562 assert_eq!(cfg.max_concurrent_streams(), 4096);
563 assert_eq!(cfg.max_concurrent_datagrams(), 4096);
564 }
565
566 #[test]
567 fn tls_mode_kebab_case_serialization() {
568 assert_eq!(serde_json::to_string(&TlsMode::Tls).unwrap(), "\"tls\"");
569 assert_eq!(
570 serde_json::to_string(&TlsMode::MTls).unwrap(),
571 "\"m-tls\""
572 );
573 assert_eq!(
574 serde_json::to_string(&TlsMode::Insecure).unwrap(),
575 "\"insecure\""
576 );
577 assert_eq!(TlsMode::default(), TlsMode::Tls);
578 }
579
580 #[test]
581 fn json_config_roundtrips() {
582 let original = r#"{
583 "secret": "abc",
584 "listen": "0.0.0.0:443",
585 "transport": { "tls_mode": "insecure", "max_streams": 50 },
586 "connection": { "max_connections": 100 }
587 }"#;
588 let parsed = json::JsonConfig::from_json_str(original).unwrap();
589 let s = serde_json::to_string(&parsed).unwrap();
590 let reparsed = json::JsonConfig::from_json_str(&s).unwrap();
591 assert_eq!(reparsed.secret.as_deref(), Some("abc"));
592 }
593
594 #[test]
595 fn load_from_file_missing_path_returns_error() {
596 let p = std::path::Path::new("/no/such/file/srvcfg.json");
597 assert!(load_from_file(p).is_err());
598 }
599
600 #[test]
601 fn load_from_file_reads_real_file() {
602 let path = std::env::temp_dir()
603 .join(format!("ombrac-server-cfg-{}.json", std::process::id()));
604 std::fs::write(
605 &path,
606 r#"{"secret":"abc","listen":"127.0.0.1:9999"}"#,
607 )
608 .unwrap();
609
610 let cfg = load_from_file(&path).unwrap();
611 assert_eq!(cfg.secret, "abc");
612 assert_eq!(cfg.listen.to_string(), "127.0.0.1:9999");
613
614 std::fs::remove_file(&path).ok();
615 }
616}