Skip to main content

ombrac_client/config/
mod.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12#[derive(Deserialize, Serialize, Debug, Clone, Default)]
13pub struct EndpointConfig {
14    /// The address to bind for the HTTP/HTTPS server
15    #[cfg(feature = "endpoint-http")]
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub http: Option<SocketAddr>,
18
19    /// The address to bind for the SOCKS server
20    #[cfg(feature = "endpoint-socks")]
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub socks: Option<SocketAddr>,
23
24    #[cfg(feature = "endpoint-tun")]
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tun: Option<TunConfig>,
27}
28
29#[derive(Deserialize, Serialize, Debug, Clone)]
30#[serde(rename_all = "snake_case")]
31pub struct TransportConfig {
32    /// The address to bind for transport
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub bind: Option<SocketAddr>,
35
36    /// Name of the server to connect (derived from `server` if not provided)
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub server_name: Option<String>,
39
40    /// Set the TLS mode for the connection
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub tls_mode: Option<TlsMode>,
43
44    /// Path to the Certificate Authority (CA) certificate file
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub ca_cert: Option<PathBuf>,
47
48    /// Path to the client's TLS certificate for mTLS
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub client_cert: Option<PathBuf>,
51
52    /// Path to the client's TLS private key for mTLS
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub client_key: Option<PathBuf>,
55
56    /// Enable 0-RTT for faster connection establishment
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub zero_rtt: Option<bool>,
59
60    /// Application-Layer protocol negotiation (ALPN) protocols [default: h3]
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub alpn_protocols: Option<Vec<Vec<u8>>>,
63
64    /// Congestion control algorithm to use (e.g. bbr, cubic, newreno) [default: bbr]
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub congestion: Option<Congestion>,
67
68    /// Initial congestion window size in bytes
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub cwnd_init: Option<u64>,
71
72    /// Maximum idle time (in milliseconds) before closing the connection [default: 30000]
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub idle_timeout: Option<u64>,
75
76    /// Keep-alive interval (in milliseconds) [default: 8000]
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub keep_alive: Option<u64>,
79
80    /// Maximum number of bidirectional streams that can be open simultaneously [default: 100]
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub max_streams: Option<u64>,
83}
84
85impl Default for TransportConfig {
86    fn default() -> Self {
87        Self {
88            bind: None,
89            server_name: None,
90            tls_mode: Some(TlsMode::Tls),
91            ca_cert: None,
92            client_cert: None,
93            client_key: None,
94            zero_rtt: Some(false),
95            alpn_protocols: Some(vec!["h3".into()]),
96            congestion: Some(Congestion::Bbr),
97            cwnd_init: None,
98            idle_timeout: Some(30000),
99            keep_alive: Some(8000),
100            max_streams: Some(100),
101        }
102    }
103}
104
105#[cfg(feature = "tracing")]
106#[derive(Deserialize, Serialize, Debug, Clone)]
107#[serde(rename_all = "snake_case")]
108pub struct LoggingConfig {
109    /// Logging level (e.g., INFO, WARN, ERROR) [default: INFO]
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub log_level: Option<String>,
112}
113
114#[cfg(feature = "tracing")]
115impl Default for LoggingConfig {
116    fn default() -> Self {
117        Self {
118            log_level: Some("INFO".to_string()),
119        }
120    }
121}
122
123#[cfg(feature = "endpoint-tun")]
124#[derive(Deserialize, Serialize, Debug, Clone)]
125pub struct TunConfig {
126    /// Use a pre-existing TUN device by providing its file descriptor.  
127    /// `tun_ipv4`, `tun_ipv6`, and `tun_mtu` will be ignored.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub tun_fd: Option<i32>,
130
131    /// The IPv4 address and subnet for the TUN device, in CIDR notation (e.g., 198.19.0.1/24).
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub tun_ipv4: Option<String>,
134
135    /// The IPv6 address and subnet for the TUN device, in CIDR notation (e.g., fd00::1/64).
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub tun_ipv6: Option<String>,
138
139    /// The Maximum Transmission Unit (MTU) for the TUN device. [default: 1500]
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub tun_mtu: Option<u16>,
142
143    /// The IPv4 address pool for the built-in fake DNS server, in CIDR notation. [default: 198.18.0.0/16]
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub fake_dns: Option<String>,
146
147    /// Disable UDP traffic to port 443
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub disable_udp_443: Option<bool>,
150}
151
152#[cfg(feature = "endpoint-tun")]
153impl Default for TunConfig {
154    fn default() -> Self {
155        Self {
156            tun_fd: None,
157            tun_ipv4: None,
158            tun_ipv6: None,
159            tun_mtu: Some(1500),
160            fake_dns: Some("198.18.0.0/16".to_string()),
161            disable_udp_443: Some(false),
162        }
163    }
164}
165
166#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
167#[serde(rename_all = "kebab-case")]
168pub enum TlsMode {
169    #[default]
170    Tls,
171    MTls,
172    Insecure,
173}
174
175/// Final service configuration with all defaults applied
176#[derive(Debug, Clone)]
177pub struct ServiceConfig {
178    pub secret: String,
179    pub server: String,
180    pub auth_option: Option<String>,
181    pub endpoint: EndpointConfig,
182    pub transport: TransportConfig,
183    #[cfg(feature = "tracing")]
184    pub logging: LoggingConfig,
185}
186
187/// Configuration builder that merges different configuration sources
188/// and applies defaults in a clear, predictable order
189pub struct ConfigBuilder {
190    secret: Option<String>,
191    server: Option<String>,
192    auth_option: Option<String>,
193    endpoint: EndpointConfig,
194    transport: TransportConfig,
195    #[cfg(feature = "tracing")]
196    logging: LoggingConfig,
197}
198
199impl ConfigBuilder {
200    /// Create a new builder with default values
201    pub fn new() -> Self {
202        Self {
203            secret: None,
204            server: None,
205            auth_option: None,
206            endpoint: EndpointConfig::default(),
207            transport: TransportConfig::default(),
208            #[cfg(feature = "tracing")]
209            logging: LoggingConfig::default(),
210        }
211    }
212
213    /// Merge JSON configuration into the builder
214    pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
215        if let Some(secret) = json_config.secret {
216            self.secret = Some(secret);
217        }
218        if let Some(server) = json_config.server {
219            self.server = Some(server);
220        }
221        if let Some(auth_option) = json_config.auth_option {
222            self.auth_option = Some(auth_option);
223        }
224        if let Some(endpoint) = json_config.endpoint {
225            self.endpoint = Self::merge_endpoint(self.endpoint, endpoint);
226        }
227        if let Some(transport) = json_config.transport {
228            self.transport = Self::merge_transport(self.transport, transport);
229        }
230        #[cfg(feature = "tracing")]
231        {
232            if let Some(logging) = json_config.logging {
233                self.logging = Self::merge_logging(self.logging, logging);
234            }
235        }
236        self
237    }
238
239    /// Merge CLI configuration into the builder (CLI overrides JSON)
240    pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
241        if let Some(secret) = cli_config.secret {
242            self.secret = Some(secret);
243        }
244        if let Some(server) = cli_config.server {
245            self.server = Some(server);
246        }
247        if let Some(auth_option) = cli_config.auth_option {
248            self.auth_option = Some(auth_option);
249        }
250        self.endpoint = Self::merge_endpoint(self.endpoint, cli_config.endpoint);
251        self.transport = Self::merge_transport(self.transport, cli_config.transport);
252        #[cfg(feature = "tracing")]
253        {
254            self.logging = Self::merge_logging(self.logging, cli_config.logging);
255        }
256        self
257    }
258
259    /// Build the final ServiceConfig, validating required fields
260    pub fn build(self) -> Result<ServiceConfig, String> {
261        let secret = self
262            .secret
263            .ok_or_else(|| "missing required field: secret".to_string())?;
264        let server = self
265            .server
266            .ok_or_else(|| "missing required field: server".to_string())?;
267
268        Ok(ServiceConfig {
269            secret,
270            server,
271            auth_option: self.auth_option,
272            endpoint: self.endpoint,
273            transport: self.transport,
274            #[cfg(feature = "tracing")]
275            logging: self.logging,
276        })
277    }
278
279    fn merge_endpoint(_base: EndpointConfig, _override_config: EndpointConfig) -> EndpointConfig {
280        EndpointConfig {
281            #[cfg(feature = "endpoint-http")]
282            http: _override_config.http.or(_base.http),
283            #[cfg(feature = "endpoint-socks")]
284            socks: _override_config.socks.or(_base.socks),
285            #[cfg(feature = "endpoint-tun")]
286            tun: Self::merge_tun(_base.tun, _override_config.tun),
287        }
288    }
289
290    #[cfg(feature = "endpoint-tun")]
291    fn merge_tun(base: Option<TunConfig>, override_config: Option<TunConfig>) -> Option<TunConfig> {
292        match (base, override_config) {
293            (None, None) => None,
294            (Some(base), None) => Some(base),
295            (None, Some(override_config)) => Some(override_config),
296            (Some(base), Some(override_config)) => Some(TunConfig {
297                tun_fd: override_config.tun_fd.or(base.tun_fd),
298                tun_ipv4: override_config.tun_ipv4.or(base.tun_ipv4),
299                tun_ipv6: override_config.tun_ipv6.or(base.tun_ipv6),
300                tun_mtu: override_config.tun_mtu.or(base.tun_mtu),
301                fake_dns: override_config.fake_dns.or(base.fake_dns),
302                disable_udp_443: override_config.disable_udp_443.or(base.disable_udp_443),
303            }),
304        }
305    }
306
307    fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
308        TransportConfig {
309            bind: override_config.bind.or(base.bind),
310            server_name: override_config.server_name.or(base.server_name),
311            tls_mode: override_config.tls_mode.or(base.tls_mode),
312            ca_cert: override_config.ca_cert.or(base.ca_cert),
313            client_cert: override_config.client_cert.or(base.client_cert),
314            client_key: override_config.client_key.or(base.client_key),
315            zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
316            alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
317            congestion: override_config.congestion.or(base.congestion),
318            cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
319            idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
320            keep_alive: override_config.keep_alive.or(base.keep_alive),
321            max_streams: override_config.max_streams.or(base.max_streams),
322        }
323    }
324
325    #[cfg(feature = "tracing")]
326    fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
327        LoggingConfig {
328            log_level: override_config.log_level.or(base.log_level),
329        }
330    }
331}
332
333impl Default for ConfigBuilder {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339/// Load configuration from command-line arguments and/or JSON file.
340///
341/// This function merges configurations in the following order:
342/// 1. Default configuration values
343/// 2. Values from JSON config file (if provided)
344/// 3. Command-line argument overrides
345///
346/// # Returns
347///
348/// A `ServiceConfig` ready to use, or an error if required fields are missing.
349#[cfg(feature = "binary")]
350pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
351    use clap::Parser;
352    let cli_args = cli::Args::parse();
353    let mut builder = ConfigBuilder::new();
354
355    // Load JSON config if specified
356    if let Some(config_path) = &cli_args.config {
357        let json_config = json::JsonConfig::from_file(config_path)?;
358        builder = builder.merge_json(json_config);
359    }
360
361    // Merge CLI overrides
362    let cli_config = cli::CliConfig {
363        secret: cli_args.secret,
364        server: cli_args.server,
365        auth_option: cli_args.auth_option,
366        endpoint: cli_args.endpoint.into_endpoint_config(),
367        transport: cli_args.transport.into_transport_config(),
368        #[cfg(feature = "tracing")]
369        logging: cli_args.logging.into_logging_config(),
370    };
371    builder = builder.merge_cli(cli_config);
372
373    builder.build().map_err(|e| e.into())
374}
375
376/// Loads configuration from a JSON string.
377///
378/// This function is useful for programmatic configuration or when loading
379/// from external sources (e.g., environment variables, API responses).
380///
381/// # Arguments
382///
383/// * `json_str` - A JSON string containing the configuration
384///
385/// # Returns
386///
387/// A `ServiceConfig` ready to use, or an error if parsing fails or required fields are missing.
388pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
389    let json_config = json::JsonConfig::from_json_str(json_str)?;
390    ConfigBuilder::new()
391        .merge_json(json_config)
392        .build()
393        .map_err(|e| e.into())
394}
395
396/// Loads configuration from a JSON file.
397///
398/// This function reads configuration from a file path.
399///
400/// # Arguments
401///
402/// * `config_path` - Path to the JSON configuration file
403///
404/// # Returns
405///
406/// A `ServiceConfig` ready to use, or an error if the file doesn't exist,
407/// parsing fails, or required fields are missing.
408pub fn load_from_file(
409    config_path: &std::path::Path,
410) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
411    let json_config = json::JsonConfig::from_file(config_path)?;
412    ConfigBuilder::new()
413        .merge_json(json_config)
414        .build()
415        .map_err(|e| e.into())
416}