1use std::{
2 fmt,
3 net::SocketAddr,
4 path::PathBuf,
5 sync::{Arc, Mutex},
6};
7
8use rustls::Certificate;
9use serde::{
10 de::{self, MapAccess, Visitor},
11 Deserialize, Deserializer,
12};
13
14use crate::keyexchange::certificates_from_file;
15
16#[derive(Deserialize, Debug, PartialEq, Eq, Clone, Copy, Default)]
17pub enum PeerHostMode {
18 #[serde(alias = "server")]
19 #[default]
20 Server,
21 #[serde(alias = "nts-server")]
22 NtsServer,
23 #[serde(alias = "pool")]
24 Pool,
25}
26
27#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
28#[serde(rename_all = "kebab-case", deny_unknown_fields)]
29pub struct StandardPeerConfig {
30 pub addr: NormalizedAddress,
31}
32
33#[derive(Debug, PartialEq, Eq, Clone)]
34pub struct NtsPeerConfig {
35 pub ke_addr: NormalizedAddress,
36 pub certificates: Arc<[Certificate]>,
37}
38
39#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
40#[serde(rename_all = "kebab-case", deny_unknown_fields)]
41pub struct PoolPeerConfig {
42 pub addr: NormalizedAddress,
43 pub max_peers: usize,
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub enum PeerConfig {
48 Standard(StandardPeerConfig),
49 Nts(NtsPeerConfig),
50 Pool(PoolPeerConfig),
51 }
53
54impl PeerConfig {
55 pub(crate) fn try_from_str(value: &str) -> Result<Self, std::io::Error> {
56 Self::try_from(value)
57 }
58}
59
60#[derive(Deserialize, Debug, Clone)]
63#[serde(rename_all = "kebab-case", deny_unknown_fields)]
64pub struct NormalizedAddress {
65 pub(crate) server_name: String,
66 pub(crate) port: u16,
67
68 #[cfg(test)]
70 hardcoded_dns_resolve: HardcodedDnsResolve,
71}
72
73impl Eq for NormalizedAddress {}
74
75impl PartialEq for NormalizedAddress {
76 fn eq(&self, other: &Self) -> bool {
77 self.server_name == other.server_name && self.port == other.port
78 }
79}
80
81#[derive(Deserialize, Debug, Clone, Default)]
82struct HardcodedDnsResolve {
83 #[cfg_attr(not(test), allow(unused))]
84 #[serde(skip)]
85 addresses: Arc<Mutex<Vec<SocketAddr>>>,
86}
87
88impl From<Vec<SocketAddr>> for HardcodedDnsResolve {
89 fn from(value: Vec<SocketAddr>) -> Self {
90 Self {
91 addresses: Arc::new(Mutex::new(value)),
92 }
93 }
94}
95
96impl NormalizedAddress {
97 const NTP_DEFAULT_PORT: u16 = 123;
98 const NTS_KE_DEFAULT_PORT: u16 = 4460;
99
100 pub(crate) fn from_string_ntp(address: String) -> std::io::Result<Self> {
102 let (server_name, port) = Self::from_string_help(address, Self::NTP_DEFAULT_PORT)?;
103
104 Ok(Self {
105 server_name,
106 port,
107
108 #[cfg(test)]
109 hardcoded_dns_resolve: HardcodedDnsResolve::default(),
110 })
111 }
112
113 fn from_string_nts_ke(address: String) -> std::io::Result<Self> {
115 let (server_name, port) = Self::from_string_help(address, Self::NTS_KE_DEFAULT_PORT)?;
116
117 Ok(Self {
118 server_name,
119 port,
120
121 #[cfg(test)]
122 hardcoded_dns_resolve: HardcodedDnsResolve::default(),
123 })
124 }
125
126 fn from_string_help(address: String, default_port: u16) -> std::io::Result<(String, u16)> {
127 if address.split(':').count() > 2 {
128 match address.parse::<SocketAddr>() {
130 Ok(socket_addr) => {
131 let (server_name, _) = address.rsplit_once(':').unwrap();
133
134 Ok((server_name.to_string(), socket_addr.port()))
135 }
136 Err(e) => {
137 let address_with_port = format!("[{address}]:{default_port}");
139 if address_with_port.parse::<SocketAddr>().is_ok() {
140 Ok((format!("[{address}]"), default_port))
141 } else {
142 Err(std::io::Error::new(std::io::ErrorKind::Other, e))
143 }
144 }
145 }
146 } else if let Some((server_name, port)) = address.split_once(':') {
147 match port.parse::<u16>() {
151 Ok(port) => Ok((server_name.to_string(), port)),
152 Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e)),
153 }
154 } else {
155 Ok((address, default_port))
158 }
159 }
160
161 #[cfg(test)]
162 pub(crate) fn new_unchecked(server_name: &str, port: u16) -> Self {
163 Self {
164 server_name: server_name.to_string(),
165 port,
166
167 #[cfg(test)]
168 hardcoded_dns_resolve: HardcodedDnsResolve::default(),
169 }
170 }
171
172 #[cfg(test)]
173 pub(crate) fn with_hardcoded_dns(
174 server_name: &str,
175 port: u16,
176 hardcoded_dns_resolve: Vec<SocketAddr>,
177 ) -> Self {
178 Self {
179 server_name: server_name.to_string(),
180 port,
181 hardcoded_dns_resolve: HardcodedDnsResolve::from(hardcoded_dns_resolve),
182 }
183 }
184
185 #[cfg(not(test))]
186 pub async fn lookup_host(&self) -> std::io::Result<impl Iterator<Item = SocketAddr> + '_> {
187 tokio::net::lookup_host((self.server_name.as_str(), self.port)).await
188 }
189
190 #[cfg(test)]
191 pub async fn lookup_host(&self) -> std::io::Result<impl Iterator<Item = SocketAddr> + '_> {
192 let mut addresses = self.hardcoded_dns_resolve.addresses.lock().unwrap();
195
196 if let Some(last) = addresses.pop() {
197 addresses.insert(0, last);
198 }
199
200 let addresses = addresses.to_vec();
201
202 Ok(addresses.into_iter())
203 }
204}
205
206impl std::fmt::Display for NormalizedAddress {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 write!(f, "{}:{}", self.server_name, self.port)
209 }
210}
211
212impl TryFrom<&str> for StandardPeerConfig {
213 type Error = std::io::Error;
214
215 fn try_from(value: &str) -> Result<Self, Self::Error> {
216 Ok(Self {
217 addr: NormalizedAddress::from_string_ntp(value.to_string())?,
218 })
219 }
220}
221
222impl<'a> TryFrom<&'a str> for PeerConfig {
223 type Error = std::io::Error;
224
225 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
226 StandardPeerConfig::try_from(value).map(Self::Standard)
227 }
228}
229
230impl<'de> Deserialize<'de> for PeerConfig {
233 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
234 where
235 D: Deserializer<'de>,
236 {
237 struct PeerConfigVisitor;
238
239 impl<'de> Visitor<'de> for PeerConfigVisitor {
240 type Value = PeerConfig;
241
242 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
243 formatter.write_str("string or map")
244 }
245
246 fn visit_str<E: de::Error>(self, value: &str) -> Result<PeerConfig, E> {
247 TryFrom::try_from(value).map_err(de::Error::custom)
248 }
249
250 fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<PeerConfig, M::Error> {
251 let mut ke_addr = None;
252 let mut opt_certificate_path = None;
253 let mut addr = None;
254 let mut mode = None;
255 let mut max_peers = None;
256 while let Some(key) = map.next_key::<String>()? {
257 match key.as_str() {
258 "addr" => {
259 if addr.is_some() {
260 return Err(de::Error::duplicate_field("addr"));
261 }
262 let raw: String = map.next_value()?;
263
264 let parsed_addr =
265 NormalizedAddress::from_string_ntp(raw.as_str().to_string())
266 .map_err(de::Error::custom)?;
267
268 addr = Some(parsed_addr);
269 }
270 "ke-addr" => {
271 if ke_addr.is_some() {
272 return Err(de::Error::duplicate_field("ke-addr"));
273 }
274 let raw: String = map.next_value()?;
275
276 let parsed_addr =
277 NormalizedAddress::from_string_nts_ke(raw.as_str().to_string())
278 .map_err(de::Error::custom)?;
279
280 ke_addr = Some(parsed_addr);
281 }
282 "certificate" => {
283 if opt_certificate_path.is_some() {
284 return Err(de::Error::duplicate_field("certificate"));
285 }
286 let raw: String = map.next_value()?;
287
288 opt_certificate_path = Some(PathBuf::from(raw));
289 }
290 "mode" => {
291 if mode.is_some() {
292 return Err(de::Error::duplicate_field("mode"));
293 }
294 mode = Some(map.next_value()?);
295 }
296 "max-peers" => {
297 if max_peers.is_some() {
298 return Err(de::Error::duplicate_field("max-peers"));
299 }
300 max_peers = Some(map.next_value()?);
301 }
302 _ => {
303 return Err(de::Error::unknown_field(
304 key.as_str(),
305 &["addr", "ke-addr", "certificate", "mode", "max-peers"],
306 ));
307 }
308 }
309 }
310
311 let mode = mode.unwrap_or_default();
312
313 let unknown_field =
314 |field, valid_fields| Err(de::Error::unknown_field(field, valid_fields));
315
316 match mode {
317 PeerHostMode::Server => {
318 let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
319
320 let valid_fields = &["addr", "mode"];
321 if max_peers.is_some() {
322 unknown_field("max-peers", valid_fields)
323 } else if ke_addr.is_some() {
324 unknown_field("ke-addr", valid_fields)
325 } else if opt_certificate_path.is_some() {
326 unknown_field("certificate", valid_fields)
327 } else {
328 Ok(PeerConfig::Standard(StandardPeerConfig { addr }))
329 }
330 }
331 PeerHostMode::NtsServer => {
332 let ke_addr = ke_addr.ok_or_else(|| de::Error::missing_field("ke-addr"))?;
333
334 let valid_fields = &["mode", "ke-addr", "certificate"];
335 if max_peers.is_some() {
336 unknown_field("max-peers", valid_fields)
337 } else {
338 let certificates: Arc<[Certificate]> = if let Some(certificate_path) =
339 opt_certificate_path
340 {
341 match certificates_from_file(&certificate_path) {
342 Ok(certificates) => Arc::from(certificates),
343 Err(io_error) => {
344 let msg = format!(
345 "error while parsing certificate file {certificate_path:?}: {io_error:?}"
346 );
347 return Err(de::Error::custom(msg));
348 }
349 }
350 } else {
351 Arc::from([])
352 };
353
354 Ok(PeerConfig::Nts(NtsPeerConfig {
355 ke_addr,
356 certificates,
357 }))
358 }
359 }
360 PeerHostMode::Pool => {
361 let addr = addr.ok_or_else(|| de::Error::missing_field("addr"))?;
362
363 let valid_fields = &["addr", "mode", "max-peers"];
364 if ke_addr.is_some() {
365 unknown_field("ke-addr", valid_fields)
366 } else if opt_certificate_path.is_some() {
367 unknown_field("certificate", valid_fields)
368 } else {
369 let max_peers = max_peers.unwrap_or(1);
370
371 Ok(PeerConfig::Pool(PoolPeerConfig { addr, max_peers }))
372 }
373 }
374 }
375 }
376 }
377
378 deserializer.deserialize_any(PeerConfigVisitor)
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 fn peer_addr(config: &PeerConfig) -> String {
387 match config {
388 PeerConfig::Standard(c) => c.addr.to_string(),
389 PeerConfig::Nts(c) => c.ke_addr.to_string(),
390 PeerConfig::Pool(c) => c.addr.to_string(),
391 }
392 }
393
394 #[test]
395 fn test_deserialize_peer() {
396 #[derive(Deserialize, Debug)]
397 struct TestConfig {
398 peer: PeerConfig,
399 }
400
401 let test: TestConfig = toml::from_str("peer = \"example.com\"").unwrap();
402 assert_eq!(peer_addr(&test.peer), "example.com:123");
403 assert!(matches!(test.peer, PeerConfig::Standard(_)));
404
405 let test: TestConfig = toml::from_str("peer = \"example.com:5678\"").unwrap();
406 assert_eq!(peer_addr(&test.peer), "example.com:5678");
407 assert!(matches!(test.peer, PeerConfig::Standard(_)));
408
409 let test: TestConfig = toml::from_str("[peer]\naddr = \"example.com\"").unwrap();
410 assert_eq!(peer_addr(&test.peer), "example.com:123");
411 assert!(matches!(test.peer, PeerConfig::Standard(_)));
412
413 let test: TestConfig = toml::from_str("[peer]\naddr = \"example.com:5678\"").unwrap();
414 assert_eq!(peer_addr(&test.peer), "example.com:5678");
415 assert!(matches!(test.peer, PeerConfig::Standard(_)));
416
417 let test: TestConfig = toml::from_str(
418 r#"
419 [peer]
420 addr = "example.com"
421 mode = "Server"
422 "#,
423 )
424 .unwrap();
425 assert_eq!(peer_addr(&test.peer), "example.com:123");
426 assert!(matches!(test.peer, PeerConfig::Standard(_)));
427
428 let test: TestConfig = toml::from_str(
429 r#"
430 [peer]
431 addr = "example.com"
432 mode = "Pool"
433 "#,
434 )
435 .unwrap();
436 assert!(matches!(test.peer, PeerConfig::Pool(_)));
437 if let PeerConfig::Pool(config) = test.peer {
438 assert_eq!(config.addr.to_string(), "example.com:123");
439 assert_eq!(config.max_peers, 1);
440 }
441
442 let test: TestConfig = toml::from_str(
443 r#"
444 [peer]
445 addr = "example.com"
446 mode = "Pool"
447 max-peers = 42
448 "#,
449 )
450 .unwrap();
451 assert!(matches!(test.peer, PeerConfig::Pool(_)));
452 if let PeerConfig::Pool(config) = test.peer {
453 assert_eq!(config.addr.to_string(), "example.com:123");
454 assert_eq!(config.max_peers, 42);
455 }
456
457 let test: TestConfig = toml::from_str(
458 r#"
459 [peer]
460 ke-addr = "example.com"
461 mode = "NtsServer"
462 "#,
463 )
464 .unwrap();
465 assert!(matches!(test.peer, PeerConfig::Nts(_)));
466 if let PeerConfig::Nts(config) = test.peer {
467 assert_eq!(config.ke_addr.to_string(), "example.com:4460");
468 }
469 }
470
471 #[test]
472 fn test_deserialize_peer_pem_certificate() {
473 let contents = include_bytes!("../../testdata/certificates/nos-nl.pem");
474 let path = std::env::temp_dir().join("nos-nl.pem");
475 std::fs::write(&path, contents).unwrap();
476
477 #[derive(Deserialize, Debug)]
478 struct TestConfig {
479 peer: PeerConfig,
480 }
481
482 let test: TestConfig = toml::from_str(&format!(
483 r#"
484 [peer]
485 ke-addr = "example.com"
486 certificate = "{}"
487 mode = "NtsServer"
488 "#,
489 path.display()
490 ))
491 .unwrap();
492 assert!(matches!(test.peer, PeerConfig::Nts(_)));
493 if let PeerConfig::Nts(config) = test.peer {
494 assert_eq!(config.ke_addr.to_string(), "example.com:4460");
495 }
496 }
497
498 #[test]
499 fn test_peer_from_string() {
500 let peer = PeerConfig::try_from("example.com").unwrap();
501 assert_eq!(peer_addr(&peer), "example.com:123");
502 assert!(matches!(peer, PeerConfig::Standard(_)));
503
504 let peer = PeerConfig::try_from("example.com:5678").unwrap();
505 assert_eq!(peer_addr(&peer), "example.com:5678");
506 assert!(matches!(peer, PeerConfig::Standard(_)));
507 }
508
509 #[test]
510 fn test_normalize_addr() {
511 let addr = NormalizedAddress::from_string_ntp("[::1]:456".into()).unwrap();
512 assert_eq!(addr.to_string(), "[::1]:456");
513 let addr = NormalizedAddress::from_string_ntp("::1".into()).unwrap();
514 assert_eq!(addr.to_string(), "[::1]:123");
515 assert!(NormalizedAddress::from_string_ntp(":some:invalid:1".into()).is_err());
516 let addr = NormalizedAddress::from_string_ntp("127.0.0.1:456".into()).unwrap();
517 assert_eq!(addr.to_string(), "127.0.0.1:456");
518 let addr = NormalizedAddress::from_string_ntp("127.0.0.1".into()).unwrap();
519 assert_eq!(addr.to_string(), "127.0.0.1:123");
520 let addr = NormalizedAddress::from_string_ntp("1234567890.example.com".into()).unwrap();
521 assert_eq!(addr.to_string(), "1234567890.example.com:123");
522 }
523}