1use std::io;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4
5use clap::builder::Styles;
6use clap::builder::styling::{AnsiColor, Style};
7use clap::{Parser, ValueEnum};
8use figment::Figment;
9use figment::providers::{Format, Json, Serialized};
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "transport-quic")]
13use ombrac_transport::quic::Congestion;
14
15#[derive(Parser, Debug)]
17#[command(version, about, long_about = None, styles = styles())]
18pub struct Args {
19 #[clap(long, short = 'c', value_name = "FILE")]
21 pub config: Option<PathBuf>,
22
23 #[clap(
25 long,
26 short = 'k',
27 help_heading = "Required",
28 value_name = "STR",
29 required_unless_present = "config"
30 )]
31 pub secret: Option<String>,
32
33 #[clap(
35 long,
36 short = 's',
37 help_heading = "Required",
38 value_name = "ADDR",
39 required_unless_present = "config"
40 )]
41 pub server: Option<String>,
42
43 #[clap(long, help_heading = "Protocol", value_name = "STR")]
45 pub handshake_option: Option<String>,
46
47 #[clap(flatten)]
48 pub endpoint: EndpointConfig,
49
50 #[cfg(feature = "transport-quic")]
51 #[clap(flatten)]
52 pub transport: TransportConfig,
53
54 #[cfg(feature = "tracing")]
55 #[clap(flatten)]
56 pub logging: LoggingConfig,
57}
58
59#[derive(Deserialize, Serialize, Debug, Default)]
61pub struct ConfigFile {
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub secret: Option<String>,
64
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub server: Option<String>,
67
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub handshake_option: Option<String>,
70
71 pub endpoint: EndpointConfig,
72
73 #[cfg(feature = "transport-quic")]
74 pub transport: TransportConfig,
75
76 #[cfg(feature = "tracing")]
77 pub logging: LoggingConfig,
78}
79
80#[derive(Deserialize, Serialize, Debug, Parser, Clone, Default)]
81pub struct EndpointConfig {
82 #[cfg(feature = "endpoint-http")]
84 #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub http: Option<SocketAddr>,
87
88 #[cfg(feature = "endpoint-socks")]
90 #[clap(long, value_name = "ADDR", help_heading = "Endpoint")]
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub socks: Option<SocketAddr>,
93
94 #[cfg(feature = "endpoint-tun")]
95 #[clap(flatten)]
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub tun: Option<TunConfig>,
98}
99
100#[cfg(feature = "transport-quic")]
101#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
102pub struct TransportConfig {
103 #[clap(long, help_heading = "Transport", value_name = "ADDR")]
105 #[serde(skip_serializing_if = "Option::is_none")]
106 pub bind: Option<SocketAddr>,
107
108 #[clap(long, help_heading = "Transport", value_name = "STR")]
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub server_name: Option<String>,
112
113 #[clap(long, value_enum, help_heading = "Transport")]
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub tls_mode: Option<TlsMode>,
120
121 #[clap(long, help_heading = "Transport", value_name = "FILE")]
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub ca_cert: Option<PathBuf>,
126
127 #[clap(long, help_heading = "Transport", value_name = "FILE")]
129 #[serde(skip_serializing_if = "Option::is_none")]
130 pub client_cert: Option<PathBuf>,
131
132 #[clap(long, help_heading = "Transport", value_name = "FILE")]
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub client_key: Option<PathBuf>,
136
137 #[clap(long, help_heading = "Transport", action)]
139 #[serde(skip_serializing_if = "Option::is_none")]
140 pub zero_rtt: Option<bool>,
141
142 #[clap(
144 long,
145 help_heading = "Transport",
146 value_name = "PROTOCOLS",
147 value_delimiter = ','
148 )]
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub alpn_protocols: Option<Vec<Vec<u8>>>,
151
152 #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub congestion: Option<Congestion>,
156
157 #[clap(long, help_heading = "Transport", value_name = "NUM")]
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub cwnd_init: Option<u64>,
161
162 #[clap(long, help_heading = "Transport", value_name = "TIME")]
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub idle_timeout: Option<u64>,
167
168 #[clap(long, help_heading = "Transport", value_name = "TIME")]
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub keep_alive: Option<u64>,
172
173 #[clap(long, help_heading = "Transport", value_name = "NUM")]
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub max_streams: Option<u64>,
177}
178
179#[cfg(feature = "tracing")]
180#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
181pub struct LoggingConfig {
182 #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub log_level: Option<String>,
186}
187
188#[cfg(feature = "endpoint-tun")]
189#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
190pub struct TunConfig {
191 #[clap(long, help_heading = "Endpoint", value_name = "FD")]
194 #[serde(skip_serializing_if = "Option::is_none")]
195 pub tun_fd: Option<i32>,
196
197 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
199 #[serde(skip_serializing_if = "Option::is_none")]
200 pub tun_ipv4: Option<String>,
201
202 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub tun_ipv6: Option<String>,
206
207 #[clap(long, help_heading = "Endpoint", value_name = "U16")]
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub tun_mtu: Option<u16>,
211
212 #[clap(long, help_heading = "Endpoint", value_name = "CIDR")]
214 #[serde(skip_serializing_if = "Option::is_none")]
215 pub fake_dns: Option<String>,
216}
217
218#[cfg(feature = "transport-quic")]
219#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
220#[serde(rename_all = "kebab-case")]
221pub enum TlsMode {
222 #[default]
223 Tls,
224 MTls,
225 Insecure,
226}
227
228#[derive(Debug, Clone)]
229pub struct ServiceConfig {
230 pub secret: String,
231 pub server: String,
232 pub handshake_option: Option<String>,
233 pub endpoint: EndpointConfig,
234 #[cfg(feature = "transport-quic")]
235 pub transport: TransportConfig,
236 #[cfg(feature = "tracing")]
237 pub logging: LoggingConfig,
238}
239
240#[cfg(feature = "transport-quic")]
241impl Default for TransportConfig {
242 fn default() -> Self {
243 Self {
244 bind: None,
245 server_name: None,
246 tls_mode: Some(TlsMode::Tls),
247 ca_cert: None,
248 client_cert: None,
249 client_key: None,
250 zero_rtt: Some(false),
251 alpn_protocols: Some(vec!["h3".into()]),
252 congestion: Some(Congestion::Bbr),
253 cwnd_init: None,
254 idle_timeout: Some(30000),
255 keep_alive: Some(8000),
256 max_streams: Some(100),
257 }
258 }
259}
260
261#[cfg(feature = "transport-quic")]
262#[cfg(feature = "endpoint-tun")]
263impl Default for TunConfig {
264 fn default() -> Self {
265 Self {
266 tun_fd: None,
267 tun_ipv4: None,
268 tun_ipv6: None,
269 tun_mtu: Some(1500),
270 fake_dns: Some("198.18.0.0/16".to_string()),
271 }
272 }
273}
274
275#[cfg(feature = "tracing")]
276impl Default for LoggingConfig {
277 fn default() -> Self {
278 Self {
279 log_level: Some("INFO".to_string()),
280 }
281 }
282}
283
284pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
285 let args = Args::parse();
286
287 let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
288
289 if let Some(config_path) = &args.config {
290 if !config_path.exists() {
291 let err = io::Error::new(
292 io::ErrorKind::NotFound,
293 format!("Configuration file not found: {}", config_path.display()),
294 );
295 return Err(Box::new(figment::Error::from(err.to_string())));
296 }
297
298 figment = figment.merge(Json::file(config_path));
299 }
300
301 let cli_overrides = ConfigFile {
302 secret: args.secret,
303 server: args.server,
304 handshake_option: args.handshake_option,
305 endpoint: args.endpoint,
306 #[cfg(feature = "transport-quic")]
307 transport: args.transport,
308 #[cfg(feature = "tracing")]
309 logging: args.logging,
310 };
311
312 figment = figment.merge(Serialized::defaults(cli_overrides));
313
314 let config: ConfigFile = figment.extract()?;
315
316 let secret = config
317 .secret
318 .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
319 let server = config
320 .server
321 .ok_or_else(|| figment::Error::from("missing field `server`"))?;
322
323 Ok(ServiceConfig {
324 secret,
325 server,
326 handshake_option: config.handshake_option,
327 endpoint: config.endpoint,
328 #[cfg(feature = "transport-quic")]
329 transport: config.transport,
330 #[cfg(feature = "tracing")]
331 logging: config.logging,
332 })
333}
334
335fn styles() -> Styles {
336 Styles::styled()
337 .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
338 .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
339 .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
340 .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
341 .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
342 .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
343 .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
344}