ombrac_client/
config.rs

1use std::io;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use clap::builder::Styles;
6use clap::builder::styling::{AnsiColor, Style};
7use clap::{Parser, ValueEnum};
8use figment::Figment;
9use figment::providers::{Format, Json, Serialized};
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::Congestion;
14
15// CLI Args
16#[derive(Parser, Debug)]
17#[command(version, about, long_about = None, styles = styles())]
18pub struct Args {
19    /// Path to the JSON configuration file
20    #[clap(long, short = 'c', value_name = "FILE")]
21    pub config: Option<PathBuf>,
22
23    /// Protocol Secret
24    #[clap(
25        long,
26        short = 'k',
27        help_heading = "Required",
28        value_name = "STR",
29        required_unless_present = "config"
30    )]
31    pub secret: Option<String>,
32
33    /// Address of the server to connect to
34    #[clap(
35        long,
36        short = 's',
37        help_heading = "Required",
38        value_name = "ADDR",
39        required_unless_present = "config"
40    )]
41    pub server: Option<String>,
42
43    #[clap(flatten)]
44    pub endpoint: EndpointConfig,
45
46    #[cfg(feature = "transport-quic")]
47    #[clap(flatten)]
48    pub transport: TransportConfig,
49
50    #[cfg(feature = "tracing")]
51    #[clap(flatten)]
52    pub logging: LoggingConfig,
53}
54
55// JSON Config File
56#[derive(Deserialize, Serialize, Debug, Default)]
57pub struct ConfigFile {
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub secret: Option<String>,
60
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub server: Option<String>,
63
64    pub endpoint: EndpointConfig,
65
66    #[cfg(feature = "transport-quic")]
67    pub transport: TransportConfig,
68
69    #[cfg(feature = "tracing")]
70    pub logging: LoggingConfig,
71}
72
73#[derive(Deserialize, Serialize, Debug, Parser, Clone, Default)]
74pub struct EndpointConfig {
75    /// The address to bind for the HTTP/HTTPS server
76    #[cfg(feature = "endpoint-http")]
77    #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub http: Option<SocketAddr>,
80
81    /// The address to bind for the SOCKS server
82    #[cfg(feature = "endpoint-socks")]
83    #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub socks: Option<SocketAddr>,
86
87    #[cfg(feature = "endpoint-tun")]
88    #[clap(flatten)]
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub tun: Option<TunConfig>,
91}
92
93#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
94#[cfg(feature = "transport-quic")]
95pub struct TransportConfig {
96    /// The address to bind for transport
97    #[clap(long, help_heading = "Transport", value_name = "ADDR")]
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub bind: Option<SocketAddr>,
100
101    /// Name of the server to connect (derived from `server` if not provided)
102    #[clap(long, help_heading = "Transport", value_name = "STR")]
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub server_name: Option<String>,
105
106    /// Set the TLS mode for the connection
107    /// tls: Standard TLS with server certificate verification
108    /// m-tls: Mutual TLS with client and server certificate verification
109    /// insecure: Skip server certificate verification (for testing only)
110    #[clap(long, value_enum, help_heading = "Transport")]
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub tls_mode: Option<TlsMode>,
113
114    /// Path to the Certificate Authority (CA) certificate file
115    /// in 'TLS' mode, if not provided, the system's default root certificates are used
116    #[clap(long, help_heading = "Transport", value_name = "FILE")]
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub ca_cert: Option<PathBuf>,
119
120    /// Path to the client's TLS certificate for mTLS
121    #[clap(long, help_heading = "Transport", value_name = "FILE")]
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub client_cert: Option<PathBuf>,
124
125    /// Path to the client's TLS private key for mTLS
126    #[clap(long, help_heading = "Transport", value_name = "FILE")]
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub client_key: Option<PathBuf>,
129
130    /// Enable 0-RTT for faster connection establishment (may reduce security)
131    #[clap(long, help_heading = "Transport", action)]
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub zero_rtt: Option<bool>,
134
135    /// Application-Layer protocol negotiation (ALPN) protocols [default: h3]
136    #[clap(
137        long,
138        help_heading = "Transport",
139        value_name = "PROTOCOLS",
140        value_delimiter = ','
141    )]
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub alpn_protocols: Option<Vec<Vec<u8>>>,
144
145    /// Congestion control algorithm to use (e.g. bbr, cubic, newreno) [default: bbr]
146    #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub congestion: Option<Congestion>,
149
150    /// Initial congestion window size in bytes
151    #[clap(long, help_heading = "Transport", value_name = "NUM")]
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub cwnd_init: Option<u64>,
154
155    /// Maximum idle time (in milliseconds) before closing the connection [default: 30000]
156    /// 30 second default recommended by RFC 9308
157    #[clap(long, help_heading = "Transport", value_name = "TIME")]
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub idle_timeout: Option<u64>,
160
161    /// Keep-alive interval (in milliseconds) [default: 8000]
162    #[clap(long, help_heading = "Transport", value_name = "TIME")]
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub keep_alive: Option<u64>,
165
166    /// Maximum number of bidirectional streams that can be open simultaneously [default: 100]
167    #[clap(long, help_heading = "Transport", value_name = "NUM")]
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub max_streams: Option<u64>,
170}
171
172#[cfg(feature = "tracing")]
173#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
174pub struct LoggingConfig {
175    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
176    #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub log_level: Option<String>,
179}
180
181#[cfg(feature = "endpoint-tun")]
182#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
183pub struct TunConfig {
184    /// Use a pre-existing TUN device by providing its file descriptor.  
185    /// `tun_ipv4`, `tun_ipv6`, and `tun_mtu` will be ignored.
186    #[clap(long, help_heading = "Endpoint", value_name = "FD")]
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub tun_fd: Option<i32>,
189
190    /// The IPv4 address and subnet for the TUN device, in CIDR notation (e.g., 198.19.0.1/24).
191    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub tun_ipv4: Option<String>,
194
195    /// The IPv6 address and subnet for the TUN device, in CIDR notation (e.g., fd00::1/64).
196    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub tun_ipv6: Option<String>,
199
200    /// The Maximum Transmission Unit (MTU) for the TUN device. [default: 1500]
201    #[clap(long, help_heading = "Endpoint", value_name = "U16")]
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub tun_mtu: Option<u16>,
204
205    /// The IPv4 address pool for the built-in fake DNS server, in CIDR notation. [default: 198.18.0.0/16]
206    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub fake_dns: Option<String>,
209}
210
211#[cfg(feature = "transport-quic")]
212#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq)]
213#[serde(rename_all = "kebab-case")]
214#[derive(Default)]
215pub enum TlsMode {
216    #[default]
217    Tls,
218    MTls,
219    Insecure,
220}
221
222#[derive(Debug, Clone)]
223pub struct ServiceConfig {
224    pub secret: String,
225    pub server: String,
226    pub endpoint: EndpointConfig,
227    #[cfg(feature = "transport-quic")]
228    pub transport: TransportConfig,
229    #[cfg(feature = "tracing")]
230    pub logging: LoggingConfig,
231}
232
233#[cfg(feature = "transport-quic")]
234impl Default for TransportConfig {
235    fn default() -> Self {
236        Self {
237            bind: None,
238            server_name: None,
239            tls_mode: Some(TlsMode::Tls),
240            ca_cert: None,
241            client_cert: None,
242            client_key: None,
243            zero_rtt: Some(false),
244            alpn_protocols: Some(vec!["h3".into()]),
245            congestion: Some(Congestion::Bbr),
246            cwnd_init: None,
247            idle_timeout: Some(30000),
248            keep_alive: Some(8000),
249            max_streams: Some(100),
250        }
251    }
252}
253
254#[cfg(feature = "transport-quic")]
255#[cfg(feature = "endpoint-tun")]
256impl Default for TunConfig {
257    fn default() -> Self {
258        Self {
259            tun_fd: None,
260            tun_ipv4: None,
261            tun_ipv6: None,
262            tun_mtu: Some(1500),
263            fake_dns: Some("198.18.0.0/16".to_string()),
264        }
265    }
266}
267
268#[cfg(feature = "tracing")]
269impl Default for LoggingConfig {
270    fn default() -> Self {
271        Self {
272            log_level: Some("INFO".to_string()),
273        }
274    }
275}
276
277pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
278    let args = Args::parse();
279
280    let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
281
282    if let Some(config_path) = &args.config {
283        if !config_path.exists() {
284            let err = io::Error::new(
285                io::ErrorKind::NotFound,
286                format!("Configuration file not found: {}", config_path.display()),
287            );
288            return Err(Box::new(figment::Error::from(err.to_string())));
289        }
290
291        figment = figment.merge(Json::file(config_path));
292    }
293
294    let cli_overrides = ConfigFile {
295        secret: args.secret,
296        server: args.server,
297        endpoint: args.endpoint,
298        #[cfg(feature = "transport-quic")]
299        transport: args.transport,
300        #[cfg(feature = "tracing")]
301        logging: args.logging,
302    };
303
304    figment = figment.merge(Serialized::defaults(cli_overrides));
305
306    let config: ConfigFile = figment.extract()?;
307
308    let secret = config
309        .secret
310        .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
311    let server = config
312        .server
313        .ok_or_else(|| figment::Error::from("missing field `server`"))?;
314
315    Ok(ServiceConfig {
316        secret,
317        server,
318        endpoint: config.endpoint,
319        #[cfg(feature = "transport-quic")]
320        transport: config.transport,
321        #[cfg(feature = "tracing")]
322        logging: config.logging,
323    })
324}
325
326fn styles() -> Styles {
327    Styles::styled()
328        .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
329        .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
330        .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
331        .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
332        .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
333        .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
334        .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
335}