ntpd 1.8.0

Full-featured implementation of NTP with NTS support
Documentation
use std::borrow::Cow;
use std::collections::VecDeque;
use std::fmt::Display;
use std::ops::Deref;

use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tracing::warn;

use ntp_proto::{KeyExchangeClient, NtsClientConfig, NtsError, SourceConfig};

use crate::daemon::config::{NormalizedAddress, NtpAddress};
use crate::daemon::dns::{KeResolutionResult, resolve_ke};
use crate::daemon::spawn::resolve_single_ntp_server;

use super::super::config::NtsPoolSourceConfig;

use super::{SourceId, SourceRemovedEvent, SpawnAction, SpawnEvent, Spawner, SpawnerId};

struct PoolSource {
    id: SourceId,
    remote: String,
}

pub struct NtsPoolSpawner {
    config: NtsPoolSourceConfig,
    key_exchange_client: KeyExchangeClient,
    source_config: SourceConfig,
    id: SpawnerId,
    current_sources: Vec<PoolSource>,
    known_resolutions: VecDeque<KeResolutionResult>,
}

#[derive(Debug)]
pub enum NtsPoolSpawnError {
    SendError(mpsc::error::SendError<SpawnEvent>),
}

impl std::error::Error for NtsPoolSpawnError {}

impl Display for NtsPoolSpawnError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::SendError(e) => write!(f, "Channel send error: {e}"),
        }
    }
}

impl From<mpsc::error::SendError<SpawnEvent>> for NtsPoolSpawnError {
    fn from(value: mpsc::error::SendError<SpawnEvent>) -> Self {
        Self::SendError(value)
    }
}

impl NtsPoolSpawner {
    pub fn new(
        config: NtsPoolSourceConfig,
        source_config: SourceConfig,
    ) -> Result<NtsPoolSpawner, NtsError> {
        let key_exchange_client = KeyExchangeClient::new(&NtsClientConfig {
            certificates: config.certificate_authorities.clone(),
            protocol_version: config.ntp_version,
        })?;

        Ok(NtsPoolSpawner {
            config,
            key_exchange_client,
            source_config,
            id: SpawnerId::new(),
            current_sources: vec![],
            known_resolutions: VecDeque::new(),
        })
    }

    fn contains_source(&self, domain: &str) -> bool {
        self.current_sources
            .iter()
            .any(|source| source.remote == domain)
    }

    async fn lookup(&mut self) -> Option<(TcpStream, String, Option<String>)> {
        if self.config.enable_srv_resolution {
            let mut did_resolution = false;
            loop {
                if self.known_resolutions.is_empty() {
                    if did_resolution {
                        warn!(
                            "Could not find more sources for pool at {}",
                            self.config.addr.server_name
                        );
                        return None;
                    }
                    did_resolution = true;
                    match resolve_ke(&self.config.addr).await {
                        Ok(resolutions) => self.known_resolutions.extend(resolutions),
                        Err(e) => {
                            warn!(error=?e, "Error trying to resolve ke server domain name.");
                            return None;
                        }
                    }

                    if self.known_resolutions.is_empty() {
                        warn!("Unresolvable domain name {}", self.config.addr.server_name);
                        return None;
                    }
                }

                let mut last_error = None;
                while let Some(addr) = self.known_resolutions.pop_front() {
                    if let Some(name) = &addr.srv_record_name
                        && self.contains_source(name)
                    {
                        continue;
                    }
                    let io = match TcpStream::connect(addr.addr).await {
                        Ok(io) => io,
                        Err(e) => {
                            last_error = Some(e);
                            continue;
                        }
                    };
                    return Some((
                        io,
                        addr.srv_record_name
                            .clone()
                            .unwrap_or_else(|| self.config.addr.server_name.clone()),
                        addr.srv_record_name,
                    ));
                }

                if let Some(e) = last_error {
                    warn!(error = ?e, "error while attempting key exchange");
                    return None;
                }
            }
        } else {
            let io = match TcpStream::connect((
                self.config.addr.server_name.as_str(),
                self.config.addr.port,
            ))
            .await
            {
                Ok(io) => io,
                Err(e) => {
                    warn!(error = ?e, "error while attempting key exchange");
                    return None;
                }
            };
            Some((io, self.config.addr.server_name.clone(), None))
        }
    }
}

impl Spawner for NtsPoolSpawner {
    type Error = NtsPoolSpawnError;

    async fn try_spawn(
        &mut self,
        action_tx: &mpsc::Sender<SpawnEvent>,
    ) -> Result<(), NtsPoolSpawnError> {
        for _ in 0..self.config.count.saturating_sub(self.current_sources.len()) {
            let Some((io, name, remote_name)) = self.lookup().await else {
                return Ok(());
            };

            match tokio::time::timeout(
                super::NTS_TIMEOUT,
                self.key_exchange_client.exchange_keys(
                    io,
                    name,
                    self.current_sources
                        .iter()
                        .map(|source| Cow::Borrowed(source.remote.as_str())),
                ),
            )
            .await
            {
                Ok(Ok(ke))
                    if !self.contains_source(remote_name.as_deref().unwrap_or(&ke.remote)) =>
                {
                    if let Some(address) = resolve_single_ntp_server(NtpAddress(
                        NormalizedAddress::new_from_parts(ke.remote.as_str(), ke.port),
                    ))
                    .await
                    {
                        let id = SourceId::new();
                        self.current_sources.push(PoolSource {
                            id,
                            remote: remote_name.unwrap_or(ke.remote),
                        });
                        action_tx
                            .send(SpawnEvent::new(
                                self.id,
                                SpawnAction::create_ntp(
                                    id,
                                    address,
                                    self.config.addr.deref().clone(),
                                    ke.protocol_version,
                                    self.source_config,
                                    Some(ke.nts),
                                ),
                            ))
                            .await?;
                    }
                }
                Ok(Ok(_)) => {
                    warn!("received an address from pool-ke that we already had, ignoring");
                }
                Ok(Err(e)) => {
                    warn!(error = ?e, "error while attempting key exchange");
                    break;
                }
                Err(_) => {
                    warn!("timeout while attempting key exchange");
                }
            }
        }

        Ok(())
    }

    fn is_complete(&self) -> bool {
        self.current_sources.len() >= self.config.count
    }

    async fn handle_source_removed(
        &mut self,
        removed_source: SourceRemovedEvent,
    ) -> Result<(), NtsPoolSpawnError> {
        self.current_sources.retain(|p| p.id != removed_source.id);
        Ok(())
    }

    fn get_id(&self) -> SpawnerId {
        self.id
    }

    fn get_addr_description(&self) -> String {
        format!("{} ({})", self.config.addr.deref(), self.config.count)
    }

    fn get_description(&self) -> &str {
        "nts-pool"
    }
}