1use std::{
2 fmt,
3 net::{AddrParseError, SocketAddr},
4 path::PathBuf,
5 str::FromStr,
6 time::Duration,
7};
8
9use serde::{
10 de::{self, MapAccess, Visitor},
11 Deserialize, Deserializer,
12};
13
14use crate::{config::subnet::IpSubnet, ipfilter::IpFilter};
15
16#[derive(Debug, PartialEq, Eq, Clone, Deserialize)]
17#[serde(rename_all = "kebab-case", deny_unknown_fields)]
18pub struct KeysetConfig {
19 #[serde(default = "default_old_keys")]
21 pub old_keys: usize,
22 #[serde(default = "default_rotation_interval")]
24 pub rotation_interval: usize,
25 #[serde(default)]
26 pub storage_path: Option<String>,
27}
28
29impl Default for KeysetConfig {
30 fn default() -> Self {
31 Self {
32 old_keys: default_old_keys(),
33 rotation_interval: default_rotation_interval(),
34 storage_path: None,
35 }
36 }
37}
38
39fn default_rotation_interval() -> usize {
40 86400
42}
43
44fn default_old_keys() -> usize {
45 7
47}
48
49#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize)]
50pub enum FilterAction {
51 Ignore,
52 Deny,
53}
54
55#[derive(Debug, PartialEq, Eq, Clone)]
56pub struct ServerConfig {
57 pub addr: SocketAddr,
58 pub denylist: IpFilter,
59 pub denylist_action: FilterAction,
60 pub allowlist: IpFilter,
61 pub allowlist_action: FilterAction,
62 pub rate_limiting_cache_size: usize,
63 pub rate_limiting_cutoff: Duration,
64}
65
66impl ServerConfig {
67 pub(crate) fn try_from_str(value: &str) -> Result<Self, <Self as TryFrom<&str>>::Error> {
68 Self::try_from(value)
69 }
70}
71
72impl TryFrom<&str> for ServerConfig {
73 type Error = AddrParseError;
74
75 fn try_from(value: &str) -> Result<Self, Self::Error> {
76 Ok(ServerConfig {
77 addr: SocketAddr::from_str(value)?,
78 denylist: IpFilter::none(),
79 denylist_action: FilterAction::Ignore,
80 allowlist: IpFilter::all(),
81 allowlist_action: FilterAction::Ignore,
82 rate_limiting_cache_size: Default::default(),
83 rate_limiting_cutoff: Default::default(),
84 })
85 }
86}
87
88impl<'de> Deserialize<'de> for ServerConfig {
91 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92 where
93 D: Deserializer<'de>,
94 {
95 struct ServerConfigVisitor;
96
97 impl<'de> Visitor<'de> for ServerConfigVisitor {
98 type Value = ServerConfig;
99
100 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101 formatter.write_str("string or map")
102 }
103
104 fn visit_str<E: de::Error>(self, value: &str) -> Result<ServerConfig, E> {
105 TryFrom::try_from(value).map_err(de::Error::custom)
106 }
107
108 fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<ServerConfig, M::Error> {
109 let mut addr = None;
110 let mut rate_limiting_cache_size = None;
111 let mut rate_limiting_cutoff = None;
112 let mut allowlist = None;
113 let mut allowlist_action = None;
114 let mut denylist = None;
115 let mut denylist_action = None;
116 while let Some(key) = map.next_key::<String>()? {
117 match key.as_str() {
118 "addr" => {
119 if addr.is_some() {
120 return Err(de::Error::duplicate_field("addr"));
121 }
122 addr = Some(map.next_value::<SocketAddr>()?);
123 }
124 "allowlist" => {
125 if allowlist.is_some() {
126 return Err(de::Error::duplicate_field("allowlist"));
127 }
128 let list: Vec<IpSubnet> = map.next_value()?;
129 allowlist = Some(IpFilter::new(&list));
130 }
131 "allowlist-action" => {
132 if allowlist_action.is_some() {
133 return Err(de::Error::duplicate_field("allowlist-action"));
134 }
135 allowlist_action = Some(map.next_value::<FilterAction>()?);
136 }
137 "denylist" => {
138 if denylist.is_some() {
139 return Err(de::Error::duplicate_field("denylist"));
140 }
141 let list: Vec<IpSubnet> = map.next_value()?;
142 denylist = Some(IpFilter::new(&list));
143 }
144 "denylist-action" => {
145 if denylist_action.is_some() {
146 return Err(de::Error::duplicate_field("denylist-action"));
147 }
148 denylist_action = Some(map.next_value::<FilterAction>()?);
149 }
150 "rate-limiting-cache-size" => {
151 if rate_limiting_cache_size.is_some() {
152 return Err(de::Error::duplicate_field("rate-limiting-cache-size"));
153 }
154
155 rate_limiting_cache_size = Some(map.next_value()?);
156 }
157 "rate-limiting-cutoff-ms" => {
158 if rate_limiting_cutoff.is_some() {
159 return Err(de::Error::duplicate_field("rate-limiting-cutoff-ms"));
160 }
161
162 rate_limiting_cutoff = Some(Duration::from_millis(map.next_value()?));
163 }
164 _ => {
165 return Err(de::Error::unknown_field(
166 key.as_str(),
167 &[
168 "addr",
169 "allowlist",
170 "allowlist-action",
171 "denylist",
172 "denylist-action",
173 "rate-limiting-cache-size",
174 "rate-limiting-cutoff-ms",
175 ],
176 ));
177 }
178 }
179 }
180
181 let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
182 let (allowlist, allowlist_action) = match allowlist {
183 Some(allowlist) => (
184 allowlist,
185 allowlist_action
186 .ok_or_else(|| de::Error::missing_field("allowlist-action"))?,
187 ),
188 None => (IpFilter::all(), FilterAction::Ignore),
189 };
190 let (denylist, denylist_action) = match denylist {
191 Some(denylist) => (
192 denylist,
193 denylist_action
194 .ok_or_else(|| de::Error::missing_field("denylist-action"))?,
195 ),
196 None => (IpFilter::none(), FilterAction::Ignore),
197 };
198
199 let rate_limiting_cache_size = rate_limiting_cache_size.unwrap_or_default();
200 let rate_limiting_cutoff = rate_limiting_cutoff.unwrap_or_default();
201
202 Ok(ServerConfig {
203 addr,
204 allowlist,
205 allowlist_action,
206 denylist,
207 denylist_action,
208 rate_limiting_cache_size,
209 rate_limiting_cutoff,
210 })
211 }
212 }
213
214 deserializer.deserialize_any(ServerConfigVisitor)
215 }
216}
217
218#[derive(Debug, PartialEq, Eq, Clone, Deserialize)]
219#[serde(rename_all = "kebab-case", deny_unknown_fields)]
220pub struct NtsKeConfig {
221 pub cert_chain_path: PathBuf,
222 pub key_der_path: PathBuf,
223 #[serde(default = "default_nts_ke_timeout")]
224 pub timeout_ms: u64,
225 pub addr: SocketAddr,
226}
227
228fn default_nts_ke_timeout() -> u64 {
229 1000
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_deserialize_peer() {
238 #[derive(Deserialize, Debug)]
239 struct TestConfig {
240 server: ServerConfig,
241 }
242
243 let test: TestConfig = toml::from_str(
244 r#"
245 [server]
246 addr = "0.0.0.0:123"
247 "#,
248 )
249 .unwrap();
250 assert_eq!(test.server.addr, "0.0.0.0:123".parse().unwrap());
251
252 let test: TestConfig = toml::from_str(
253 r#"
254 [server]
255 addr = "127.0.0.1:123"
256 rate-limiting-cutoff-ms = 1000
257 rate-limiting-cache-size = 32
258 "#,
259 )
260 .unwrap();
261 assert_eq!(test.server.addr, "127.0.0.1:123".parse().unwrap());
262 assert_eq!(test.server.rate_limiting_cache_size, 32);
263 assert_eq!(
264 test.server.rate_limiting_cutoff,
265 Duration::from_millis(1000)
266 );
267 }
268
269 #[test]
270 fn test_deserialize_nts_ke() {
271 #[derive(Deserialize, Debug)]
272 #[serde(rename_all = "kebab-case", deny_unknown_fields)]
273 struct TestConfig {
274 nts_ke_server: NtsKeConfig,
275 }
276
277 let test: TestConfig = toml::from_str(
278 r#"
279 [nts-ke-server]
280 addr = "0.0.0.0:4460"
281 cert-chain-path = "/foo/bar/baz.pem"
282 key-der-path = "spam.der"
283 "#,
284 )
285 .unwrap();
286
287 let pem = PathBuf::from("/foo/bar/baz.pem");
288 assert_eq!(test.nts_ke_server.cert_chain_path, pem);
289 assert_eq!(test.nts_ke_server.key_der_path, PathBuf::from("spam.der"));
290 assert_eq!(test.nts_ke_server.timeout_ms, 1000,);
291 assert_eq!(test.nts_ke_server.addr, "0.0.0.0:4460".parse().unwrap(),);
292 }
293}