1use std::net::SocketAddr;
2use std::path::PathBuf;
3
4use clap::builder::Styles;
5use clap::builder::styling::{AnsiColor, Style};
6use clap::{Parser, ValueEnum};
7
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "transport-quic")]
11use ombrac_transport::quic::Congestion;
12
13#[derive(Parser, Debug)]
15#[command(version, about, long_about = None, styles = styles())]
16pub struct Args {
17 #[clap(long, short = 'c', value_name = "FILE")]
19 pub config: Option<PathBuf>,
20
21 #[clap(
23 long,
24 short = 'k',
25 help_heading = "Required",
26 value_name = "STR",
27 required_unless_present = "config"
28 )]
29 pub secret: Option<String>,
30
31 #[clap(
33 long,
34 short = 'l',
35 help_heading = "Required",
36 value_name = "ADDR",
37 required_unless_present = "config"
38 )]
39 pub listen: Option<SocketAddr>,
40
41 #[cfg(feature = "transport-quic")]
42 #[clap(flatten)]
43 pub transport: TransportConfig,
44
45 #[cfg(feature = "tracing")]
46 #[clap(flatten)]
47 pub logging: LoggingConfig,
48}
49
50#[derive(Deserialize, Serialize, Debug, Default)]
52pub struct ConfigFile {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub secret: Option<String>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub listen: Option<SocketAddr>,
58
59 #[cfg(feature = "transport-quic")]
60 pub transport: TransportConfig,
61
62 #[cfg(feature = "tracing")]
63 pub logging: LoggingConfig,
64}
65
66#[cfg(feature = "transport-quic")]
67#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
68pub struct TransportConfig {
69 #[clap(long, value_enum, help_heading = "Transport")]
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub tls_mode: Option<TlsMode>,
76
77 #[clap(long, help_heading = "Transport", value_name = "FILE")]
79 #[serde(skip_serializing_if = "Option::is_none")]
80 pub ca_cert: Option<PathBuf>,
81
82 #[clap(long, help_heading = "Transport", value_name = "FILE")]
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub tls_cert: Option<PathBuf>,
86
87 #[clap(long, help_heading = "Transport", value_name = "FILE")]
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub tls_key: Option<PathBuf>,
91
92 #[clap(long, help_heading = "Transport", action)]
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub zero_rtt: Option<bool>,
96
97 #[clap(
99 long,
100 help_heading = "Transport",
101 value_name = "PROTOCOLS",
102 value_delimiter = ','
103 )]
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub alpn_protocols: Option<Vec<Vec<u8>>>,
106
107 #[clap(long, help_heading = "Transport", value_name = "ALGORITHM")]
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub congestion: Option<Congestion>,
111
112 #[clap(long, help_heading = "Transport", value_name = "NUM")]
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub cwnd_init: Option<u64>,
116
117 #[clap(long, help_heading = "Transport", value_name = "TIME")]
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub idle_timeout: Option<u64>,
121
122 #[clap(long, help_heading = "Transport", value_name = "TIME")]
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub keep_alive: Option<u64>,
126
127 #[clap(long, help_heading = "Transport", value_name = "NUM")]
129 #[serde(skip_serializing_if = "Option::is_none")]
130 pub max_streams: Option<u64>,
131}
132
133#[cfg(feature = "tracing")]
134#[derive(Deserialize, Serialize, Debug, Parser, Clone)]
135pub struct LoggingConfig {
136 #[clap(long, help_heading = "Logging", value_name = "LEVEL")]
138 #[serde(skip_serializing_if = "Option::is_none")]
139 pub log_level: Option<String>,
140}
141
142#[cfg(feature = "transport-quic")]
143#[derive(ValueEnum, Clone, Debug, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
144#[serde(rename_all = "kebab-case")]
145pub enum TlsMode {
146 #[default]
147 Tls,
148 MTls,
149 Insecure,
150}
151
152#[derive(Debug, Clone)]
153pub struct ServiceConfig {
154 pub secret: String,
155 pub listen: SocketAddr,
156 #[cfg(feature = "transport-quic")]
157 pub transport: TransportConfig,
158 #[cfg(feature = "tracing")]
159 pub logging: LoggingConfig,
160}
161
162#[cfg(feature = "transport-quic")]
163impl Default for TransportConfig {
164 fn default() -> Self {
165 Self {
166 tls_mode: Some(TlsMode::Tls),
167 ca_cert: None,
168 tls_cert: None,
169 tls_key: None,
170 zero_rtt: Some(false),
171 alpn_protocols: Some(vec!["h3".into()]),
172 congestion: Some(Congestion::Bbr),
173 cwnd_init: None,
174 idle_timeout: Some(30000),
175 keep_alive: Some(8000),
176 max_streams: Some(1000),
177 }
178 }
179}
180
181#[cfg(feature = "tracing")]
182impl Default for LoggingConfig {
183 fn default() -> Self {
184 Self {
185 log_level: Some("INFO".to_string()),
186 }
187 }
188}
189
190#[cfg(feature = "binary")]
191pub fn load() -> Result<ServiceConfig, Box<figment::Error>> {
192 use figment::Figment;
193 use figment::providers::{Format, Json, Serialized};
194
195 let args = Args::parse();
196
197 let mut figment = Figment::new().merge(Serialized::defaults(ConfigFile::default()));
198
199 if let Some(config_path) = &args.config {
200 if !config_path.exists() {
201 let err = std::io::Error::new(
202 std::io::ErrorKind::NotFound,
203 format!("Configuration file not found: {}", config_path.display()),
204 );
205 return Err(Box::new(figment::Error::from(err.to_string())));
206 }
207
208 figment = figment.merge(Json::file(config_path));
209 }
210
211 let cli_overrides = ConfigFile {
212 secret: args.secret,
213 listen: args.listen,
214 #[cfg(feature = "transport-quic")]
215 transport: args.transport,
216 #[cfg(feature = "tracing")]
217 logging: args.logging,
218 };
219
220 figment = figment.merge(Serialized::defaults(cli_overrides));
221
222 let config: ConfigFile = figment.extract()?;
223
224 let secret = config
225 .secret
226 .ok_or_else(|| figment::Error::from("missing field `secret`"))?;
227 let listen = config
228 .listen
229 .ok_or_else(|| figment::Error::from("missing field `listen`"))?;
230
231 Ok(ServiceConfig {
232 secret,
233 listen,
234 #[cfg(feature = "transport-quic")]
235 transport: config.transport,
236 #[cfg(feature = "tracing")]
237 logging: config.logging,
238 })
239}
240
241fn styles() -> Styles {
242 Styles::styled()
243 .header(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
244 .usage(Style::new().bold().fg_color(Some(AnsiColor::Green.into())))
245 .literal(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
246 .placeholder(Style::new().fg_color(Some(AnsiColor::Cyan.into())))
247 .valid(Style::new().bold().fg_color(Some(AnsiColor::Cyan.into())))
248 .invalid(Style::new().bold().fg_color(Some(AnsiColor::Yellow.into())))
249 .error(Style::new().bold().fg_color(Some(AnsiColor::Red.into())))
250}