ombrac_server/
config.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::builder::Styles;
5use clap::builder::styling::{AnsiColor, Style};
6use clap::{Parser, ValueEnum};
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::Congestion;
12
13// CLI Args
14#[derive(Parser, Debug)]
15#[command(version, about, long_about = None, styles = styles())]
16pub struct Args {
17    /// Path to the JSON configuration file
18    #[clap(long, short = 'c', value_name = "FILE")]
19    pub config: Option<PathBuf>,
20
21    /// Protocol Secret
22    #[clap(
23        long,
24        short = 'k',
25        help_heading = "Required",
26        value_name = "STR",
27        required_unless_present = "config"
28    )]
29    pub secret: Option<String>,
30
31    /// The address to bind for transport
32    #[clap(
33        long,
34        short = 'l',
35        help_heading = "Required",
36        value_name = "ADDR",
37        required_unless_present = "config"
38    )]
39    pub listen: Option<SocketAddr>,
40
41    #[cfg(feature = "transport-quic")]
42    #[clap(flatten)]
43    pub transport: TransportConfig,
44
45    #[cfg(feature = "tracing")]
46    #[clap(flatten)]
47    pub logging: LoggingConfig,
48}
49
50// JSON Config File
51#[derive(Deserialize, Serialize, Debug, Default)]
52pub struct ConfigFile {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub secret: Option<String>,
55
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub listen: Option<SocketAddr>,
58
59    #[cfg(feature = "transport-quic")]
60    pub transport: TransportConfig,
61
62    #[cfg(feature = "tracing")]
63    pub logging: LoggingConfig,
64}
65
66#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
67#[cfg(feature = "transport-quic")]
68pub struct TransportConfig {
69    /// Set the TLS mode for the connection
70    /// tls: Standard TLS with server certificate verification
71    /// m-tls: Mutual TLS with client and server certificate verification
72    /// insecure: Generates a self-signed certificate for testing (SANs set to 'localhost')
73    #[clap(long, value_enum, help_heading = "Transport")]
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub tls_mode: Option<TlsMode>,
76
77    /// Path to the Certificate Authority (CA) certificate file for mTLS
78    #[clap(long, help_heading = "Transport", value_name = "FILE")]
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub ca_cert: Option<PathBuf>,
81
82    /// Path to the TLS certificate file
83    #[clap(long, help_heading = "Transport", value_name = "FILE")]
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub tls_cert: Option<PathBuf>,
86
87    /// Path to the TLS private key file
88    #[clap(long, help_heading = "Transport", value_name = "FILE")]
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub tls_key: Option<PathBuf>,
91
92    /// Enable 0-RTT for faster connection establishment (may reduce security)
93    #[clap(long, help_heading = "Transport", action)]
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub zero_rtt: Option<bool>,
96
97    /// Application-Layer protocol negotiation (ALPN) protocols [default: h3]
98    #[clap(
99        long,
100        help_heading = "Transport",
101        value_name = "PROTOCOLS",
102        value_delimiter = ','
103    )]
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub alpn_protocols: Option<Vec<Vec<u8>>>,
106
107    /// Congestion control algorithm to use (e.g. bbr, cubic, newreno) [default: bbr]
108    #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub congestion: Option<Congestion>,
111
112    /// Initial congestion window size in bytes
113    #[clap(long, help_heading = "Transport", value_name = "NUM")]
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub cwnd_init: Option<u64>,
116
117    /// Maximum idle time (in milliseconds) before closing the connection [default: 30000]
118    #[clap(long, help_heading = "Transport", value_name = "TIME")]
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub idle_timeout: Option<u64>,
121
122    /// Keep-alive interval (in milliseconds) [default: 8000]
123    #[clap(long, help_heading = "Transport", value_name = "TIME")]
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub keep_alive: Option<u64>,
126
127    /// Maximum number of bidirectional streams that can be open simultaneously [default: 1000]
128    #[clap(long, help_heading = "Transport", value_name = "NUM")]
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub max_streams: Option<u64>,
131}
132
133#[cfg(feature = "tracing")]
134#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
135pub struct LoggingConfig {
136    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
137    #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub log_level: Option<String>,
140}
141
142#[cfg(feature = "transport-quic")]
143#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq)]
144#[serde(rename_all = "kebab-case")]
145#[derive(Default)]
146pub enum TlsMode {
147    #[default]
148    Tls,
149    MTls,
150    Insecure,
151}
152
153#[derive(Debug, Clone)]
154pub struct ServiceConfig {
155    pub secret: String,
156    pub listen: SocketAddr,
157    #[cfg(feature = "transport-quic")]
158    pub transport: TransportConfig,
159    #[cfg(feature = "tracing")]
160    pub logging: LoggingConfig,
161}
162
163#[cfg(feature = "transport-quic")]
164impl Default for TransportConfig {
165    fn default() -> Self {
166        Self {
167            tls_mode: Some(TlsMode::Tls),
168            ca_cert: None,
169            tls_cert: None,
170            tls_key: None,
171            zero_rtt: Some(false),
172            alpn_protocols: Some(vec!["h3".into()]),
173            congestion: Some(Congestion::Bbr),
174            cwnd_init: None,
175            idle_timeout: Some(30000),
176            keep_alive: Some(8000),
177            max_streams: Some(1000),
178        }
179    }
180}
181
182#[cfg(feature = "tracing")]
183impl Default for LoggingConfig {
184    fn default() -> Self {
185        Self {
186            log_level: Some("INFO".to_string()),
187        }
188    }
189}
190
191#[cfg(feature = "binary")]
192pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
193    use figment::Figment;
194    use figment::providers::{Format, Json, Serialized};
195
196    let args = Args::parse();
197
198    let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
199
200    if let Some(config_path) = &args.config {
201        if !config_path.exists() {
202            let err = std::io::Error::new(
203                std::io::ErrorKind::NotFound,
204                format!("Configuration file not found: {}", config_path.display()),
205            );
206            return Err(Box::new(figment::Error::from(err.to_string())));
207        }
208
209        figment = figment.merge(Json::file(config_path));
210    }
211
212    let cli_overrides = ConfigFile {
213        secret: args.secret,
214        listen: args.listen,
215        #[cfg(feature = "transport-quic")]
216        transport: args.transport,
217        #[cfg(feature = "tracing")]
218        logging: args.logging,
219    };
220
221    figment = figment.merge(Serialized::defaults(cli_overrides));
222
223    let config: ConfigFile = figment.extract()?;
224
225    let secret = config
226        .secret
227        .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
228    let listen = config
229        .listen
230        .ok_or_else(|| figment::Error::from("missing field `listen`"))?;
231
232    Ok(ServiceConfig {
233        secret,
234        listen,
235        #[cfg(feature = "transport-quic")]
236        transport: config.transport,
237        #[cfg(feature = "tracing")]
238        logging: config.logging,
239    })
240}
241
242fn styles() -> Styles {
243    Styles::styled()
244        .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
245        .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
246        .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
247        .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
248        .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
249        .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
250        .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
251}