use std::net::SocketAddr;
use std::path::PathBuf;
use clap::ValueEnum;
use serde::{Deserialize, Serialize};
use ombrac_transport::quic::Congestion;
pub mod cli;
pub mod json;
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub struct TransportConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_mode: Option<TlsMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ca_cert: Option<PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_cert: Option<PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tls_key: Option<PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub zero_rtt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alpn_protocols: Option<Vec<Vec<u8>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub congestion: Option<Congestion>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cwnd_init: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub idle_timeout: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_streams: Option<u64>,
}
impl TransportConfig {
pub fn tls_mode(&self) -> TlsMode {
self.tls_mode.unwrap_or_default()
}
pub fn zero_rtt(&self) -> bool {
self.zero_rtt.unwrap_or(false)
}
pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
self.alpn_protocols
.clone()
.unwrap_or_else(|| vec!["h3".into()])
}
pub fn congestion(&self) -> Congestion {
self.congestion.unwrap_or(Congestion::Bbr)
}
pub fn idle_timeout(&self) -> u64 {
self.idle_timeout.unwrap_or(30000)
}
pub fn keep_alive(&self) -> u64 {
self.keep_alive.unwrap_or(8000)
}
pub fn max_streams(&self) -> u64 {
self.max_streams.unwrap_or(1000)
}
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
tls_mode: Some(TlsMode::Tls),
ca_cert: None,
tls_cert: None,
tls_key: None,
zero_rtt: Some(false),
alpn_protocols: Some(vec!["h3".into()]),
congestion: Some(Congestion::Bbr),
cwnd_init: None,
idle_timeout: Some(30000),
keep_alive: Some(8000),
max_streams: Some(1000),
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub struct ConnectionConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth_timeout_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_streams: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_datagrams: Option<usize>,
}
impl ConnectionConfig {
pub fn max_connections(&self) -> usize {
self.max_connections.unwrap_or(10000)
}
pub fn auth_timeout_secs(&self) -> u64 {
self.auth_timeout_secs.unwrap_or(10)
}
pub fn max_concurrent_streams(&self) -> usize {
self.max_concurrent_streams.unwrap_or(4096)
}
pub fn max_concurrent_datagrams(&self) -> usize {
self.max_concurrent_datagrams.unwrap_or(4096)
}
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
max_connections: Some(10000),
auth_timeout_secs: Some(10),
max_concurrent_streams: Some(4096),
max_concurrent_datagrams: Some(4096),
}
}
}
#[cfg(feature = "tracing")]
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub struct LoggingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub log_level: Option<String>,
}
#[cfg(feature = "tracing")]
impl LoggingConfig {
pub fn log_level(&self) -> &str {
self.log_level.as_deref().unwrap_or("INFO")
}
}
#[cfg(feature = "tracing")]
impl Default for LoggingConfig {
fn default() -> Self {
Self {
log_level: Some("INFO".to_string()),
}
}
}
#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum TlsMode {
#[default]
Tls,
MTls,
Insecure,
}
#[derive(Debug, Clone)]
pub struct ServiceConfig {
pub secret: String,
pub listen: SocketAddr,
pub transport: TransportConfig,
pub connection: ConnectionConfig,
#[cfg(feature = "tracing")]
pub logging: LoggingConfig,
}
pub struct ConfigBuilder {
secret: Option<String>,
listen: Option<SocketAddr>,
transport: TransportConfig,
connection: ConnectionConfig,
#[cfg(feature = "tracing")]
logging: LoggingConfig,
}
impl ConfigBuilder {
pub fn new() -> Self {
Self {
secret: None,
listen: None,
transport: TransportConfig::default(),
connection: ConnectionConfig::default(),
#[cfg(feature = "tracing")]
logging: LoggingConfig::default(),
}
}
pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
if let Some(secret) = json_config.secret {
self.secret = Some(secret);
}
if let Some(listen) = json_config.listen {
self.listen = Some(listen);
}
if let Some(transport) = json_config.transport {
self.transport = Self::merge_transport(self.transport, transport);
}
if let Some(conn) = json_config.connection {
self.connection = Self::merge_connection(self.connection, conn);
}
#[cfg(feature = "tracing")]
{
if let Some(logging) = json_config.logging {
self.logging = Self::merge_logging(self.logging, logging);
}
}
self
}
pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
if let Some(secret) = cli_config.secret {
self.secret = Some(secret);
}
if let Some(listen) = cli_config.listen {
self.listen = Some(listen);
}
self.transport = Self::merge_transport(self.transport, cli_config.transport);
#[cfg(feature = "tracing")]
{
self.logging = Self::merge_logging(self.logging, cli_config.logging);
}
self
}
pub fn build(self) -> Result<ServiceConfig, String> {
let secret = self
.secret
.ok_or_else(|| "missing required field: secret".to_string())?;
let listen = self
.listen
.ok_or_else(|| "missing required field: listen".to_string())?;
Ok(ServiceConfig {
secret,
listen,
transport: self.transport,
connection: self.connection,
#[cfg(feature = "tracing")]
logging: self.logging,
})
}
fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
TransportConfig {
tls_mode: override_config.tls_mode.or(base.tls_mode),
ca_cert: override_config.ca_cert.or(base.ca_cert),
tls_cert: override_config.tls_cert.or(base.tls_cert),
tls_key: override_config.tls_key.or(base.tls_key),
zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
congestion: override_config.congestion.or(base.congestion),
cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
keep_alive: override_config.keep_alive.or(base.keep_alive),
max_streams: override_config.max_streams.or(base.max_streams),
}
}
fn merge_connection(
base: ConnectionConfig,
override_config: ConnectionConfig,
) -> ConnectionConfig {
ConnectionConfig {
max_connections: override_config.max_connections.or(base.max_connections),
auth_timeout_secs: override_config.auth_timeout_secs.or(base.auth_timeout_secs),
max_concurrent_streams: override_config
.max_concurrent_streams
.or(base.max_concurrent_streams),
max_concurrent_datagrams: override_config
.max_concurrent_datagrams
.or(base.max_concurrent_datagrams),
}
}
#[cfg(feature = "tracing")]
fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
LoggingConfig {
log_level: override_config.log_level.or(base.log_level),
}
}
}
impl Default for ConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "binary")]
pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
use clap::Parser;
let cli_args = cli::Args::parse();
let mut builder = ConfigBuilder::new();
if let Some(config_path) = &cli_args.config {
let json_config = json::JsonConfig::from_file(config_path)?;
builder = builder.merge_json(json_config);
}
let cli_config = cli::CliConfig {
secret: cli_args.secret,
listen: cli_args.listen,
transport: cli_args.transport.into_transport_config(),
#[cfg(feature = "tracing")]
logging: cli_args.logging.into_logging_config(),
};
builder = builder.merge_cli(cli_config);
builder.build().map_err(|e| e.into())
}
pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
let json_config = json::JsonConfig::from_json_str(json_str)?;
ConfigBuilder::new()
.merge_json(json_config)
.build()
.map_err(|e| e.into())
}
pub fn load_from_file(
config_path: &std::path::Path,
) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
let json_config = json::JsonConfig::from_file(config_path)?;
ConfigBuilder::new()
.merge_json(json_config)
.build()
.map_err(|e| e.into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_from_json_minimal_uses_defaults() {
let json = r#"{
"secret": "k",
"listen": "0.0.0.0:443"
}"#;
let cfg = load_from_json(json).unwrap();
assert_eq!(cfg.secret, "k");
assert_eq!(cfg.listen.to_string(), "0.0.0.0:443");
assert_eq!(cfg.transport.tls_mode, Some(TlsMode::Tls));
assert_eq!(cfg.transport.idle_timeout, Some(30000));
assert_eq!(cfg.connection.max_connections, Some(10000));
assert_eq!(cfg.connection.auth_timeout_secs, Some(10));
assert_eq!(cfg.connection.max_concurrent_streams, Some(4096));
}
#[test]
fn load_from_json_missing_secret_fails() {
let json = r#"{ "listen": "0.0.0.0:443" }"#;
let err = load_from_json(json).unwrap_err();
assert!(err.to_string().contains("secret"));
}
#[test]
fn load_from_json_missing_listen_fails() {
let json = r#"{ "secret": "k" }"#;
let err = load_from_json(json).unwrap_err();
assert!(err.to_string().contains("listen"));
}
#[test]
fn load_from_json_invalid_listen_address_fails() {
let json = r#"{ "secret": "k", "listen": "not-an-address" }"#;
let result = load_from_json(json);
assert!(result.is_err());
}
#[test]
fn load_from_json_overrides_transport() {
let json = r#"{
"secret": "k",
"listen": "127.0.0.1:443",
"transport": {
"tls_mode": "m-tls",
"idle_timeout": 12345,
"max_streams": 999
}
}"#;
let cfg = load_from_json(json).unwrap();
assert_eq!(cfg.transport.tls_mode, Some(TlsMode::MTls));
assert_eq!(cfg.transport.idle_timeout, Some(12345));
assert_eq!(cfg.transport.max_streams, Some(999));
}
#[test]
fn load_from_json_overrides_connection_limits() {
let json = r#"{
"secret": "k",
"listen": "127.0.0.1:443",
"connection": {
"max_connections": 500,
"auth_timeout_secs": 5,
"max_concurrent_streams": 100,
"max_concurrent_datagrams": 200
}
}"#;
let cfg = load_from_json(json).unwrap();
assert_eq!(cfg.connection.max_connections, Some(500));
assert_eq!(cfg.connection.auth_timeout_secs, Some(5));
assert_eq!(cfg.connection.max_concurrent_streams, Some(100));
assert_eq!(cfg.connection.max_concurrent_datagrams, Some(200));
}
#[test]
fn cli_overrides_json_in_merge_order() {
let json = json::JsonConfig {
secret: Some("from_json".into()),
listen: Some("0.0.0.0:5555".parse().unwrap()),
transport: Some(TransportConfig {
idle_timeout: Some(11111),
keep_alive: Some(2222),
..Default::default()
}),
connection: None,
#[cfg(feature = "tracing")]
logging: None,
};
let cli = cli::CliConfig {
secret: None, listen: Some("127.0.0.1:6666".parse().unwrap()), transport: TransportConfig {
idle_timeout: Some(99999), keep_alive: None, ..Default::default()
},
#[cfg(feature = "tracing")]
logging: LoggingConfig::default(),
};
let cfg = ConfigBuilder::new()
.merge_json(json)
.merge_cli(cli)
.build()
.unwrap();
assert_eq!(cfg.secret, "from_json");
assert_eq!(cfg.listen.to_string(), "127.0.0.1:6666");
assert_eq!(cfg.transport.idle_timeout, Some(99999));
assert_eq!(cfg.transport.keep_alive, Some(2222));
}
#[test]
fn transport_config_accessors_apply_defaults_on_none() {
let cfg = TransportConfig {
tls_mode: None,
ca_cert: None,
tls_cert: None,
tls_key: None,
zero_rtt: None,
alpn_protocols: None,
congestion: None,
cwnd_init: None,
idle_timeout: None,
keep_alive: None,
max_streams: None,
};
assert_eq!(cfg.tls_mode(), TlsMode::Tls);
assert!(!cfg.zero_rtt());
assert_eq!(cfg.idle_timeout(), 30000);
assert_eq!(cfg.keep_alive(), 8000);
assert_eq!(cfg.max_streams(), 1000);
assert_eq!(cfg.alpn_protocols(), vec![b"h3".to_vec()]);
}
#[test]
fn connection_config_accessors_apply_defaults_on_none() {
let cfg = ConnectionConfig {
max_connections: None,
auth_timeout_secs: None,
max_concurrent_streams: None,
max_concurrent_datagrams: None,
};
assert_eq!(cfg.max_connections(), 10000);
assert_eq!(cfg.auth_timeout_secs(), 10);
assert_eq!(cfg.max_concurrent_streams(), 4096);
assert_eq!(cfg.max_concurrent_datagrams(), 4096);
}
#[test]
fn tls_mode_kebab_case_serialization() {
assert_eq!(serde_json::to_string(&TlsMode::Tls).unwrap(), "\"tls\"");
assert_eq!(
serde_json::to_string(&TlsMode::MTls).unwrap(),
"\"m-tls\""
);
assert_eq!(
serde_json::to_string(&TlsMode::Insecure).unwrap(),
"\"insecure\""
);
assert_eq!(TlsMode::default(), TlsMode::Tls);
}
#[test]
fn json_config_roundtrips() {
let original = r#"{
"secret": "abc",
"listen": "0.0.0.0:443",
"transport": { "tls_mode": "insecure", "max_streams": 50 },
"connection": { "max_connections": 100 }
}"#;
let parsed = json::JsonConfig::from_json_str(original).unwrap();
let s = serde_json::to_string(&parsed).unwrap();
let reparsed = json::JsonConfig::from_json_str(&s).unwrap();
assert_eq!(reparsed.secret.as_deref(), Some("abc"));
}
#[test]
fn load_from_file_missing_path_returns_error() {
let p = std::path::Path::new("/no/such/file/srvcfg.json");
assert!(load_from_file(p).is_err());
}
#[test]
fn load_from_file_reads_real_file() {
let path = std::env::temp_dir()
.join(format!("ombrac-server-cfg-{}.json", std::process::id()));
std::fs::write(
&path,
r#"{"secret":"abc","listen":"127.0.0.1:9999"}"#,
)
.unwrap();
let cfg = load_from_file(&path).unwrap();
assert_eq!(cfg.secret, "abc");
assert_eq!(cfg.listen.to_string(), "127.0.0.1:9999");
std::fs::remove_file(&path).ok();
}
}