use std::fmt::Debug;
use std::{collections::HashMap, marker::PhantomData};
use ntp_proto::{
Measurement, NtpClock, NtpDuration, PollInterval, SourceConfig, SourceController,
TimeSyncController,
};
use serde::Deserialize;
use crate::daemon::spawn::SourceId;
#[derive(Debug, Clone)]
pub enum Measurements {
Ntp(Measurement<NtpDuration>),
Sock(Measurement<()>),
}
impl Measurements {
fn get_offset(&self) -> NtpDuration {
match self {
Measurements::Ntp(measurement) => measurement.offset,
Measurements::Sock(measurement) => measurement.offset,
}
}
}
pub trait WrapMeasurements<D: Debug + Copy + Clone> {
fn wrap(&self) -> Measurements;
}
impl WrapMeasurements<NtpDuration> for Measurement<NtpDuration> {
fn wrap(&self) -> Measurements {
Measurements::Ntp(*self)
}
}
impl WrapMeasurements<()> for Measurement<()> {
fn wrap(&self) -> Measurements {
Measurements::Sock(*self)
}
}
pub(crate) struct SingleShotController<C> {
pub(super) clock: C,
sources: HashMap<SourceId, Measurements>,
min_agreeing: usize,
}
#[derive(Debug, Copy, Clone, Deserialize)]
pub(crate) struct SingleShotControllerConfig {
pub expected_sources: usize,
}
pub(crate) struct SingleShotSourceController<D: Debug + Copy + Clone> {
delay_type: PhantomData<D>,
min_poll_interval: PollInterval,
done: bool,
ignore: bool,
}
#[derive(Debug, Copy, Clone)]
pub(crate) enum SingleShotControllerMessage {}
impl<C: NtpClock> SingleShotController<C> {
const ASSUMED_UNCERTAINTY: NtpDuration = NtpDuration::from_exponent(-1);
fn try_steer(&self) {
if self.sources.len() < self.min_agreeing {
return;
}
struct Event {
offset: NtpDuration,
count: isize,
}
let mut events: Vec<_> = self
.sources
.values()
.flat_map(|m| {
[
Event {
offset: m.get_offset() - Self::ASSUMED_UNCERTAINTY,
count: 1,
},
Event {
offset: m.get_offset() + Self::ASSUMED_UNCERTAINTY,
count: -1,
},
]
.into_iter()
})
.collect();
events.sort_by(|a, b| a.offset.cmp(&b.offset));
let mut peak = 0;
let mut peak_offset = events[0].offset;
let mut cur = 0;
for ev in events {
cur += ev.count;
if cur > peak {
peak = cur;
peak_offset = ev.offset;
}
}
if peak as usize >= self.min_agreeing {
let mut sum = 0.0;
let mut count = 0;
for source in self.sources.values() {
if source.get_offset().abs_diff(peak_offset) <= Self::ASSUMED_UNCERTAINTY {
count += 1;
sum += source.get_offset().to_seconds();
}
}
let avg_offset = NtpDuration::from_seconds(sum / (count as f64));
self.offer_clock_change(avg_offset);
std::process::exit(0);
}
}
}
impl<C: NtpClock> TimeSyncController for SingleShotController<C> {
type Clock = C;
type SourceId = SourceId;
type AlgorithmConfig = SingleShotControllerConfig;
type ControllerMessage = SingleShotControllerMessage;
type SourceMessage = Measurements;
type NtpSourceController = SingleShotSourceController<NtpDuration>;
type OneWaySourceController = SingleShotSourceController<()>;
fn new(
clock: Self::Clock,
synchronization_config: ntp_proto::SynchronizationConfig,
algorithm_config: Self::AlgorithmConfig,
) -> Result<Self, <Self::Clock as ntp_proto::NtpClock>::Error> {
Ok(SingleShotController {
clock,
sources: HashMap::new(),
min_agreeing: synchronization_config
.minimum_agreeing_sources
.max(algorithm_config.expected_sources / 2),
})
}
fn take_control(&mut self) -> Result<(), <Self::Clock as ntp_proto::NtpClock>::Error> {
Ok(())
}
fn add_source(
&mut self,
_id: Self::SourceId,
config: SourceConfig,
) -> Self::NtpSourceController {
SingleShotSourceController::<NtpDuration> {
delay_type: PhantomData,
min_poll_interval: config.poll_interval_limits.min,
done: false,
ignore: false,
}
}
fn add_one_way_source(
&mut self,
_id: Self::SourceId,
config: SourceConfig,
_measurement_noise_estimate: f64,
_measurement_accuracy_estimate: f64,
period: Option<f64>,
) -> Self::OneWaySourceController {
SingleShotSourceController::<()> {
delay_type: PhantomData,
min_poll_interval: config.poll_interval_limits.min,
done: false,
ignore: period.is_some(),
}
}
fn remove_source(&mut self, id: Self::SourceId) {
self.sources.remove(&id);
}
fn source_update(&mut self, id: Self::SourceId, usable: bool) {
if !usable {
self.sources.remove(&id);
}
}
fn source_message(
&mut self,
id: Self::SourceId,
message: Self::SourceMessage,
) -> ntp_proto::StateUpdate<Self::SourceId, Self::ControllerMessage> {
self.sources.insert(id, message);
self.try_steer();
ntp_proto::StateUpdate::default()
}
fn time_update(&mut self) -> ntp_proto::StateUpdate<Self::SourceId, Self::ControllerMessage> {
ntp_proto::StateUpdate::default()
}
}
impl<D: Debug + Copy + Clone + Send + 'static> SourceController for SingleShotSourceController<D>
where
Measurement<D>: WrapMeasurements<D>,
{
type ControllerMessage = SingleShotControllerMessage;
type MeasurementDelay = D;
type SourceMessage = Measurements;
fn handle_message(&mut self, _message: Self::ControllerMessage) {
}
fn handle_measurement(
&mut self,
measurement: Measurement<Self::MeasurementDelay>,
) -> Option<Self::SourceMessage> {
self.done = true;
if self.ignore {
None
} else {
Some(measurement.wrap())
}
}
fn desired_poll_interval(&self) -> ntp_proto::PollInterval {
if self.done {
PollInterval::NEVER
} else {
self.min_poll_interval
}
}
fn observe(&self) -> ntp_proto::ObservableSourceTimedata {
ntp_proto::ObservableSourceTimedata::default()
}
}