1use std::io;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use clap::builder::Styles;
6use clap::builder::styling::{AnsiColor, Style};
7use clap::{Parser, ValueEnum};
8use figment::Figment;
9use figment::providers::{Format, Json, Serialized};
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::Congestion;
14
15#[derive(Parser, Debug)]
17#[command(version, about, long_about = None, styles = styles())]
18pub struct Args {
19 #[clap(long, short = 'c', value_name = "FILE")]
21 pub config: Option<PathBuf>,
22
23 #[clap(
25 long,
26 short = 'k',
27 help_heading = "Required",
28 value_name = "STR",
29 required_unless_present = "config"
30 )]
31 pub secret: Option<String>,
32
33 #[clap(
35 long,
36 short = 's',
37 help_heading = "Required",
38 value_name = "ADDR",
39 required_unless_present = "config"
40 )]
41 pub server: Option<String>,
42
43 #[clap(flatten)]
44 pub endpoint: EndpointConfig,
45
46 #[cfg(feature = "transport-quic")]
47 #[clap(flatten)]
48 pub transport: TransportConfig,
49
50 #[cfg(feature = "tracing")]
51 #[clap(flatten)]
52 pub logging: LoggingConfig,
53}
54
55#[derive(Deserialize, Serialize, Debug, Default)]
57pub struct ConfigFile {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub secret: Option<String>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub server: Option<String>,
63
64 pub endpoint: EndpointConfig,
65
66 #[cfg(feature = "transport-quic")]
67 pub transport: TransportConfig,
68
69 #[cfg(feature = "tracing")]
70 pub logging: LoggingConfig,
71}
72
73#[derive(Deserialize, Serialize, Debug, Parser, Clone, Default)]
74pub struct EndpointConfig {
75 #[cfg(feature = "endpoint-http")]
77 #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub http: Option<SocketAddr>,
80
81 #[cfg(feature = "endpoint-socks")]
83 #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub socks: Option<SocketAddr>,
86
87 #[cfg(feature = "endpoint-tun")]
88 #[clap(flatten)]
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub tun: Option<TunConfig>,
91}
92
93#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
94#[cfg(feature = "transport-quic")]
95pub struct TransportConfig {
96 #[clap(long, help_heading = "Transport", value_name = "ADDR")]
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub bind: Option<SocketAddr>,
100
101 #[clap(long, help_heading = "Transport", value_name = "STR")]
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub server_name: Option<String>,
105
106 #[clap(long, value_enum, help_heading = "Transport")]
111 #[serde(skip_serializing_if = "Option::is_none")]
112 pub tls_mode: Option<TlsMode>,
113
114 #[clap(long, help_heading = "Transport", value_name = "FILE")]
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub ca_cert: Option<PathBuf>,
119
120 #[clap(long, help_heading = "Transport", value_name = "FILE")]
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub client_cert: Option<PathBuf>,
124
125 #[clap(long, help_heading = "Transport", value_name = "FILE")]
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub client_key: Option<PathBuf>,
129
130 #[clap(long, help_heading = "Transport", action)]
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub zero_rtt: Option<bool>,
134
135 #[clap(
137 long,
138 help_heading = "Transport",
139 value_name = "PROTOCOLS",
140 value_delimiter = ','
141 )]
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub alpn_protocols: Option<Vec<Vec<u8>>>,
144
145 #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
147 #[serde(skip_serializing_if = "Option::is_none")]
148 pub congestion: Option<Congestion>,
149
150 #[clap(long, help_heading = "Transport", value_name = "NUM")]
152 #[serde(skip_serializing_if = "Option::is_none")]
153 pub cwnd_init: Option<u64>,
154
155 #[clap(long, help_heading = "Transport", value_name = "TIME")]
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub idle_timeout: Option<u64>,
160
161 #[clap(long, help_heading = "Transport", value_name = "TIME")]
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub keep_alive: Option<u64>,
165
166 #[clap(long, help_heading = "Transport", value_name = "NUM")]
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub max_streams: Option<u64>,
170}
171
172#[cfg(feature = "tracing")]
173#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
174pub struct LoggingConfig {
175 #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
177 #[serde(skip_serializing_if = "Option::is_none")]
178 pub log_level: Option<String>,
179}
180
181#[cfg(feature = "endpoint-tun")]
182#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
183pub struct TunConfig {
184 #[clap(long, help_heading = "Endpoint", value_name = "FD")]
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub tun_fd: Option<i32>,
189
190 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
192 #[serde(skip_serializing_if = "Option::is_none")]
193 pub tun_ipv4: Option<String>,
194
195 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
197 #[serde(skip_serializing_if = "Option::is_none")]
198 pub tun_ipv6: Option<String>,
199
200 #[clap(long, help_heading = "Endpoint", value_name = "U16")]
202 #[serde(skip_serializing_if = "Option::is_none")]
203 pub tun_mtu: Option<u16>,
204
205 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub fake_dns: Option<String>,
209}
210
211#[cfg(feature = "transport-quic")]
212#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq)]
213#[serde(rename_all = "kebab-case")]
214#[derive(Default)]
215pub enum TlsMode {
216 #[default]
217 Tls,
218 MTls,
219 Insecure,
220}
221
222#[derive(Debug, Clone)]
223pub struct ServiceConfig {
224 pub secret: String,
225 pub server: String,
226 pub endpoint: EndpointConfig,
227 #[cfg(feature = "transport-quic")]
228 pub transport: TransportConfig,
229 #[cfg(feature = "tracing")]
230 pub logging: LoggingConfig,
231}
232
233#[cfg(feature = "transport-quic")]
234impl Default for TransportConfig {
235 fn default() -> Self {
236 Self {
237 bind: None,
238 server_name: None,
239 tls_mode: Some(TlsMode::Tls),
240 ca_cert: None,
241 client_cert: None,
242 client_key: None,
243 zero_rtt: Some(false),
244 alpn_protocols: Some(vec!["h3".into()]),
245 congestion: Some(Congestion::Bbr),
246 cwnd_init: None,
247 idle_timeout: Some(30000),
248 keep_alive: Some(8000),
249 max_streams: Some(100),
250 }
251 }
252}
253
254#[cfg(feature = "transport-quic")]
255#[cfg(feature = "endpoint-tun")]
256impl Default for TunConfig {
257 fn default() -> Self {
258 Self {
259 tun_fd: None,
260 tun_ipv4: None,
261 tun_ipv6: None,
262 tun_mtu: Some(1500),
263 fake_dns: Some("198.18.0.0/16".to_string()),
264 }
265 }
266}
267
268#[cfg(feature = "tracing")]
269impl Default for LoggingConfig {
270 fn default() -> Self {
271 Self {
272 log_level: Some("INFO".to_string()),
273 }
274 }
275}
276
277pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
278 let args = Args::parse();
279
280 let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
281
282 if let Some(config_path) = &args.config {
283 if !config_path.exists() {
284 let err = io::Error::new(
285 io::ErrorKind::NotFound,
286 format!("Configuration file not found: {}", config_path.display()),
287 );
288 return Err(Box::new(figment::Error::from(err.to_string())));
289 }
290
291 figment = figment.merge(Json::file(config_path));
292 }
293
294 let cli_overrides = ConfigFile {
295 secret: args.secret,
296 server: args.server,
297 endpoint: args.endpoint,
298 #[cfg(feature = "transport-quic")]
299 transport: args.transport,
300 #[cfg(feature = "tracing")]
301 logging: args.logging,
302 };
303
304 figment = figment.merge(Serialized::defaults(cli_overrides));
305
306 let config: ConfigFile = figment.extract()?;
307
308 let secret = config
309 .secret
310 .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
311 let server = config
312 .server
313 .ok_or_else(|| figment::Error::from("missing field `server`"))?;
314
315 Ok(ServiceConfig {
316 secret,
317 server,
318 endpoint: config.endpoint,
319 #[cfg(feature = "transport-quic")]
320 transport: config.transport,
321 #[cfg(feature = "tracing")]
322 logging: config.logging,
323 })
324}
325
326fn styles() -> Styles {
327 Styles::styled()
328 .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
329 .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
330 .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
331 .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
332 .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
333 .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
334 .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
335}