1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12#[derive(Deserialize, Serialize, Debug, Clone, Default)]
13pub struct EndpointConfig {
14 #[cfg(feature = "endpoint-http")]
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub http: Option<SocketAddr>,
18
19 #[cfg(feature = "endpoint-socks")]
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub socks: Option<SocketAddr>,
23
24 #[cfg(feature = "endpoint-tun")]
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tun: Option<TunConfig>,
27}
28
29#[derive(Deserialize, Serialize, Debug, Clone)]
30#[serde(rename_all = "snake_case")]
31pub struct TransportConfig {
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub bind: Option<SocketAddr>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub server_name: Option<String>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tls_mode: Option<TlsMode>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub ca_cert: Option<PathBuf>,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub client_cert: Option<PathBuf>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub client_key: Option<PathBuf>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub zero_rtt: Option<bool>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub alpn_protocols: Option<Vec<Vec<u8>>>,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub congestion: Option<Congestion>,
67
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub cwnd_init: Option<u64>,
71
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub idle_timeout: Option<u64>,
75
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub keep_alive: Option<u64>,
79
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub max_streams: Option<u64>,
83}
84
85impl Default for TransportConfig {
86 fn default() -> Self {
87 Self {
88 bind: None,
89 server_name: None,
90 tls_mode: Some(TlsMode::Tls),
91 ca_cert: None,
92 client_cert: None,
93 client_key: None,
94 zero_rtt: Some(false),
95 alpn_protocols: Some(vec!["h3".into()]),
96 congestion: Some(Congestion::Bbr),
97 cwnd_init: None,
98 idle_timeout: Some(30000),
99 keep_alive: Some(8000),
100 max_streams: Some(100),
101 }
102 }
103}
104
105#[cfg(feature = "tracing")]
106#[derive(Deserialize, Serialize, Debug, Clone)]
107#[serde(rename_all = "snake_case")]
108pub struct LoggingConfig {
109 #[serde(skip_serializing_if = "Option::is_none")]
111 pub log_level: Option<String>,
112}
113
114#[cfg(feature = "tracing")]
115impl Default for LoggingConfig {
116 fn default() -> Self {
117 Self {
118 log_level: Some("INFO".to_string()),
119 }
120 }
121}
122
123#[cfg(feature = "endpoint-tun")]
124#[derive(Deserialize, Serialize, Debug, Clone)]
125pub struct TunConfig {
126 #[serde(skip_serializing_if = "Option::is_none")]
129 pub tun_fd: Option<i32>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub tun_ipv4: Option<String>,
134
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub tun_ipv6: Option<String>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub tun_mtu: Option<u16>,
142
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub fake_dns: Option<String>,
146
147 #[serde(skip_serializing_if = "Option::is_none")]
149 pub disable_udp_443: Option<bool>,
150}
151
152#[cfg(feature = "endpoint-tun")]
153impl Default for TunConfig {
154 fn default() -> Self {
155 Self {
156 tun_fd: None,
157 tun_ipv4: None,
158 tun_ipv6: None,
159 tun_mtu: Some(1500),
160 fake_dns: Some("198.18.0.0/16".to_string()),
161 disable_udp_443: Some(false),
162 }
163 }
164}
165
166#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
167#[serde(rename_all = "kebab-case")]
168pub enum TlsMode {
169 #[default]
170 Tls,
171 MTls,
172 Insecure,
173}
174
175#[derive(Debug, Clone)]
177pub struct ServiceConfig {
178 pub secret: String,
179 pub server: String,
180 pub auth_option: Option<String>,
181 pub endpoint: EndpointConfig,
182 pub transport: TransportConfig,
183 #[cfg(feature = "tracing")]
184 pub logging: LoggingConfig,
185}
186
187pub struct ConfigBuilder {
190 secret: Option<String>,
191 server: Option<String>,
192 auth_option: Option<String>,
193 endpoint: EndpointConfig,
194 transport: TransportConfig,
195 #[cfg(feature = "tracing")]
196 logging: LoggingConfig,
197}
198
199impl ConfigBuilder {
200 pub fn new() -> Self {
202 Self {
203 secret: None,
204 server: None,
205 auth_option: None,
206 endpoint: EndpointConfig::default(),
207 transport: TransportConfig::default(),
208 #[cfg(feature = "tracing")]
209 logging: LoggingConfig::default(),
210 }
211 }
212
213 pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
215 if let Some(secret) = json_config.secret {
216 self.secret = Some(secret);
217 }
218 if let Some(server) = json_config.server {
219 self.server = Some(server);
220 }
221 if let Some(auth_option) = json_config.auth_option {
222 self.auth_option = Some(auth_option);
223 }
224 if let Some(endpoint) = json_config.endpoint {
225 self.endpoint = Self::merge_endpoint(self.endpoint, endpoint);
226 }
227 if let Some(transport) = json_config.transport {
228 self.transport = Self::merge_transport(self.transport, transport);
229 }
230 #[cfg(feature = "tracing")]
231 {
232 if let Some(logging) = json_config.logging {
233 self.logging = Self::merge_logging(self.logging, logging);
234 }
235 }
236 self
237 }
238
239 pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
241 if let Some(secret) = cli_config.secret {
242 self.secret = Some(secret);
243 }
244 if let Some(server) = cli_config.server {
245 self.server = Some(server);
246 }
247 if let Some(auth_option) = cli_config.auth_option {
248 self.auth_option = Some(auth_option);
249 }
250 self.endpoint = Self::merge_endpoint(self.endpoint, cli_config.endpoint);
251 self.transport = Self::merge_transport(self.transport, cli_config.transport);
252 #[cfg(feature = "tracing")]
253 {
254 self.logging = Self::merge_logging(self.logging, cli_config.logging);
255 }
256 self
257 }
258
259 pub fn build(self) -> Result<ServiceConfig, String> {
261 let secret = self
262 .secret
263 .ok_or_else(|| "missing required field: secret".to_string())?;
264 let server = self
265 .server
266 .ok_or_else(|| "missing required field: server".to_string())?;
267
268 Ok(ServiceConfig {
269 secret,
270 server,
271 auth_option: self.auth_option,
272 endpoint: self.endpoint,
273 transport: self.transport,
274 #[cfg(feature = "tracing")]
275 logging: self.logging,
276 })
277 }
278
279 fn merge_endpoint(_base: EndpointConfig, _override_config: EndpointConfig) -> EndpointConfig {
280 EndpointConfig {
281 #[cfg(feature = "endpoint-http")]
282 http: _override_config.http.or(_base.http),
283 #[cfg(feature = "endpoint-socks")]
284 socks: _override_config.socks.or(_base.socks),
285 #[cfg(feature = "endpoint-tun")]
286 tun: Self::merge_tun(_base.tun, _override_config.tun),
287 }
288 }
289
290 #[cfg(feature = "endpoint-tun")]
291 fn merge_tun(base: Option<TunConfig>, override_config: Option<TunConfig>) -> Option<TunConfig> {
292 match (base, override_config) {
293 (None, None) => None,
294 (Some(base), None) => Some(base),
295 (None, Some(override_config)) => Some(override_config),
296 (Some(base), Some(override_config)) => Some(TunConfig {
297 tun_fd: override_config.tun_fd.or(base.tun_fd),
298 tun_ipv4: override_config.tun_ipv4.or(base.tun_ipv4),
299 tun_ipv6: override_config.tun_ipv6.or(base.tun_ipv6),
300 tun_mtu: override_config.tun_mtu.or(base.tun_mtu),
301 fake_dns: override_config.fake_dns.or(base.fake_dns),
302 disable_udp_443: override_config.disable_udp_443.or(base.disable_udp_443),
303 }),
304 }
305 }
306
307 fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
308 TransportConfig {
309 bind: override_config.bind.or(base.bind),
310 server_name: override_config.server_name.or(base.server_name),
311 tls_mode: override_config.tls_mode.or(base.tls_mode),
312 ca_cert: override_config.ca_cert.or(base.ca_cert),
313 client_cert: override_config.client_cert.or(base.client_cert),
314 client_key: override_config.client_key.or(base.client_key),
315 zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
316 alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
317 congestion: override_config.congestion.or(base.congestion),
318 cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
319 idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
320 keep_alive: override_config.keep_alive.or(base.keep_alive),
321 max_streams: override_config.max_streams.or(base.max_streams),
322 }
323 }
324
325 #[cfg(feature = "tracing")]
326 fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
327 LoggingConfig {
328 log_level: override_config.log_level.or(base.log_level),
329 }
330 }
331}
332
333impl Default for ConfigBuilder {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339#[cfg(feature = "binary")]
350pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
351 use clap::Parser;
352 let cli_args = cli::Args::parse();
353 let mut builder = ConfigBuilder::new();
354
355 if let Some(config_path) = &cli_args.config {
357 let json_config = json::JsonConfig::from_file(config_path)?;
358 builder = builder.merge_json(json_config);
359 }
360
361 let cli_config = cli::CliConfig {
363 secret: cli_args.secret,
364 server: cli_args.server,
365 auth_option: cli_args.auth_option,
366 endpoint: cli_args.endpoint.into_endpoint_config(),
367 transport: cli_args.transport.into_transport_config(),
368 #[cfg(feature = "tracing")]
369 logging: cli_args.logging.into_logging_config(),
370 };
371 builder = builder.merge_cli(cli_config);
372
373 builder.build().map_err(|e| e.into())
374}
375
376pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
389 let json_config = json::JsonConfig::from_json_str(json_str)?;
390 ConfigBuilder::new()
391 .merge_json(json_config)
392 .build()
393 .map_err(|e| e.into())
394}
395
396pub fn load_from_file(
409 config_path: &std::path::Path,
410) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
411 let json_config = json::JsonConfig::from_file(config_path)?;
412 ConfigBuilder::new()
413 .merge_json(json_config)
414 .build()
415 .map_err(|e| e.into())
416}