1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::ValueEnum;
5use serde::{Deserialize, Serialize};
6
7use ombrac_transport::quic::Congestion;
8
9pub mod cli;
10pub mod json;
11
12#[derive(Deserialize, Serialize, Debug, Clone)]
14#[serde(rename_all = "snake_case")]
15pub struct TransportConfig {
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub tls_mode: Option<TlsMode>,
18
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub ca_cert: Option<PathBuf>,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tls_cert: Option<PathBuf>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tls_key: Option<PathBuf>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub zero_rtt: Option<bool>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub alpn_protocols: Option<Vec<Vec<u8>>>,
33
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub congestion: Option<Congestion>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub cwnd_init: Option<u64>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub idle_timeout: Option<u64>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub keep_alive: Option<u64>,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub max_streams: Option<u64>,
48}
49
50impl TransportConfig {
51 pub fn tls_mode(&self) -> TlsMode {
53 self.tls_mode.unwrap_or_default()
54 }
55
56 pub fn zero_rtt(&self) -> bool {
58 self.zero_rtt.unwrap_or(false)
59 }
60
61 pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
63 self.alpn_protocols
64 .clone()
65 .unwrap_or_else(|| vec!["h3".into()])
66 }
67
68 pub fn congestion(&self) -> Congestion {
70 self.congestion.unwrap_or(Congestion::Bbr)
71 }
72
73 pub fn idle_timeout(&self) -> u64 {
75 self.idle_timeout.unwrap_or(30000)
76 }
77
78 pub fn keep_alive(&self) -> u64 {
80 self.keep_alive.unwrap_or(8000)
81 }
82
83 pub fn max_streams(&self) -> u64 {
85 self.max_streams.unwrap_or(1000)
86 }
87}
88
89impl Default for TransportConfig {
90 fn default() -> Self {
91 Self {
92 tls_mode: Some(TlsMode::Tls),
93 ca_cert: None,
94 tls_cert: None,
95 tls_key: None,
96 zero_rtt: Some(false),
97 alpn_protocols: Some(vec!["h3".into()]),
98 congestion: Some(Congestion::Bbr),
99 cwnd_init: None,
100 idle_timeout: Some(30000),
101 keep_alive: Some(8000),
102 max_streams: Some(1000),
103 }
104 }
105}
106
107#[derive(Deserialize, Serialize, Debug, Clone)]
109#[serde(rename_all = "snake_case")]
110pub struct ConnectionConfig {
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub max_connections: Option<usize>,
114
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub auth_timeout_secs: Option<u64>,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub max_concurrent_streams: Option<usize>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub max_concurrent_datagrams: Option<usize>,
126}
127
128impl ConnectionConfig {
129 pub fn max_connections(&self) -> usize {
131 self.max_connections.unwrap_or(10000)
132 }
133
134 pub fn auth_timeout_secs(&self) -> u64 {
136 self.auth_timeout_secs.unwrap_or(10)
137 }
138
139 #[deprecated(note = "Use auth_timeout_secs instead")]
143 pub fn handshake_timeout_secs(&self) -> u64 {
144 self.auth_timeout_secs()
145 }
146
147 pub fn max_concurrent_streams(&self) -> usize {
149 self.max_concurrent_streams.unwrap_or(4096)
150 }
151
152 pub fn max_concurrent_datagrams(&self) -> usize {
154 self.max_concurrent_datagrams.unwrap_or(4096)
155 }
156}
157
158impl Default for ConnectionConfig {
159 fn default() -> Self {
160 Self {
161 max_connections: Some(10000),
162 auth_timeout_secs: Some(10),
163 max_concurrent_streams: Some(4096),
164 max_concurrent_datagrams: Some(4096),
165 }
166 }
167}
168
169#[cfg(feature = "tracing")]
171#[derive(Deserialize, Serialize, Debug, Clone)]
172#[serde(rename_all = "snake_case")]
173pub struct LoggingConfig {
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub log_level: Option<String>,
177}
178
179#[cfg(feature = "tracing")]
180impl LoggingConfig {
181 pub fn log_level(&self) -> &str {
183 self.log_level.as_deref().unwrap_or("INFO")
184 }
185}
186
187#[cfg(feature = "tracing")]
188impl Default for LoggingConfig {
189 fn default() -> Self {
190 Self {
191 log_level: Some("INFO".to_string()),
192 }
193 }
194}
195
196#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
197#[serde(rename_all = "kebab-case")]
198pub enum TlsMode {
199 #[default]
200 Tls,
201 MTls,
202 Insecure,
203}
204
205#[derive(Debug, Clone)]
207pub struct ServiceConfig {
208 pub secret: String,
209 pub listen: SocketAddr,
210 pub transport: TransportConfig,
211 pub connection: ConnectionConfig,
212 #[cfg(feature = "tracing")]
213 pub logging: LoggingConfig,
214}
215
216pub struct ConfigBuilder {
219 secret: Option<String>,
220 listen: Option<SocketAddr>,
221 transport: TransportConfig,
222 connection: ConnectionConfig,
223 #[cfg(feature = "tracing")]
224 logging: LoggingConfig,
225}
226
227impl ConfigBuilder {
228 pub fn new() -> Self {
230 Self {
231 secret: None,
232 listen: None,
233 transport: TransportConfig::default(),
234 connection: ConnectionConfig::default(),
235 #[cfg(feature = "tracing")]
236 logging: LoggingConfig::default(),
237 }
238 }
239
240 pub fn merge_json(mut self, json_config: json::JsonConfig) -> Self {
242 if let Some(secret) = json_config.secret {
243 self.secret = Some(secret);
244 }
245 if let Some(listen) = json_config.listen {
246 self.listen = Some(listen);
247 }
248 if let Some(transport) = json_config.transport {
249 self.transport = Self::merge_transport(self.transport, transport);
250 }
251 if let Some(conn) = json_config.connection {
252 self.connection = Self::merge_connection(self.connection, conn);
253 }
254 #[cfg(feature = "tracing")]
255 {
256 if let Some(logging) = json_config.logging {
257 self.logging = Self::merge_logging(self.logging, logging);
258 }
259 }
260 self
261 }
262
263 pub fn merge_cli(mut self, cli_config: cli::CliConfig) -> Self {
265 if let Some(secret) = cli_config.secret {
266 self.secret = Some(secret);
267 }
268 if let Some(listen) = cli_config.listen {
269 self.listen = Some(listen);
270 }
271 self.transport = Self::merge_transport(self.transport, cli_config.transport);
272 #[cfg(feature = "tracing")]
273 {
274 self.logging = Self::merge_logging(self.logging, cli_config.logging);
275 }
276 self
277 }
278
279 pub fn build(self) -> Result<ServiceConfig, String> {
281 let secret = self
282 .secret
283 .ok_or_else(|| "missing required field: secret".to_string())?;
284 let listen = self
285 .listen
286 .ok_or_else(|| "missing required field: listen".to_string())?;
287
288 Ok(ServiceConfig {
289 secret,
290 listen,
291 transport: self.transport,
292 connection: self.connection,
293 #[cfg(feature = "tracing")]
294 logging: self.logging,
295 })
296 }
297
298 fn merge_transport(base: TransportConfig, override_config: TransportConfig) -> TransportConfig {
299 TransportConfig {
300 tls_mode: override_config.tls_mode.or(base.tls_mode),
301 ca_cert: override_config.ca_cert.or(base.ca_cert),
302 tls_cert: override_config.tls_cert.or(base.tls_cert),
303 tls_key: override_config.tls_key.or(base.tls_key),
304 zero_rtt: override_config.zero_rtt.or(base.zero_rtt),
305 alpn_protocols: override_config.alpn_protocols.or(base.alpn_protocols),
306 congestion: override_config.congestion.or(base.congestion),
307 cwnd_init: override_config.cwnd_init.or(base.cwnd_init),
308 idle_timeout: override_config.idle_timeout.or(base.idle_timeout),
309 keep_alive: override_config.keep_alive.or(base.keep_alive),
310 max_streams: override_config.max_streams.or(base.max_streams),
311 }
312 }
313
314 fn merge_connection(
315 base: ConnectionConfig,
316 override_config: ConnectionConfig,
317 ) -> ConnectionConfig {
318 ConnectionConfig {
319 max_connections: override_config.max_connections.or(base.max_connections),
320 auth_timeout_secs: override_config.auth_timeout_secs.or(base.auth_timeout_secs),
321 max_concurrent_streams: override_config
322 .max_concurrent_streams
323 .or(base.max_concurrent_streams),
324 max_concurrent_datagrams: override_config
325 .max_concurrent_datagrams
326 .or(base.max_concurrent_datagrams),
327 }
328 }
329
330 #[cfg(feature = "tracing")]
331 fn merge_logging(base: LoggingConfig, override_config: LoggingConfig) -> LoggingConfig {
332 LoggingConfig {
333 log_level: override_config.log_level.or(base.log_level),
334 }
335 }
336}
337
338impl Default for ConfigBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[cfg(feature = "binary")]
355pub fn load() -> Result<ServiceConfig, Box<dyn std::error::Error>> {
356 use clap::Parser;
357 let cli_args = cli::Args::parse();
358 let mut builder = ConfigBuilder::new();
359
360 if let Some(config_path) = &cli_args.config {
362 let json_config = json::JsonConfig::from_file(config_path)?;
363 builder = builder.merge_json(json_config);
364 }
365
366 let cli_config = cli::CliConfig {
368 secret: cli_args.secret,
369 listen: cli_args.listen,
370 transport: cli_args.transport.into_transport_config(),
371 #[cfg(feature = "tracing")]
372 logging: cli_args.logging.into_logging_config(),
373 };
374 builder = builder.merge_cli(cli_config);
375
376 builder.build().map_err(|e| e.into())
377}
378
379pub fn load_from_json(json_str: &str) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
392 let json_config = json::JsonConfig::from_json_str(json_str)?;
393 ConfigBuilder::new()
394 .merge_json(json_config)
395 .build()
396 .map_err(|e| e.into())
397}
398
399pub fn load_from_file(
412 config_path: &std::path::Path,
413) -> Result<ServiceConfig, Box<dyn std::error::Error>> {
414 let json_config = json::JsonConfig::from_file(config_path)?;
415 ConfigBuilder::new()
416 .merge_json(json_config)
417 .build()
418 .map_err(|e| e.into())
419}