ntpd 1.7.2

Full-featured implementation of NTP with NTS support
Documentation
use std::borrow::Cow;
use std::fmt::Display;
use std::ops::Deref;

use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tracing::warn;

use ntp_proto::{KeyExchangeClient, NtsClientConfig, NtsError, SourceConfig};

use super::super::config::NtsPoolSourceConfig;

use super::{SourceId, SourceRemovedEvent, SpawnAction, SpawnEvent, Spawner, SpawnerId};

use super::nts::resolve_addr;

struct PoolSource {
    id: SourceId,
    remote: String,
}

pub struct NtsPoolSpawner {
    config: NtsPoolSourceConfig,
    key_exchange_client: KeyExchangeClient,
    source_config: SourceConfig,
    id: SpawnerId,
    current_sources: Vec<PoolSource>,
}

#[derive(Debug)]
pub enum NtsPoolSpawnError {
    SendError(mpsc::error::SendError<SpawnEvent>),
}

impl std::error::Error for NtsPoolSpawnError {}

impl Display for NtsPoolSpawnError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::SendError(e) => write!(f, "Channel send error: {e}"),
        }
    }
}

impl From<mpsc::error::SendError<SpawnEvent>> for NtsPoolSpawnError {
    fn from(value: mpsc::error::SendError<SpawnEvent>) -> Self {
        Self::SendError(value)
    }
}

impl NtsPoolSpawner {
    pub fn new(
        config: NtsPoolSourceConfig,
        source_config: SourceConfig,
    ) -> Result<NtsPoolSpawner, NtsError> {
        let key_exchange_client = KeyExchangeClient::new(NtsClientConfig {
            certificates: config.certificate_authorities.clone(),
            protocol_version: config.ntp_version,
        })?;

        Ok(NtsPoolSpawner {
            config,
            key_exchange_client,
            source_config,
            id: SpawnerId::new(),
            current_sources: vec![],
        })
    }

    fn contains_source(&self, domain: &str) -> bool {
        self.current_sources
            .iter()
            .any(|source| source.remote == domain)
    }
}

impl Spawner for NtsPoolSpawner {
    type Error = NtsPoolSpawnError;

    async fn try_spawn(
        &mut self,
        action_tx: &mpsc::Sender<SpawnEvent>,
    ) -> Result<(), NtsPoolSpawnError> {
        for _ in 0..self.config.count.saturating_sub(self.current_sources.len()) {
            let io = match TcpStream::connect((
                self.config.addr.server_name.as_str(),
                self.config.addr.port,
            ))
            .await
            {
                Ok(io) => io,
                Err(e) => {
                    warn!(error = ?e, "error while attempting key exchange");
                    break;
                }
            };

            match tokio::time::timeout(
                super::NTS_TIMEOUT,
                self.key_exchange_client.exchange_keys(
                    io,
                    self.config.addr.server_name.clone(),
                    self.current_sources
                        .iter()
                        .map(|source| Cow::Borrowed(source.remote.as_str())),
                ),
            )
            .await
            {
                Ok(Ok(ke)) if !self.contains_source(&ke.remote) => {
                    if let Some(address) = resolve_addr((ke.remote.as_str(), ke.port)).await {
                        let id = SourceId::new();
                        self.current_sources.push(PoolSource {
                            id,
                            remote: ke.remote,
                        });
                        action_tx
                            .send(SpawnEvent::new(
                                self.id,
                                SpawnAction::create_ntp(
                                    id,
                                    address,
                                    self.config.addr.deref().clone(),
                                    ke.protocol_version,
                                    self.source_config,
                                    Some(ke.nts),
                                ),
                            ))
                            .await?;
                    }
                }
                Ok(Ok(_)) => {
                    warn!("received an address from pool-ke that we already had, ignoring");
                    continue;
                }
                Ok(Err(e)) => {
                    warn!(error = ?e, "error while attempting key exchange");
                    break;
                }
                Err(_) => {
                    warn!("timeout while attempting key exchange");
                }
            }
        }

        Ok(())
    }

    fn is_complete(&self) -> bool {
        self.current_sources.len() >= self.config.count
    }

    async fn handle_source_removed(
        &mut self,
        removed_source: SourceRemovedEvent,
    ) -> Result<(), NtsPoolSpawnError> {
        self.current_sources.retain(|p| p.id != removed_source.id);
        Ok(())
    }

    fn get_id(&self) -> SpawnerId {
        self.id
    }

    fn get_addr_description(&self) -> String {
        format!("{} ({})", self.config.addr.deref(), self.config.count)
    }

    fn get_description(&self) -> &str {
        "nts-pool"
    }
}