ombrac_server/
config.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::builder::Styles;
5use clap::builder::styling::{AnsiColor, Style};
6use clap::{Parser, ValueEnum};
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::Congestion;
12
13// CLI Args
14#[derive(Parser, Debug)]
15#[command(version, about, long_about = None, styles = styles())]
16pub struct Args {
17    /// Path to the JSON configuration file
18    #[clap(long, short = 'c', value_name = "FILE")]
19    pub config: Option<PathBuf>,
20
21    /// Protocol Secret
22    #[clap(
23        long,
24        short = 'k',
25        help_heading = "Required",
26        value_name = "STR",
27        required_unless_present = "config"
28    )]
29    pub secret: Option<String>,
30
31    /// The address to bind for transport
32    #[clap(
33        long,
34        short = 'l',
35        help_heading = "Required",
36        value_name = "ADDR",
37        required_unless_present = "config"
38    )]
39    pub listen: Option<SocketAddr>,
40
41    #[cfg(feature = "transport-quic")]
42    #[clap(flatten)]
43    pub transport: TransportConfig,
44
45    #[cfg(feature = "tracing")]
46    #[clap(flatten)]
47    pub logging: LoggingConfig,
48}
49
50// JSON Config File
51#[derive(Deserialize, Serialize, Debug, Default)]
52pub struct ConfigFile {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub secret: Option<String>,
55
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub listen: Option<SocketAddr>,
58
59    #[cfg(feature = "transport-quic")]
60    pub transport: TransportConfig,
61
62    #[cfg(feature = "tracing")]
63    pub logging: LoggingConfig,
64}
65
66#[cfg(feature = "transport-quic")]
67#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
68pub struct TransportConfig {
69    /// Set the TLS mode for the connection
70    /// tls: Standard TLS with server certificate verification
71    /// m-tls: Mutual TLS with client and server certificate verification
72    /// insecure: Generates a self-signed certificate for testing (SANs set to 'localhost')
73    #[clap(long, value_enum, help_heading = "Transport")]
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub tls_mode: Option<TlsMode>,
76
77    /// Path to the Certificate Authority (CA) certificate file for mTLS
78    #[clap(long, help_heading = "Transport", value_name = "FILE")]
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub ca_cert: Option<PathBuf>,
81
82    /// Path to the TLS certificate file
83    #[clap(long, help_heading = "Transport", value_name = "FILE")]
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub tls_cert: Option<PathBuf>,
86
87    /// Path to the TLS private key file
88    #[clap(long, help_heading = "Transport", value_name = "FILE")]
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub tls_key: Option<PathBuf>,
91
92    /// Enable 0-RTT for faster connection establishment (may reduce security)
93    #[clap(long, help_heading = "Transport", action)]
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub zero_rtt: Option<bool>,
96
97    /// Application-Layer protocol negotiation (ALPN) protocols [default: h3]
98    #[clap(
99        long,
100        help_heading = "Transport",
101        value_name = "PROTOCOLS",
102        value_delimiter = ','
103    )]
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub alpn_protocols: Option<Vec<Vec<u8>>>,
106
107    /// Congestion control algorithm to use (e.g. bbr, cubic, newreno) [default: bbr]
108    #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub congestion: Option<Congestion>,
111
112    /// Initial congestion window size in bytes
113    #[clap(long, help_heading = "Transport", value_name = "NUM")]
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub cwnd_init: Option<u64>,
116
117    /// Maximum idle time (in milliseconds) before closing the connection [default: 30000]
118    #[clap(long, help_heading = "Transport", value_name = "TIME")]
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub idle_timeout: Option<u64>,
121
122    /// Keep-alive interval (in milliseconds) [default: 8000]
123    #[clap(long, help_heading = "Transport", value_name = "TIME")]
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub keep_alive: Option<u64>,
126
127    /// Maximum number of bidirectional streams that can be open simultaneously [default: 1000]
128    #[clap(long, help_heading = "Transport", value_name = "NUM")]
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub max_streams: Option<u64>,
131}
132
133#[cfg(feature = "tracing")]
134#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
135pub struct LoggingConfig {
136    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
137    #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub log_level: Option<String>,
140}
141
142#[cfg(feature = "transport-quic")]
143#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
144#[serde(rename_all = "kebab-case")]
145pub enum TlsMode {
146    #[default]
147    Tls,
148    MTls,
149    Insecure,
150}
151
152#[derive(Debug, Clone)]
153pub struct ServiceConfig {
154    pub secret: String,
155    pub listen: SocketAddr,
156    #[cfg(feature = "transport-quic")]
157    pub transport: TransportConfig,
158    #[cfg(feature = "tracing")]
159    pub logging: LoggingConfig,
160}
161
162#[cfg(feature = "transport-quic")]
163impl Default for TransportConfig {
164    fn default() -> Self {
165        Self {
166            tls_mode: Some(TlsMode::Tls),
167            ca_cert: None,
168            tls_cert: None,
169            tls_key: None,
170            zero_rtt: Some(false),
171            alpn_protocols: Some(vec!["h3".into()]),
172            congestion: Some(Congestion::Bbr),
173            cwnd_init: None,
174            idle_timeout: Some(30000),
175            keep_alive: Some(8000),
176            max_streams: Some(1000),
177        }
178    }
179}
180
181#[cfg(feature = "tracing")]
182impl Default for LoggingConfig {
183    fn default() -> Self {
184        Self {
185            log_level: Some("INFO".to_string()),
186        }
187    }
188}
189
190#[cfg(feature = "binary")]
191pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
192    use figment::Figment;
193    use figment::providers::{Format, Json, Serialized};
194
195    let args = Args::parse();
196
197    let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
198
199    if let Some(config_path) = &args.config {
200        if !config_path.exists() {
201            let err = std::io::Error::new(
202                std::io::ErrorKind::NotFound,
203                format!("Configuration file not found: {}", config_path.display()),
204            );
205            return Err(Box::new(figment::Error::from(err.to_string())));
206        }
207
208        figment = figment.merge(Json::file(config_path));
209    }
210
211    let cli_overrides = ConfigFile {
212        secret: args.secret,
213        listen: args.listen,
214        #[cfg(feature = "transport-quic")]
215        transport: args.transport,
216        #[cfg(feature = "tracing")]
217        logging: args.logging,
218    };
219
220    figment = figment.merge(Serialized::defaults(cli_overrides));
221
222    let config: ConfigFile = figment.extract()?;
223
224    let secret = config
225        .secret
226        .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
227    let listen = config
228        .listen
229        .ok_or_else(|| figment::Error::from("missing field `listen`"))?;
230
231    Ok(ServiceConfig {
232        secret,
233        listen,
234        #[cfg(feature = "transport-quic")]
235        transport: config.transport,
236        #[cfg(feature = "tracing")]
237        logging: config.logging,
238    })
239}
240
241fn styles() -> Styles {
242    Styles::styled()
243        .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
244        .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
245        .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
246        .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
247        .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
248        .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
249        .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
250}