ntp_daemon/config/
peer.rs

1use std::{
2    fmt,
3    net::SocketAddr,
4    path::PathBuf,
5    sync::{Arc, Mutex},
6};
7
8use rustls::Certificate;
9use serde::{
10    de::{self, MapAccess, Visitor},
11    Deserialize, Deserializer,
12};
13
14use crate::keyexchange::certificates_from_file;
15
16#[derive(Deserialize, Debug, PartialEq, Eq, Clone, Copy, Default)]
17pub enum PeerHostMode {
18    #[serde(alias = "server")]
19    #[default]
20    Server,
21    #[serde(alias = "nts-server")]
22    NtsServer,
23    #[serde(alias = "pool")]
24    Pool,
25}
26
27#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
28#[serde(rename_all = "kebab-case", deny_unknown_fields)]
29pub struct StandardPeerConfig {
30    pub addr: NormalizedAddress,
31}
32
33#[derive(Debug, PartialEq, Eq, Clone)]
34pub struct NtsPeerConfig {
35    pub ke_addr: NormalizedAddress,
36    pub certificates: Arc<[Certificate]>,
37}
38
39#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
40#[serde(rename_all = "kebab-case", deny_unknown_fields)]
41pub struct PoolPeerConfig {
42    pub addr: NormalizedAddress,
43    pub max_peers: usize,
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum PeerConfig {
48    Standard(StandardPeerConfig),
49    Nts(NtsPeerConfig),
50    Pool(PoolPeerConfig),
51    // Consul(ConsulPeerConfig),
52}
53
54impl PeerConfig {
55    pub(crate) fn try_from_str(value: &str) -> Result<Self, std::io::Error> {
56        Self::try_from(value)
57    }
58}
59
60/// A normalized address has a host and a port part. However, the host may be
61/// invalid, we didn't yet perform a DNS lookup.
62#[derive(Deserialize, Debug, Clone)]
63#[serde(rename_all = "kebab-case", deny_unknown_fields)]
64pub struct NormalizedAddress {
65    pub(crate) server_name: String,
66    pub(crate) port: u16,
67
68    /// Used to inject socket addrs into the DNS lookup result
69    #[cfg(test)]
70    hardcoded_dns_resolve: HardcodedDnsResolve,
71}
72
73impl Eq for NormalizedAddress {}
74
75impl PartialEq for NormalizedAddress {
76    fn eq(&self, other: &Self) -> bool {
77        self.server_name == other.server_name && self.port == other.port
78    }
79}
80
81#[derive(Deserialize, Debug, Clone, Default)]
82struct HardcodedDnsResolve {
83    #[cfg_attr(not(test), allow(unused))]
84    #[serde(skip)]
85    addresses: Arc<Mutex<Vec<SocketAddr>>>,
86}
87
88impl From<Vec<SocketAddr>> for HardcodedDnsResolve {
89    fn from(value: Vec<SocketAddr>) -> Self {
90        Self {
91            addresses: Arc::new(Mutex::new(value)),
92        }
93    }
94}
95
96impl NormalizedAddress {
97    const NTP_DEFAULT_PORT: u16 = 123;
98    const NTS_KE_DEFAULT_PORT: u16 = 4460;
99
100    /// Specifically, this adds the `:123` port if no port is specified
101    pub(crate) fn from_string_ntp(address: String) -> std::io::Result<Self> {
102        let (server_name, port) = Self::from_string_help(address, Self::NTP_DEFAULT_PORT)?;
103
104        Ok(Self {
105            server_name,
106            port,
107
108            #[cfg(test)]
109            hardcoded_dns_resolve: HardcodedDnsResolve::default(),
110        })
111    }
112
113    /// Specifically, this adds the `:4460` port if no port is specified
114    fn from_string_nts_ke(address: String) -> std::io::Result<Self> {
115        let (server_name, port) = Self::from_string_help(address, Self::NTS_KE_DEFAULT_PORT)?;
116
117        Ok(Self {
118            server_name,
119            port,
120
121            #[cfg(test)]
122            hardcoded_dns_resolve: HardcodedDnsResolve::default(),
123        })
124    }
125
126    fn from_string_help(address: String, default_port: u16) -> std::io::Result<(String, u16)> {
127        if address.split(':').count() > 2 {
128            // IPv6, try to parse it as such
129            match address.parse::<SocketAddr>() {
130                Ok(socket_addr) => {
131                    // strip off the port
132                    let (server_name, _) = address.rsplit_once(':').unwrap();
133
134                    Ok((server_name.to_string(), socket_addr.port()))
135                }
136                Err(e) => {
137                    // Could be because of no port, add one and see
138                    let address_with_port = format!("[{address}]:{default_port}");
139                    if address_with_port.parse::<SocketAddr>().is_ok() {
140                        Ok((format!("[{address}]"), default_port))
141                    } else {
142                        Err(std::io::Error::new(std::io::ErrorKind::Other, e))
143                    }
144                }
145            }
146        } else if let Some((server_name, port)) = address.split_once(':') {
147            // Not ipv6, and we seem to have a port. We cant reasonably
148            // check whether the host is valid, but at least check that
149            // the port is.
150            match port.parse::<u16>() {
151                Ok(port) => Ok((server_name.to_string(), port)),
152                Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e)),
153            }
154        } else {
155            // Not ipv6 and no port. As we cant reasonably check host
156            // so just append a port
157            Ok((address, default_port))
158        }
159    }
160
161    #[cfg(test)]
162    pub(crate) fn new_unchecked(server_name: &str, port: u16) -> Self {
163        Self {
164            server_name: server_name.to_string(),
165            port,
166
167            #[cfg(test)]
168            hardcoded_dns_resolve: HardcodedDnsResolve::default(),
169        }
170    }
171
172    #[cfg(test)]
173    pub(crate) fn with_hardcoded_dns(
174        server_name: &str,
175        port: u16,
176        hardcoded_dns_resolve: Vec<SocketAddr>,
177    ) -> Self {
178        Self {
179            server_name: server_name.to_string(),
180            port,
181            hardcoded_dns_resolve: HardcodedDnsResolve::from(hardcoded_dns_resolve),
182        }
183    }
184
185    #[cfg(not(test))]
186    pub async fn lookup_host(&self) -> std::io::Result<impl Iterator<Item = SocketAddr> + '_> {
187        tokio::net::lookup_host((self.server_name.as_str(), self.port)).await
188    }
189
190    #[cfg(test)]
191    pub async fn lookup_host(&self) -> std::io::Result<impl Iterator<Item = SocketAddr> + '_> {
192        // We don't want to spam a real DNS server during testing. This is an attempt to randomize
193        // the returned addresses somewhat.
194        let mut addresses = self.hardcoded_dns_resolve.addresses.lock().unwrap();
195
196        if let Some(last) = addresses.pop() {
197            addresses.insert(0, last);
198        }
199
200        let addresses = addresses.to_vec();
201
202        Ok(addresses.into_iter())
203    }
204}
205
206impl std::fmt::Display for NormalizedAddress {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        write!(f, "{}:{}", self.server_name, self.port)
209    }
210}
211
212impl TryFrom<&str> for StandardPeerConfig {
213    type Error = std::io::Error;
214
215    fn try_from(value: &str) -> Result<Self, Self::Error> {
216        Ok(Self {
217            addr: NormalizedAddress::from_string_ntp(value.to_string())?,
218        })
219    }
220}
221
222impl<'a> TryFrom<&'a str> for PeerConfig {
223    type Error = std::io::Error;
224
225    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
226        StandardPeerConfig::try_from(value).map(Self::Standard)
227    }
228}
229
230// We have a custom deserializer for peerconfig because we
231// want to deserialize it from either a string or a map
232impl<'de> Deserialize<'de> for PeerConfig {
233    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
234    where
235        D: Deserializer<'de>,
236    {
237        struct PeerConfigVisitor;
238
239        impl<'de> Visitor<'de> for PeerConfigVisitor {
240            type Value = PeerConfig;
241
242            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
243                formatter.write_str("string or map")
244            }
245
246            fn visit_str<E: de::Error>(self, value: &str) -> Result<PeerConfig, E> {
247                TryFrom::try_from(value).map_err(de::Error::custom)
248            }
249
250            fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<PeerConfig, M::Error> {
251                let mut ke_addr = None;
252                let mut opt_certificate_path = None;
253                let mut addr = None;
254                let mut mode = None;
255                let mut max_peers = None;
256                while let Some(key) = map.next_key::<String>()? {
257                    match key.as_str() {
258                        "addr" => {
259                            if addr.is_some() {
260                                return Err(de::Error::duplicate_field("addr"));
261                            }
262                            let raw: String = map.next_value()?;
263
264                            let parsed_addr =
265                                NormalizedAddress::from_string_ntp(raw.as_str().to_string())
266                                    .map_err(de::Error::custom)?;
267
268                            addr = Some(parsed_addr);
269                        }
270                        "ke-addr" => {
271                            if ke_addr.is_some() {
272                                return Err(de::Error::duplicate_field("ke-addr"));
273                            }
274                            let raw: String = map.next_value()?;
275
276                            let parsed_addr =
277                                NormalizedAddress::from_string_nts_ke(raw.as_str().to_string())
278                                    .map_err(de::Error::custom)?;
279
280                            ke_addr = Some(parsed_addr);
281                        }
282                        "certificate" => {
283                            if opt_certificate_path.is_some() {
284                                return Err(de::Error::duplicate_field("certificate"));
285                            }
286                            let raw: String = map.next_value()?;
287
288                            opt_certificate_path = Some(PathBuf::from(raw));
289                        }
290                        "mode" => {
291                            if mode.is_some() {
292                                return Err(de::Error::duplicate_field("mode"));
293                            }
294                            mode = Some(map.next_value()?);
295                        }
296                        "max-peers" => {
297                            if max_peers.is_some() {
298                                return Err(de::Error::duplicate_field("max-peers"));
299                            }
300                            max_peers = Some(map.next_value()?);
301                        }
302                        _ => {
303                            return Err(de::Error::unknown_field(
304                                key.as_str(),
305                                &["addr", "ke-addr", "certificate", "mode", "max-peers"],
306                            ));
307                        }
308                    }
309                }
310
311                let mode = mode.unwrap_or_default();
312
313                let unknown_field =
314                    |field, valid_fields| Err(de::Error::unknown_field(field, valid_fields));
315
316                match mode {
317                    PeerHostMode::Server => {
318                        let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
319
320                        let valid_fields = &["addr", "mode"];
321                        if max_peers.is_some() {
322                            unknown_field("max-peers", valid_fields)
323                        } else if ke_addr.is_some() {
324                            unknown_field("ke-addr", valid_fields)
325                        } else if opt_certificate_path.is_some() {
326                            unknown_field("certificate", valid_fields)
327                        } else {
328                            Ok(PeerConfig::Standard(StandardPeerConfig { addr }))
329                        }
330                    }
331                    PeerHostMode::NtsServer => {
332                        let ke_addr = ke_addr.ok_or_else(|| de::Error::missing_field("ke-addr"))?;
333
334                        let valid_fields = &["mode", "ke-addr", "certificate"];
335                        if max_peers.is_some() {
336                            unknown_field("max-peers", valid_fields)
337                        } else {
338                            let certificates: Arc<[Certificate]> = if let Some(certificate_path) =
339                                opt_certificate_path
340                            {
341                                match certificates_from_file(&certificate_path) {
342                                    Ok(certificates) => Arc::from(certificates),
343                                    Err(io_error) => {
344                                        let msg = format!(
345                                                "error while parsing certificate file {certificate_path:?}: {io_error:?}"
346                                            );
347                                        return Err(de::Error::custom(msg));
348                                    }
349                                }
350                            } else {
351                                Arc::from([])
352                            };
353
354                            Ok(PeerConfig::Nts(NtsPeerConfig {
355                                ke_addr,
356                                certificates,
357                            }))
358                        }
359                    }
360                    PeerHostMode::Pool => {
361                        let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
362
363                        let valid_fields = &["addr", "mode", "max-peers"];
364                        if ke_addr.is_some() {
365                            unknown_field("ke-addr", valid_fields)
366                        } else if opt_certificate_path.is_some() {
367                            unknown_field("certificate", valid_fields)
368                        } else {
369                            let max_peers = max_peers.unwrap_or(1);
370
371                            Ok(PeerConfig::Pool(PoolPeerConfig { addr, max_peers }))
372                        }
373                    }
374                }
375            }
376        }
377
378        deserializer.deserialize_any(PeerConfigVisitor)
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    fn peer_addr(config: &PeerConfig) -> String {
387        match config {
388            PeerConfig::Standard(c) => c.addr.to_string(),
389            PeerConfig::Nts(c) => c.ke_addr.to_string(),
390            PeerConfig::Pool(c) => c.addr.to_string(),
391        }
392    }
393
394    #[test]
395    fn test_deserialize_peer() {
396        #[derive(Deserialize, Debug)]
397        struct TestConfig {
398            peer: PeerConfig,
399        }
400
401        let test: TestConfig = toml::from_str("peer = \"example.com\"").unwrap();
402        assert_eq!(peer_addr(&test.peer), "example.com:123");
403        assert!(matches!(test.peer, PeerConfig::Standard(_)));
404
405        let test: TestConfig = toml::from_str("peer = \"example.com:5678\"").unwrap();
406        assert_eq!(peer_addr(&test.peer), "example.com:5678");
407        assert!(matches!(test.peer, PeerConfig::Standard(_)));
408
409        let test: TestConfig = toml::from_str("[peer]\naddr = \"example.com\"").unwrap();
410        assert_eq!(peer_addr(&test.peer), "example.com:123");
411        assert!(matches!(test.peer, PeerConfig::Standard(_)));
412
413        let test: TestConfig = toml::from_str("[peer]\naddr = \"example.com:5678\"").unwrap();
414        assert_eq!(peer_addr(&test.peer), "example.com:5678");
415        assert!(matches!(test.peer, PeerConfig::Standard(_)));
416
417        let test: TestConfig = toml::from_str(
418            r#"
419            [peer]
420            addr = "example.com"
421            mode = "Server"
422            "#,
423        )
424        .unwrap();
425        assert_eq!(peer_addr(&test.peer), "example.com:123");
426        assert!(matches!(test.peer, PeerConfig::Standard(_)));
427
428        let test: TestConfig = toml::from_str(
429            r#"
430            [peer]
431            addr = "example.com"
432            mode = "Pool"
433            "#,
434        )
435        .unwrap();
436        assert!(matches!(test.peer, PeerConfig::Pool(_)));
437        if let PeerConfig::Pool(config) = test.peer {
438            assert_eq!(config.addr.to_string(), "example.com:123");
439            assert_eq!(config.max_peers, 1);
440        }
441
442        let test: TestConfig = toml::from_str(
443            r#"
444            [peer]
445            addr = "example.com"
446            mode = "Pool"
447            max-peers = 42
448            "#,
449        )
450        .unwrap();
451        assert!(matches!(test.peer, PeerConfig::Pool(_)));
452        if let PeerConfig::Pool(config) = test.peer {
453            assert_eq!(config.addr.to_string(), "example.com:123");
454            assert_eq!(config.max_peers, 42);
455        }
456
457        let test: TestConfig = toml::from_str(
458            r#"
459            [peer]
460            ke-addr = "example.com"
461            mode = "NtsServer"
462            "#,
463        )
464        .unwrap();
465        assert!(matches!(test.peer, PeerConfig::Nts(_)));
466        if let PeerConfig::Nts(config) = test.peer {
467            assert_eq!(config.ke_addr.to_string(), "example.com:4460");
468        }
469    }
470
471    #[test]
472    fn test_deserialize_peer_pem_certificate() {
473        let contents = include_bytes!("../../testdata/certificates/nos-nl.pem");
474        let path = std::env::temp_dir().join("nos-nl.pem");
475        std::fs::write(&path, contents).unwrap();
476
477        #[derive(Deserialize, Debug)]
478        struct TestConfig {
479            peer: PeerConfig,
480        }
481
482        let test: TestConfig = toml::from_str(&format!(
483            r#"
484                [peer]
485                ke-addr = "example.com"
486                certificate = "{}"
487                mode = "NtsServer"
488                "#,
489            path.display()
490        ))
491        .unwrap();
492        assert!(matches!(test.peer, PeerConfig::Nts(_)));
493        if let PeerConfig::Nts(config) = test.peer {
494            assert_eq!(config.ke_addr.to_string(), "example.com:4460");
495        }
496    }
497
498    #[test]
499    fn test_peer_from_string() {
500        let peer = PeerConfig::try_from("example.com").unwrap();
501        assert_eq!(peer_addr(&peer), "example.com:123");
502        assert!(matches!(peer, PeerConfig::Standard(_)));
503
504        let peer = PeerConfig::try_from("example.com:5678").unwrap();
505        assert_eq!(peer_addr(&peer), "example.com:5678");
506        assert!(matches!(peer, PeerConfig::Standard(_)));
507    }
508
509    #[test]
510    fn test_normalize_addr() {
511        let addr = NormalizedAddress::from_string_ntp("[::1]:456".into()).unwrap();
512        assert_eq!(addr.to_string(), "[::1]:456");
513        let addr = NormalizedAddress::from_string_ntp("::1".into()).unwrap();
514        assert_eq!(addr.to_string(), "[::1]:123");
515        assert!(NormalizedAddress::from_string_ntp(":some:invalid:1".into()).is_err());
516        let addr = NormalizedAddress::from_string_ntp("127.0.0.1:456".into()).unwrap();
517        assert_eq!(addr.to_string(), "127.0.0.1:456");
518        let addr = NormalizedAddress::from_string_ntp("127.0.0.1".into()).unwrap();
519        assert_eq!(addr.to_string(), "127.0.0.1:123");
520        let addr = NormalizedAddress::from_string_ntp("1234567890.example.com".into()).unwrap();
521        assert_eq!(addr.to_string(), "1234567890.example.com:123");
522    }
523}