ntp_daemon/config/
server.rs

1use std::{
2    fmt,
3    net::{AddrParseError, SocketAddr},
4    path::PathBuf,
5    str::FromStr,
6    time::Duration,
7};
8
9use serde::{
10    de::{self, MapAccess, Visitor},
11    Deserialize, Deserializer,
12};
13
14use crate::{config::subnet::IpSubnet, ipfilter::IpFilter};
15
16#[derive(Debug, PartialEq, Eq, Clone, Deserialize)]
17#[serde(rename_all = "kebab-case", deny_unknown_fields)]
18pub struct KeysetConfig {
19    /// Number of old keys to keep around
20    #[serde(default = "default_old_keys")]
21    pub old_keys: usize,
22    /// How often to rotate keys (seconds between rotations)
23    #[serde(default = "default_rotation_interval")]
24    pub rotation_interval: usize,
25    #[serde(default)]
26    pub storage_path: Option<String>,
27}
28
29impl Default for KeysetConfig {
30    fn default() -> Self {
31        Self {
32            old_keys: default_old_keys(),
33            rotation_interval: default_rotation_interval(),
34            storage_path: None,
35        }
36    }
37}
38
39fn default_rotation_interval() -> usize {
40    // 1 day in seconds
41    86400
42}
43
44fn default_old_keys() -> usize {
45    // 1 weeks worth at 1 key per day
46    7
47}
48
49#[derive(Debug, PartialEq, Eq, Copy, Clone, Deserialize)]
50pub enum FilterAction {
51    Ignore,
52    Deny,
53}
54
55#[derive(Debug, PartialEq, Eq, Clone)]
56pub struct ServerConfig {
57    pub addr: SocketAddr,
58    pub denylist: IpFilter,
59    pub denylist_action: FilterAction,
60    pub allowlist: IpFilter,
61    pub allowlist_action: FilterAction,
62    pub rate_limiting_cache_size: usize,
63    pub rate_limiting_cutoff: Duration,
64}
65
66impl ServerConfig {
67    pub(crate) fn try_from_str(value: &str) -> Result<Self, <Self as TryFrom<&str>>::Error> {
68        Self::try_from(value)
69    }
70}
71
72impl TryFrom<&str> for ServerConfig {
73    type Error = AddrParseError;
74
75    fn try_from(value: &str) -> Result<Self, Self::Error> {
76        Ok(ServerConfig {
77            addr: SocketAddr::from_str(value)?,
78            denylist: IpFilter::none(),
79            denylist_action: FilterAction::Ignore,
80            allowlist: IpFilter::all(),
81            allowlist_action: FilterAction::Ignore,
82            rate_limiting_cache_size: Default::default(),
83            rate_limiting_cutoff: Default::default(),
84        })
85    }
86}
87
88// We have a custom deserializer for serverconfig because we
89// want to deserialize it from either a string or a map
90impl<'de> Deserialize<'de> for ServerConfig {
91    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92    where
93        D: Deserializer<'de>,
94    {
95        struct ServerConfigVisitor;
96
97        impl<'de> Visitor<'de> for ServerConfigVisitor {
98            type Value = ServerConfig;
99
100            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101                formatter.write_str("string or map")
102            }
103
104            fn visit_str<E: de::Error>(self, value: &str) -> Result<ServerConfig, E> {
105                TryFrom::try_from(value).map_err(de::Error::custom)
106            }
107
108            fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<ServerConfig, M::Error> {
109                let mut addr = None;
110                let mut rate_limiting_cache_size = None;
111                let mut rate_limiting_cutoff = None;
112                let mut allowlist = None;
113                let mut allowlist_action = None;
114                let mut denylist = None;
115                let mut denylist_action = None;
116                while let Some(key) = map.next_key::<String>()? {
117                    match key.as_str() {
118                        "addr" => {
119                            if addr.is_some() {
120                                return Err(de::Error::duplicate_field("addr"));
121                            }
122                            addr = Some(map.next_value::<SocketAddr>()?);
123                        }
124                        "allowlist" => {
125                            if allowlist.is_some() {
126                                return Err(de::Error::duplicate_field("allowlist"));
127                            }
128                            let list: Vec<IpSubnet> = map.next_value()?;
129                            allowlist = Some(IpFilter::new(&list));
130                        }
131                        "allowlist-action" => {
132                            if allowlist_action.is_some() {
133                                return Err(de::Error::duplicate_field("allowlist-action"));
134                            }
135                            allowlist_action = Some(map.next_value::<FilterAction>()?);
136                        }
137                        "denylist" => {
138                            if denylist.is_some() {
139                                return Err(de::Error::duplicate_field("denylist"));
140                            }
141                            let list: Vec<IpSubnet> = map.next_value()?;
142                            denylist = Some(IpFilter::new(&list));
143                        }
144                        "denylist-action" => {
145                            if denylist_action.is_some() {
146                                return Err(de::Error::duplicate_field("denylist-action"));
147                            }
148                            denylist_action = Some(map.next_value::<FilterAction>()?);
149                        }
150                        "rate-limiting-cache-size" => {
151                            if rate_limiting_cache_size.is_some() {
152                                return Err(de::Error::duplicate_field("rate-limiting-cache-size"));
153                            }
154
155                            rate_limiting_cache_size = Some(map.next_value()?);
156                        }
157                        "rate-limiting-cutoff-ms" => {
158                            if rate_limiting_cutoff.is_some() {
159                                return Err(de::Error::duplicate_field("rate-limiting-cutoff-ms"));
160                            }
161
162                            rate_limiting_cutoff = Some(Duration::from_millis(map.next_value()?));
163                        }
164                        _ => {
165                            return Err(de::Error::unknown_field(
166                                key.as_str(),
167                                &[
168                                    "addr",
169                                    "allowlist",
170                                    "allowlist-action",
171                                    "denylist",
172                                    "denylist-action",
173                                    "rate-limiting-cache-size",
174                                    "rate-limiting-cutoff-ms",
175                                ],
176                            ));
177                        }
178                    }
179                }
180
181                let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
182                let (allowlist, allowlist_action) = match allowlist {
183                    Some(allowlist) => (
184                        allowlist,
185                        allowlist_action
186                            .ok_or_else(|| de::Error::missing_field("allowlist-action"))?,
187                    ),
188                    None => (IpFilter::all(), FilterAction::Ignore),
189                };
190                let (denylist, denylist_action) = match denylist {
191                    Some(denylist) => (
192                        denylist,
193                        denylist_action
194                            .ok_or_else(|| de::Error::missing_field("denylist-action"))?,
195                    ),
196                    None => (IpFilter::none(), FilterAction::Ignore),
197                };
198
199                let rate_limiting_cache_size = rate_limiting_cache_size.unwrap_or_default();
200                let rate_limiting_cutoff = rate_limiting_cutoff.unwrap_or_default();
201
202                Ok(ServerConfig {
203                    addr,
204                    allowlist,
205                    allowlist_action,
206                    denylist,
207                    denylist_action,
208                    rate_limiting_cache_size,
209                    rate_limiting_cutoff,
210                })
211            }
212        }
213
214        deserializer.deserialize_any(ServerConfigVisitor)
215    }
216}
217
218#[derive(Debug, PartialEq, Eq, Clone, Deserialize)]
219#[serde(rename_all = "kebab-case", deny_unknown_fields)]
220pub struct NtsKeConfig {
221    pub cert_chain_path: PathBuf,
222    pub key_der_path: PathBuf,
223    #[serde(default = "default_nts_ke_timeout")]
224    pub timeout_ms: u64,
225    pub addr: SocketAddr,
226}
227
228fn default_nts_ke_timeout() -> u64 {
229    1000
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_deserialize_peer() {
238        #[derive(Deserialize, Debug)]
239        struct TestConfig {
240            server: ServerConfig,
241        }
242
243        let test: TestConfig = toml::from_str(
244            r#"
245            [server]
246            addr = "0.0.0.0:123"
247            "#,
248        )
249        .unwrap();
250        assert_eq!(test.server.addr, "0.0.0.0:123".parse().unwrap());
251
252        let test: TestConfig = toml::from_str(
253            r#"
254            [server]
255            addr = "127.0.0.1:123"
256            rate-limiting-cutoff-ms = 1000
257            rate-limiting-cache-size = 32
258            "#,
259        )
260        .unwrap();
261        assert_eq!(test.server.addr, "127.0.0.1:123".parse().unwrap());
262        assert_eq!(test.server.rate_limiting_cache_size, 32);
263        assert_eq!(
264            test.server.rate_limiting_cutoff,
265            Duration::from_millis(1000)
266        );
267    }
268
269    #[test]
270    fn test_deserialize_nts_ke() {
271        #[derive(Deserialize, Debug)]
272        #[serde(rename_all = "kebab-case", deny_unknown_fields)]
273        struct TestConfig {
274            nts_ke_server: NtsKeConfig,
275        }
276
277        let test: TestConfig = toml::from_str(
278            r#"
279            [nts-ke-server]
280            addr = "0.0.0.0:4460"
281            cert-chain-path = "/foo/bar/baz.pem"
282            key-der-path = "spam.der"
283            "#,
284        )
285        .unwrap();
286
287        let pem = PathBuf::from("/foo/bar/baz.pem");
288        assert_eq!(test.nts_ke_server.cert_chain_path, pem);
289        assert_eq!(test.nts_ke_server.key_der_path, PathBuf::from("spam.der"));
290        assert_eq!(test.nts_ke_server.timeout_ms, 1000,);
291        assert_eq!(test.nts_ke_server.addr, "0.0.0.0:4460".parse().unwrap(),);
292    }
293}