1use std::net::SocketAddr;
68use std::path::{Path, PathBuf};
69
70use serde::Deserialize;
71
72use crate::error::{NetError, NetResult};
73use crate::logging_layer::LogVerbosity;
74use crate::server::AqlServerBuilder;
75use amaters_core::traits::StorageEngine;
76
77#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
87pub struct NetConfig {
88 #[serde(default)]
90 pub net: NetSection,
91}
92
93#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
95pub struct NetSection {
96 pub bind_addr: Option<SocketAddr>,
98 #[serde(default)]
100 pub tls: TlsSection,
101 #[serde(default)]
103 pub metrics: MetricsSection,
104 #[serde(default)]
106 pub logging: LoggingSection,
107 #[serde(default)]
109 pub rate_limit: RateLimitSection,
110 #[serde(default)]
112 pub auth: AuthSection,
113}
114
115#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
117pub struct TlsSection {
118 pub enabled: Option<bool>,
120 pub cert_path: Option<PathBuf>,
122 pub key_path: Option<PathBuf>,
124}
125
126#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
128pub struct MetricsSection {
129 pub addr: Option<SocketAddr>,
131}
132
133#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
135pub struct LoggingSection {
136 pub verbosity: Option<LogVerbosityWire>,
138 pub slow_threshold_ms: Option<u64>,
141}
142
143#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
145pub struct RateLimitSection {
146 pub qps: Option<f64>,
148}
149
150#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
152pub struct AuthSection {
153 pub jwt_secret_path: Option<PathBuf>,
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub struct LogVerbosityWire(pub LogVerbosity);
165
166impl LogVerbosityWire {
167 pub fn parse(s: &str) -> NetResult<Self> {
169 match s.trim().to_ascii_lowercase().as_str() {
170 "off" => Ok(Self(LogVerbosity::Off)),
171 "brief" => Ok(Self(LogVerbosity::Brief)),
172 "detailed" => Ok(Self(LogVerbosity::Detailed)),
173 other => Err(NetError::InvalidRequest(format!(
174 "Invalid log verbosity '{other}': expected 'off', 'brief', or 'detailed'"
175 ))),
176 }
177 }
178}
179
180impl<'de> Deserialize<'de> for LogVerbosityWire {
181 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
182 where
183 D: serde::Deserializer<'de>,
184 {
185 let s = String::deserialize(deserializer)?;
186 Self::parse(&s).map_err(serde::de::Error::custom)
187 }
188}
189
190impl NetConfig {
195 pub fn from_path(path: impl AsRef<Path>) -> NetResult<Self> {
205 let path = path.as_ref();
206 let bytes = std::fs::read(path).map_err(|e| {
207 NetError::InvalidRequest(format!(
208 "Failed to read config file {}: {e}",
209 path.display()
210 ))
211 })?;
212 let text = std::str::from_utf8(&bytes).map_err(|e| {
213 NetError::InvalidRequest(format!(
214 "Config file {} is not valid UTF-8: {e}",
215 path.display()
216 ))
217 })?;
218 let mut cfg: Self = toml::from_str(text).map_err(|e| {
219 NetError::InvalidRequest(format!(
220 "Failed to parse config file {}: {e}",
221 path.display()
222 ))
223 })?;
224
225 if let Some(parent) = path.parent() {
227 cfg.resolve_paths_relative_to(parent);
228 }
229
230 Ok(cfg)
231 }
232
233 pub fn merge_env(mut self) -> NetResult<Self> {
242 if let Some(val) = read_env("AMATERS_NET_BIND_ADDR")? {
243 self.net.bind_addr = Some(parse_env::<SocketAddr>("AMATERS_NET_BIND_ADDR", &val)?);
244 }
245 if let Some(val) = read_env("AMATERS_NET_TLS_ENABLED")? {
246 self.net.tls.enabled = Some(parse_env::<bool>("AMATERS_NET_TLS_ENABLED", &val)?);
247 }
248 if let Some(val) = read_env("AMATERS_NET_TLS_CERT_PATH")? {
249 self.net.tls.cert_path = Some(PathBuf::from(val));
250 }
251 if let Some(val) = read_env("AMATERS_NET_TLS_KEY_PATH")? {
252 self.net.tls.key_path = Some(PathBuf::from(val));
253 }
254 if let Some(val) = read_env("AMATERS_NET_METRICS_ADDR")? {
255 self.net.metrics.addr =
256 Some(parse_env::<SocketAddr>("AMATERS_NET_METRICS_ADDR", &val)?);
257 }
258 if let Some(val) = read_env("AMATERS_NET_LOG_VERBOSITY")? {
259 self.net.logging.verbosity = Some(LogVerbosityWire::parse(&val)?);
260 }
261 if let Some(val) = read_env("AMATERS_NET_SLOW_THRESHOLD_MS")? {
262 self.net.logging.slow_threshold_ms =
263 Some(parse_env::<u64>("AMATERS_NET_SLOW_THRESHOLD_MS", &val)?);
264 }
265 if let Some(val) = read_env("AMATERS_NET_RATE_LIMIT_QPS")? {
266 self.net.rate_limit.qps = Some(parse_env::<f64>("AMATERS_NET_RATE_LIMIT_QPS", &val)?);
267 }
268 if let Some(val) = read_env("AMATERS_NET_JWT_SECRET_PATH")? {
269 self.net.auth.jwt_secret_path = Some(PathBuf::from(val));
270 }
271 Ok(self)
272 }
273
274 pub fn load_layered(path: impl AsRef<Path>) -> NetResult<Self> {
278 Self::from_path(path)?.merge_env()
279 }
280
281 pub fn apply_to<S>(&self, mut builder: AqlServerBuilder<S>) -> AqlServerBuilder<S>
286 where
287 S: StorageEngine + Send + Sync + 'static,
288 {
289 if let Some(verbosity) = self.net.logging.verbosity {
290 builder = builder.with_logging(verbosity.0);
291 }
292 if let Some(slow_ms) = self.net.logging.slow_threshold_ms {
293 builder = builder.with_slow_threshold_ms(slow_ms);
294 }
295 if let Some(addr) = self.net.metrics.addr {
296 builder = builder.with_metrics_addr(addr);
297 }
298 if let Some(addr) = self.net.bind_addr {
299 builder = builder.with_bind_addr(addr);
300 }
301 if let Some(qps) = self.net.rate_limit.qps {
302 builder = builder.with_rate_limit_qps(qps);
303 }
304 if let Some(ref path) = self.net.auth.jwt_secret_path {
305 builder = builder.with_jwt_secret_path(path.clone());
306 }
307 builder
308 }
309
310 fn resolve_paths_relative_to(&mut self, base: &Path) {
312 if let Some(p) = self.net.tls.cert_path.as_mut() {
313 if p.is_relative() {
314 *p = base.join(p.as_path());
315 }
316 }
317 if let Some(p) = self.net.tls.key_path.as_mut() {
318 if p.is_relative() {
319 *p = base.join(p.as_path());
320 }
321 }
322 if let Some(p) = self.net.auth.jwt_secret_path.as_mut() {
323 if p.is_relative() {
324 *p = base.join(p.as_path());
325 }
326 }
327 }
328}
329
330fn read_env(name: &str) -> NetResult<Option<String>> {
337 match std::env::var(name) {
338 Ok(v) => Ok(Some(v)),
339 Err(std::env::VarError::NotPresent) => Ok(None),
340 Err(std::env::VarError::NotUnicode(_)) => Err(NetError::InvalidRequest(format!(
341 "Env var {name} is not valid UTF-8"
342 ))),
343 }
344}
345
346fn parse_env<T: std::str::FromStr>(name: &str, raw: &str) -> NetResult<T>
348where
349 T::Err: std::fmt::Display,
350{
351 raw.parse::<T>()
352 .map_err(|e| NetError::InvalidRequest(format!("Invalid {name}={raw:?}: {e}")))
353}
354
355#[cfg(test)]
360mod tests {
361 use super::*;
362 use amaters_core::storage::MemoryStorage;
363 use serial_test::serial;
364 use std::sync::Arc;
365
366 fn scratch_path(name: &str) -> PathBuf {
368 let mut p = std::env::temp_dir();
369 p.push(format!(
370 "amaters_net_config_test_{name}_{}.toml",
371 uuid::Uuid::new_v4()
372 ));
373 p
374 }
375
376 fn clear_env_vars() {
378 for v in [
379 "AMATERS_NET_BIND_ADDR",
380 "AMATERS_NET_TLS_ENABLED",
381 "AMATERS_NET_TLS_CERT_PATH",
382 "AMATERS_NET_TLS_KEY_PATH",
383 "AMATERS_NET_METRICS_ADDR",
384 "AMATERS_NET_LOG_VERBOSITY",
385 "AMATERS_NET_SLOW_THRESHOLD_MS",
386 "AMATERS_NET_RATE_LIMIT_QPS",
387 "AMATERS_NET_JWT_SECRET_PATH",
388 ] {
389 unsafe { std::env::remove_var(v) };
394 }
395 }
396
397 #[test]
399 fn test_net_config_load_from_toml_file() {
400 let path = scratch_path("full");
401 std::fs::write(
402 &path,
403 r#"
404[net]
405bind_addr = "127.0.0.1:50051"
406
407[net.tls]
408enabled = true
409cert_path = "certs/server.pem"
410key_path = "certs/server.key"
411
412[net.metrics]
413addr = "127.0.0.1:9091"
414
415[net.logging]
416verbosity = "brief"
417slow_threshold_ms = 250
418
419[net.rate_limit]
420qps = 1500.0
421
422[net.auth]
423jwt_secret_path = "secrets/jwt.key"
424"#,
425 )
426 .expect("write toml");
427
428 let cfg = NetConfig::from_path(&path).expect("load config");
429 assert_eq!(
430 cfg.net.bind_addr,
431 Some("127.0.0.1:50051".parse().expect("addr"))
432 );
433 assert_eq!(cfg.net.tls.enabled, Some(true));
434 let scratch_parent = path.parent().expect("parent");
436 assert_eq!(
437 cfg.net.tls.cert_path,
438 Some(scratch_parent.join("certs/server.pem"))
439 );
440 assert_eq!(
441 cfg.net.tls.key_path,
442 Some(scratch_parent.join("certs/server.key"))
443 );
444 assert_eq!(
445 cfg.net.metrics.addr,
446 Some("127.0.0.1:9091".parse().expect("metrics addr"))
447 );
448 assert_eq!(
449 cfg.net.logging.verbosity.map(|v| v.0),
450 Some(LogVerbosity::Brief)
451 );
452 assert_eq!(cfg.net.logging.slow_threshold_ms, Some(250));
453 assert_eq!(cfg.net.rate_limit.qps, Some(1500.0));
454 assert_eq!(
455 cfg.net.auth.jwt_secret_path,
456 Some(scratch_parent.join("secrets/jwt.key"))
457 );
458
459 let _ = std::fs::remove_file(&path);
460 }
461
462 #[test]
465 fn test_net_config_partial_toml_uses_builder_defaults() {
466 let path = scratch_path("partial");
467 std::fs::write(
469 &path,
470 r#"
471[net.metrics]
472addr = "127.0.0.1:9092"
473"#,
474 )
475 .expect("write toml");
476
477 let cfg = NetConfig::from_path(&path).expect("load config");
478 assert_eq!(cfg.net.bind_addr, None);
479 assert_eq!(cfg.net.tls.enabled, None);
480 assert_eq!(cfg.net.tls.cert_path, None);
481 assert_eq!(cfg.net.logging.verbosity, None);
482 assert_eq!(
483 cfg.net.metrics.addr,
484 Some("127.0.0.1:9092".parse().expect("metrics addr"))
485 );
486
487 let _ = std::fs::remove_file(&path);
488 }
489
490 #[test]
493 fn test_net_config_apply_to_builder_overrides() {
494 let path = scratch_path("apply");
495 std::fs::write(
496 &path,
497 r#"
498[net.logging]
499verbosity = "detailed"
500slow_threshold_ms = 50
501
502[net.metrics]
503addr = "127.0.0.1:9093"
504
505[net.rate_limit]
506qps = 250.0
507"#,
508 )
509 .expect("write toml");
510
511 let cfg = NetConfig::from_path(&path).expect("load config");
512 let storage = Arc::new(MemoryStorage::new());
513 let builder = AqlServerBuilder::new(storage);
514 let builder = cfg.apply_to(builder);
515
516 assert_eq!(builder.logging_verbosity(), Some(LogVerbosity::Detailed));
517 assert_eq!(builder.slow_threshold_ms(), Some(50));
518 assert_eq!(
519 builder.metrics_addr(),
520 Some("127.0.0.1:9093".parse().expect("metrics addr"))
521 );
522 assert_eq!(builder.rate_limit_qps(), Some(250.0));
523
524 let _ = std::fs::remove_file(&path);
525 }
526
527 #[test]
529 fn test_net_config_invalid_toml_returns_error() {
530 let path = scratch_path("invalid");
531 std::fs::write(&path, "this is not [net.tls valid toml = yes").expect("write toml");
532
533 let result = NetConfig::from_path(&path);
534 assert!(matches!(result, Err(NetError::InvalidRequest(_))));
535
536 let _ = std::fs::remove_file(&path);
537 }
538
539 #[test]
541 fn test_net_config_full_round_trip() {
542 let path = scratch_path("roundtrip");
543 std::fs::write(
544 &path,
545 r#"
546[net]
547bind_addr = "0.0.0.0:50052"
548
549[net.tls]
550enabled = false
551
552[net.metrics]
553addr = "0.0.0.0:9094"
554
555[net.logging]
556verbosity = "off"
557slow_threshold_ms = 1000
558
559[net.rate_limit]
560qps = 5000.5
561"#,
562 )
563 .expect("write toml");
564
565 let cfg = NetConfig::from_path(&path).expect("load config");
566 let storage = Arc::new(MemoryStorage::new());
567 let builder = AqlServerBuilder::new(storage);
568 let builder = cfg.apply_to(builder);
569
570 assert_eq!(
571 builder.bind_addr(),
572 Some("0.0.0.0:50052".parse().expect("bind addr"))
573 );
574 assert_eq!(builder.logging_verbosity(), Some(LogVerbosity::Off));
575 assert_eq!(builder.slow_threshold_ms(), Some(1000));
576 assert_eq!(
577 builder.metrics_addr(),
578 Some("0.0.0.0:9094".parse().expect("metrics addr"))
579 );
580 assert_eq!(builder.rate_limit_qps(), Some(5000.5));
581
582 let _ = std::fs::remove_file(&path);
583 }
584
585 #[test]
587 fn test_net_config_invalid_log_verbosity_returns_error() {
588 let path = scratch_path("invalid_verb");
589 std::fs::write(
590 &path,
591 r#"
592[net.logging]
593verbosity = "loud"
594"#,
595 )
596 .expect("write toml");
597
598 let result = NetConfig::from_path(&path);
599 assert!(matches!(result, Err(NetError::InvalidRequest(_))));
600
601 let _ = std::fs::remove_file(&path);
602 }
603
604 #[test]
609 #[serial]
610 fn test_env_override_bind_addr() {
611 clear_env_vars();
612 unsafe { std::env::set_var("AMATERS_NET_BIND_ADDR", "127.0.0.1:60001") };
614
615 let cfg = NetConfig::default().merge_env().expect("merge_env");
616
617 assert_eq!(
618 cfg.net.bind_addr,
619 Some("127.0.0.1:60001".parse().expect("addr"))
620 );
621
622 clear_env_vars();
623 }
624
625 #[test]
626 #[serial]
627 fn test_env_override_tls_enabled_true() {
628 clear_env_vars();
629 unsafe { std::env::set_var("AMATERS_NET_TLS_ENABLED", "true") };
631
632 let cfg = NetConfig::default().merge_env().expect("merge_env");
633 assert_eq!(cfg.net.tls.enabled, Some(true));
634
635 clear_env_vars();
636 }
637
638 #[test]
639 #[serial]
640 fn test_env_override_invalid_value_returns_error() {
641 clear_env_vars();
642 unsafe { std::env::set_var("AMATERS_NET_RATE_LIMIT_QPS", "not-a-number") };
644
645 let result = NetConfig::default().merge_env();
646 assert!(matches!(result, Err(NetError::InvalidRequest(_))));
647
648 clear_env_vars();
649 }
650
651 #[test]
652 #[serial]
653 fn test_env_does_not_override_when_unset() {
654 clear_env_vars();
655
656 let mut cfg = NetConfig::default();
657 cfg.net.bind_addr = Some("10.0.0.1:50051".parse().expect("addr"));
658 cfg.net.tls.enabled = Some(false);
659
660 let cfg = cfg.merge_env().expect("merge_env");
661 assert_eq!(
662 cfg.net.bind_addr,
663 Some("10.0.0.1:50051".parse().expect("addr"))
664 );
665 assert_eq!(cfg.net.tls.enabled, Some(false));
666 }
667
668 #[test]
669 #[serial]
670 fn test_layered_load_combines_toml_and_env() {
671 clear_env_vars();
672 let path = scratch_path("layered");
673 std::fs::write(
674 &path,
675 r#"
676[net]
677bind_addr = "127.0.0.1:50051"
678
679[net.metrics]
680addr = "127.0.0.1:9090"
681
682[net.logging]
683verbosity = "off"
684"#,
685 )
686 .expect("write toml");
687
688 unsafe {
691 std::env::set_var("AMATERS_NET_BIND_ADDR", "127.0.0.1:50099");
692 std::env::set_var("AMATERS_NET_LOG_VERBOSITY", "detailed");
693 }
694
695 let cfg = NetConfig::load_layered(&path).expect("layered");
696 assert_eq!(
697 cfg.net.bind_addr,
698 Some("127.0.0.1:50099".parse().expect("addr"))
699 );
700 assert_eq!(
701 cfg.net.metrics.addr,
702 Some("127.0.0.1:9090".parse().expect("metrics addr"))
703 );
704 assert_eq!(
705 cfg.net.logging.verbosity.map(|v| v.0),
706 Some(LogVerbosity::Detailed)
707 );
708
709 clear_env_vars();
710 let _ = std::fs::remove_file(&path);
711 }
712
713 #[test]
714 #[serial]
715 fn test_env_override_log_verbosity_invalid() {
716 clear_env_vars();
717 unsafe { std::env::set_var("AMATERS_NET_LOG_VERBOSITY", "loud") };
719
720 let result = NetConfig::default().merge_env();
721 assert!(matches!(result, Err(NetError::InvalidRequest(_))));
722
723 clear_env_vars();
724 }
725}