1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::builder::Styles;
5use clap::builder::styling::{AnsiColor, Style};
6use clap::{Parser, ValueEnum};
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::Congestion;
12
13#[derive(Parser, Debug)]
15#[command(version, about, long_about = None, styles = styles())]
16pub struct Args {
17 #[clap(long, short = 'c', value_name = "FILE")]
19 pub config: Option<PathBuf>,
20
21 #[clap(
23 long,
24 short = 'k',
25 help_heading = "Required",
26 value_name = "STR",
27 required_unless_present = "config"
28 )]
29 pub secret: Option<String>,
30
31 #[clap(
33 long,
34 short = 'l',
35 help_heading = "Required",
36 value_name = "ADDR",
37 required_unless_present = "config"
38 )]
39 pub listen: Option<SocketAddr>,
40
41 #[cfg(feature = "transport-quic")]
42 #[clap(flatten)]
43 pub transport: TransportConfig,
44
45 #[cfg(feature = "tracing")]
46 #[clap(flatten)]
47 pub logging: LoggingConfig,
48}
49
50#[derive(Deserialize, Serialize, Debug, Default)]
52pub struct ConfigFile {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub secret: Option<String>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub listen: Option<SocketAddr>,
58
59 #[cfg(feature = "transport-quic")]
60 pub transport: TransportConfig,
61
62 #[cfg(feature = "tracing")]
63 pub logging: LoggingConfig,
64}
65
66#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
67#[cfg(feature = "transport-quic")]
68pub struct TransportConfig {
69 #[clap(long, value_enum, help_heading = "Transport")]
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub tls_mode: Option<TlsMode>,
76
77 #[clap(long, help_heading = "Transport", value_name = "FILE")]
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub ca_cert: Option<PathBuf>,
81
82 #[clap(long, help_heading = "Transport", value_name = "FILE")]
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub tls_cert: Option<PathBuf>,
86
87 #[clap(long, help_heading = "Transport", value_name = "FILE")]
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub tls_key: Option<PathBuf>,
91
92 #[clap(long, help_heading = "Transport", action)]
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub zero_rtt: Option<bool>,
96
97 #[clap(
99 long,
100 help_heading = "Transport",
101 value_name = "PROTOCOLS",
102 value_delimiter = ','
103 )]
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub alpn_protocols: Option<Vec<Vec<u8>>>,
106
107 #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub congestion: Option<Congestion>,
111
112 #[clap(long, help_heading = "Transport", value_name = "NUM")]
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub cwnd_init: Option<u64>,
116
117 #[clap(long, help_heading = "Transport", value_name = "TIME")]
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub idle_timeout: Option<u64>,
121
122 #[clap(long, help_heading = "Transport", value_name = "TIME")]
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub keep_alive: Option<u64>,
126
127 #[clap(long, help_heading = "Transport", value_name = "NUM")]
129 #[serde(skip_serializing_if = "Option::is_none")]
130 pub max_streams: Option<u64>,
131}
132
133#[cfg(feature = "tracing")]
134#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
135pub struct LoggingConfig {
136 #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
138 #[serde(skip_serializing_if = "Option::is_none")]
139 pub log_level: Option<String>,
140}
141
142#[cfg(feature = "transport-quic")]
143#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq)]
144#[serde(rename_all = "kebab-case")]
145#[derive(Default)]
146pub enum TlsMode {
147 #[default]
148 Tls,
149 MTls,
150 Insecure,
151}
152
153#[derive(Debug, Clone)]
154pub struct ServiceConfig {
155 pub secret: String,
156 pub listen: SocketAddr,
157 #[cfg(feature = "transport-quic")]
158 pub transport: TransportConfig,
159 #[cfg(feature = "tracing")]
160 pub logging: LoggingConfig,
161}
162
163#[cfg(feature = "transport-quic")]
164impl Default for TransportConfig {
165 fn default() -> Self {
166 Self {
167 tls_mode: Some(TlsMode::Tls),
168 ca_cert: None,
169 tls_cert: None,
170 tls_key: None,
171 zero_rtt: Some(false),
172 alpn_protocols: Some(vec!["h3".into()]),
173 congestion: Some(Congestion::Bbr),
174 cwnd_init: None,
175 idle_timeout: Some(30000),
176 keep_alive: Some(8000),
177 max_streams: Some(1000),
178 }
179 }
180}
181
182#[cfg(feature = "tracing")]
183impl Default for LoggingConfig {
184 fn default() -> Self {
185 Self {
186 log_level: Some("INFO".to_string()),
187 }
188 }
189}
190
191#[cfg(feature = "binary")]
192pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
193 use figment::Figment;
194 use figment::providers::{Format, Json, Serialized};
195
196 let args = Args::parse();
197
198 let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
199
200 if let Some(config_path) = &args.config {
201 if !config_path.exists() {
202 let err = std::io::Error::new(
203 std::io::ErrorKind::NotFound,
204 format!("Configuration file not found: {}", config_path.display()),
205 );
206 return Err(Box::new(figment::Error::from(err.to_string())));
207 }
208
209 figment = figment.merge(Json::file(config_path));
210 }
211
212 let cli_overrides = ConfigFile {
213 secret: args.secret,
214 listen: args.listen,
215 #[cfg(feature = "transport-quic")]
216 transport: args.transport,
217 #[cfg(feature = "tracing")]
218 logging: args.logging,
219 };
220
221 figment = figment.merge(Serialized::defaults(cli_overrides));
222
223 let config: ConfigFile = figment.extract()?;
224
225 let secret = config
226 .secret
227 .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
228 let listen = config
229 .listen
230 .ok_or_else(|| figment::Error::from("missing field `listen`"))?;
231
232 Ok(ServiceConfig {
233 secret,
234 listen,
235 #[cfg(feature = "transport-quic")]
236 transport: config.transport,
237 #[cfg(feature = "tracing")]
238 logging: config.logging,
239 })
240}
241
242fn styles() -> Styles {
243 Styles::styled()
244 .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
245 .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
246 .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
247 .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
248 .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
249 .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
250 .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
251}