ombrac_client/
config.rs

1use std::io;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use clap::builder::Styles;
6use clap::builder::styling::{AnsiColor, Style};
7use clap::{Parser, ValueEnum};
8use figment::Figment;
9use figment::providers::{Format, Json, Serialized};
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::Congestion;
14
15// CLI Args
16#[derive(Parser, Debug)]
17#[command(version, about, long_about = None, styles = styles())]
18pub struct Args {
19    /// Path to the JSON configuration file
20    #[clap(long, short = 'c', value_name = "FILE")]
21    pub config: Option<PathBuf>,
22
23    /// Protocol Secret
24    #[clap(
25        long,
26        short = 'k',
27        help_heading = "Required",
28        value_name = "STR",
29        required_unless_present = "config"
30    )]
31    pub secret: Option<String>,
32
33    /// Address of the server to connect to
34    #[clap(
35        long,
36        short = 's',
37        help_heading = "Required",
38        value_name = "ADDR",
39        required_unless_present = "config"
40    )]
41    pub server: Option<String>,
42
43    /// Extended parameter of the protocol, used for handshake related information
44    #[clap(long, help_heading = "Protocol", value_name = "STR")]
45    pub handshake_option: Option<String>,
46
47    #[clap(flatten)]
48    pub endpoint: EndpointConfig,
49
50    #[cfg(feature = "transport-quic")]
51    #[clap(flatten)]
52    pub transport: TransportConfig,
53
54    #[cfg(feature = "tracing")]
55    #[clap(flatten)]
56    pub logging: LoggingConfig,
57}
58
59// JSON Config File
60#[derive(Deserialize, Serialize, Debug, Default)]
61pub struct ConfigFile {
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub secret: Option<String>,
64
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub server: Option<String>,
67
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub handshake_option: Option<String>,
70
71    pub endpoint: EndpointConfig,
72
73    #[cfg(feature = "transport-quic")]
74    pub transport: TransportConfig,
75
76    #[cfg(feature = "tracing")]
77    pub logging: LoggingConfig,
78}
79
80#[derive(Deserialize, Serialize, Debug, Parser, Clone, Default)]
81pub struct EndpointConfig {
82    /// The address to bind for the HTTP/HTTPS server
83    #[cfg(feature = "endpoint-http")]
84    #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub http: Option<SocketAddr>,
87
88    /// The address to bind for the SOCKS server
89    #[cfg(feature = "endpoint-socks")]
90    #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub socks: Option<SocketAddr>,
93
94    #[cfg(feature = "endpoint-tun")]
95    #[clap(flatten)]
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub tun: Option<TunConfig>,
98}
99
100#[cfg(feature = "transport-quic")]
101#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
102pub struct TransportConfig {
103    /// The address to bind for transport
104    #[clap(long, help_heading = "Transport", value_name = "ADDR")]
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub bind: Option<SocketAddr>,
107
108    /// Name of the server to connect (derived from `server` if not provided)
109    #[clap(long, help_heading = "Transport", value_name = "STR")]
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub server_name: Option<String>,
112
113    /// Set the TLS mode for the connection
114    /// tls: Standard TLS with server certificate verification
115    /// m-tls: Mutual TLS with client and server certificate verification
116    /// insecure: Skip server certificate verification (for testing only)
117    #[clap(long, value_enum, help_heading = "Transport")]
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub tls_mode: Option<TlsMode>,
120
121    /// Path to the Certificate Authority (CA) certificate file
122    /// in 'TLS' mode, if not provided, the system's default root certificates are used
123    #[clap(long, help_heading = "Transport", value_name = "FILE")]
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub ca_cert: Option<PathBuf>,
126
127    /// Path to the client's TLS certificate for mTLS
128    #[clap(long, help_heading = "Transport", value_name = "FILE")]
129    #[serde(skip_serializing_if = "Option::is_none")]
130    pub client_cert: Option<PathBuf>,
131
132    /// Path to the client's TLS private key for mTLS
133    #[clap(long, help_heading = "Transport", value_name = "FILE")]
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub client_key: Option<PathBuf>,
136
137    /// Enable 0-RTT for faster connection establishment (may reduce security)
138    #[clap(long, help_heading = "Transport", action)]
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub zero_rtt: Option<bool>,
141
142    /// Application-Layer protocol negotiation (ALPN) protocols [default: h3]
143    #[clap(
144        long,
145        help_heading = "Transport",
146        value_name = "PROTOCOLS",
147        value_delimiter = ','
148    )]
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub alpn_protocols: Option<Vec<Vec<u8>>>,
151
152    /// Congestion control algorithm to use (e.g. bbr, cubic, newreno) [default: bbr]
153    #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub congestion: Option<Congestion>,
156
157    /// Initial congestion window size in bytes
158    #[clap(long, help_heading = "Transport", value_name = "NUM")]
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub cwnd_init: Option<u64>,
161
162    /// Maximum idle time (in milliseconds) before closing the connection [default: 30000]
163    /// 30 second default recommended by RFC 9308
164    #[clap(long, help_heading = "Transport", value_name = "TIME")]
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub idle_timeout: Option<u64>,
167
168    /// Keep-alive interval (in milliseconds) [default: 8000]
169    #[clap(long, help_heading = "Transport", value_name = "TIME")]
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub keep_alive: Option<u64>,
172
173    /// Maximum number of bidirectional streams that can be open simultaneously [default: 100]
174    #[clap(long, help_heading = "Transport", value_name = "NUM")]
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub max_streams: Option<u64>,
177}
178
179#[cfg(feature = "tracing")]
180#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
181pub struct LoggingConfig {
182    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
183    #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub log_level: Option<String>,
186}
187
188#[cfg(feature = "endpoint-tun")]
189#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
190pub struct TunConfig {
191    /// Use a pre-existing TUN device by providing its file descriptor.  
192    /// `tun_ipv4`, `tun_ipv6`, and `tun_mtu` will be ignored.
193    #[clap(long, help_heading = "Endpoint", value_name = "FD")]
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub tun_fd: Option<i32>,
196
197    /// The IPv4 address and subnet for the TUN device, in CIDR notation (e.g., 198.19.0.1/24).
198    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub tun_ipv4: Option<String>,
201
202    /// The IPv6 address and subnet for the TUN device, in CIDR notation (e.g., fd00::1/64).
203    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub tun_ipv6: Option<String>,
206
207    /// The Maximum Transmission Unit (MTU) for the TUN device. [default: 1500]
208    #[clap(long, help_heading = "Endpoint", value_name = "U16")]
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub tun_mtu: Option<u16>,
211
212    /// The IPv4 address pool for the built-in fake DNS server, in CIDR notation. [default: 198.18.0.0/16]
213    #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub fake_dns: Option<String>,
216}
217
218#[cfg(feature = "transport-quic")]
219#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
220#[serde(rename_all = "kebab-case")]
221pub enum TlsMode {
222    #[default]
223    Tls,
224    MTls,
225    Insecure,
226}
227
228#[derive(Debug, Clone)]
229pub struct ServiceConfig {
230    pub secret: String,
231    pub server: String,
232    pub handshake_option: Option<String>,
233    pub endpoint: EndpointConfig,
234    #[cfg(feature = "transport-quic")]
235    pub transport: TransportConfig,
236    #[cfg(feature = "tracing")]
237    pub logging: LoggingConfig,
238}
239
240#[cfg(feature = "transport-quic")]
241impl Default for TransportConfig {
242    fn default() -> Self {
243        Self {
244            bind: None,
245            server_name: None,
246            tls_mode: Some(TlsMode::Tls),
247            ca_cert: None,
248            client_cert: None,
249            client_key: None,
250            zero_rtt: Some(false),
251            alpn_protocols: Some(vec!["h3".into()]),
252            congestion: Some(Congestion::Bbr),
253            cwnd_init: None,
254            idle_timeout: Some(30000),
255            keep_alive: Some(8000),
256            max_streams: Some(100),
257        }
258    }
259}
260
261#[cfg(feature = "transport-quic")]
262#[cfg(feature = "endpoint-tun")]
263impl Default for TunConfig {
264    fn default() -> Self {
265        Self {
266            tun_fd: None,
267            tun_ipv4: None,
268            tun_ipv6: None,
269            tun_mtu: Some(1500),
270            fake_dns: Some("198.18.0.0/16".to_string()),
271        }
272    }
273}
274
275#[cfg(feature = "tracing")]
276impl Default for LoggingConfig {
277    fn default() -> Self {
278        Self {
279            log_level: Some("INFO".to_string()),
280        }
281    }
282}
283
284pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
285    let args = Args::parse();
286
287    let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
288
289    if let Some(config_path) = &args.config {
290        if !config_path.exists() {
291            let err = io::Error::new(
292                io::ErrorKind::NotFound,
293                format!("Configuration file not found: {}", config_path.display()),
294            );
295            return Err(Box::new(figment::Error::from(err.to_string())));
296        }
297
298        figment = figment.merge(Json::file(config_path));
299    }
300
301    let cli_overrides = ConfigFile {
302        secret: args.secret,
303        server: args.server,
304        handshake_option: args.handshake_option,
305        endpoint: args.endpoint,
306        #[cfg(feature = "transport-quic")]
307        transport: args.transport,
308        #[cfg(feature = "tracing")]
309        logging: args.logging,
310    };
311
312    figment = figment.merge(Serialized::defaults(cli_overrides));
313
314    let config: ConfigFile = figment.extract()?;
315
316    let secret = config
317        .secret
318        .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
319    let server = config
320        .server
321        .ok_or_else(|| figment::Error::from("missing field `server`"))?;
322
323    Ok(ServiceConfig {
324        secret,
325        server,
326        handshake_option: config.handshake_option,
327        endpoint: config.endpoint,
328        #[cfg(feature = "transport-quic")]
329        transport: config.transport,
330        #[cfg(feature = "tracing")]
331        logging: config.logging,
332    })
333}
334
335fn styles() -> Styles {
336    Styles::styled()
337        .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
338        .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
339        .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
340        .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
341        .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
342        .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
343        .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
344}