Skip to main content

ntp_proto/
system.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5use std::time::Duration;
6use std::{fmt::Debug, hash::Hash};
7
8use crate::packet::v5::server_reference_id::{BloomFilter, ServerId};
9use crate::source::{NtpSourceUpdate, SourceSnapshot};
10use crate::{NtpTimestamp, OneWaySource, OneWaySourceUpdate};
11use crate::{
12    algorithm::{StateUpdate, TimeSyncController},
13    clock::NtpClock,
14    config::{SourceConfig, SynchronizationConfig},
15    identifiers::ReferenceId,
16    packet::NtpLeapIndicator,
17    source::{NtpSource, NtpSourceActionIterator, ProtocolVersion, SourceNtsData},
18    time_types::NtpDuration,
19};
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
22pub struct TimeSnapshot {
23    /// Precision of the local clock
24    pub precision: NtpDuration,
25    /// Current root delay
26    pub root_delay: NtpDuration,
27    /// t=0 for root variance calculation
28    pub root_variance_base_time: NtpTimestamp,
29    /// Constant contribution for root variance
30    pub root_variance_base: f64,
31    /// Linear (*t) contribution for root variance
32    pub root_variance_linear: f64,
33    /// Quadratic (*t*t) contribution for root variance
34    pub root_variance_quadratic: f64,
35    /// Cubic (*t*t*t) contribution for root variance
36    pub root_variance_cubic: f64,
37    /// Current leap indicator state
38    pub leap_indicator: NtpLeapIndicator,
39    /// Total amount that the clock has stepped
40    pub accumulated_steps: NtpDuration,
41}
42
43impl TimeSnapshot {
44    pub fn root_dispersion(&self, now: NtpTimestamp) -> NtpDuration {
45        let t = (now - self.root_variance_base_time).to_seconds();
46        // Note: dispersion is the standard deviation, so we need a sqrt here.
47        NtpDuration::from_seconds(
48            (self.root_variance_base
49                + t * self.root_variance_linear
50                + t.powi(2) * self.root_variance_quadratic
51                + t.powi(3) * self.root_variance_cubic)
52                .sqrt(),
53        )
54    }
55}
56
57impl Default for TimeSnapshot {
58    fn default() -> Self {
59        Self {
60            precision: NtpDuration::from_exponent(-18),
61            root_delay: NtpDuration::ZERO,
62            root_variance_base_time: NtpTimestamp::default(),
63            root_variance_base: 0.0,
64            root_variance_linear: 0.0,
65            root_variance_quadratic: 0.0,
66            root_variance_cubic: 0.0,
67            leap_indicator: NtpLeapIndicator::Unknown,
68            accumulated_steps: NtpDuration::ZERO,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
74pub struct SystemSnapshot {
75    /// Log of the precision of the local clock
76    pub stratum: u8,
77    /// Reference ID of current primary time source
78    pub reference_id: ReferenceId,
79    /// Crossing this amount of stepping will cause a Panic
80    pub accumulated_steps_threshold: Option<NtpDuration>,
81    /// Timekeeping data
82    #[serde(flatten)]
83    pub time_snapshot: TimeSnapshot,
84    /// Bloom filter that contains all currently used time sources
85    #[serde(skip)]
86    pub bloom_filter: BloomFilter,
87    /// NTPv5 reference ID for this instance
88    #[serde(skip)]
89    pub server_id: ServerId,
90}
91
92impl SystemSnapshot {
93    pub fn update_timedata(&mut self, timedata: TimeSnapshot, config: &SynchronizationConfig) {
94        self.time_snapshot = timedata;
95        self.accumulated_steps_threshold = config.accumulated_step_panic_threshold;
96    }
97
98    pub fn update_used_sources(&mut self, used_sources: impl Iterator<Item = SourceSnapshot>) {
99        let mut used_sources = used_sources.peekable();
100        if let Some(system_source_snapshot) = used_sources.peek() {
101            let (stratum, source_id) = match system_source_snapshot {
102                SourceSnapshot::Ntp(snapshot) => (snapshot.stratum, snapshot.source_id),
103                SourceSnapshot::OneWay(snapshot) => (snapshot.stratum, snapshot.source_id),
104            };
105
106            self.stratum = stratum.saturating_add(1);
107            self.reference_id = source_id;
108        }
109
110        self.bloom_filter = BloomFilter::new();
111        for source in used_sources {
112            if let SourceSnapshot::Ntp(source) = source {
113                if let Some(bf) = &source.bloom_filter {
114                    self.bloom_filter.add(bf);
115                } else if let ProtocolVersion::V5 = source.protocol_version {
116                    tracing::warn!("Using NTPv5 source without a bloom filter!");
117                }
118            }
119        }
120        self.bloom_filter.add_id(&self.server_id);
121    }
122}
123
124impl Default for SystemSnapshot {
125    fn default() -> Self {
126        Self {
127            stratum: 16,
128            reference_id: ReferenceId::NONE,
129            accumulated_steps_threshold: None,
130            time_snapshot: TimeSnapshot::default(),
131            bloom_filter: BloomFilter::new(),
132            server_id: ServerId::default(),
133        }
134    }
135}
136
137pub struct SystemSourceUpdate<ControllerMessage> {
138    pub message: ControllerMessage,
139}
140
141impl<ControllerMessage: Debug> std::fmt::Debug for SystemSourceUpdate<ControllerMessage> {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("SystemSourceUpdate")
144            .field("message", &self.message)
145            .finish()
146    }
147}
148
149impl<ControllerMessage: Clone> Clone for SystemSourceUpdate<ControllerMessage> {
150    fn clone(&self) -> Self {
151        Self {
152            message: self.message.clone(),
153        }
154    }
155}
156
157#[derive(Debug, Clone)]
158pub enum SystemAction<ControllerMessage> {
159    UpdateSources(SystemSourceUpdate<ControllerMessage>),
160    SetTimer(Duration),
161}
162
163#[derive(Debug)]
164pub struct SystemActionIterator<ControllerMessage> {
165    iter: <Vec<SystemAction<ControllerMessage>> as IntoIterator>::IntoIter,
166}
167
168impl<ControllerMessage> Default for SystemActionIterator<ControllerMessage> {
169    fn default() -> Self {
170        Self {
171            iter: vec![].into_iter(),
172        }
173    }
174}
175
176impl<ControllerMessage> From<Vec<SystemAction<ControllerMessage>>>
177    for SystemActionIterator<ControllerMessage>
178{
179    fn from(value: Vec<SystemAction<ControllerMessage>>) -> Self {
180        Self {
181            iter: value.into_iter(),
182        }
183    }
184}
185
186impl<ControllerMessage> Iterator for SystemActionIterator<ControllerMessage> {
187    type Item = SystemAction<ControllerMessage>;
188
189    fn next(&mut self) -> Option<Self::Item> {
190        self.iter.next()
191    }
192}
193
194macro_rules! actions {
195    [$($action:expr),*] => {
196        {
197            SystemActionIterator::from(vec![$($action),*])
198        }
199    }
200}
201
202pub struct System<SourceId, Controller> {
203    synchronization_config: SynchronizationConfig,
204    system: SystemSnapshot,
205    ip_list: Arc<[IpAddr]>,
206
207    sources: HashMap<SourceId, Option<SourceSnapshot>>,
208
209    controller: Controller,
210    controller_took_control: bool,
211}
212
213impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId = SourceId>>
214    System<SourceId, Controller>
215{
216    pub fn new(
217        clock: Controller::Clock,
218        synchronization_config: SynchronizationConfig,
219        algorithm_config: Controller::AlgorithmConfig,
220        ip_list: Arc<[IpAddr]>,
221    ) -> Result<Self, <Controller::Clock as NtpClock>::Error> {
222        // Setup system snapshot
223        let mut system = SystemSnapshot {
224            stratum: synchronization_config.local_stratum,
225            ..Default::default()
226        };
227
228        if synchronization_config.local_stratum == 1 {
229            // We are a stratum 1 server so mark our selves synchronized.
230            system.time_snapshot.leap_indicator = NtpLeapIndicator::NoWarning;
231            // Set the reference id for the system
232            system.reference_id = synchronization_config.reference_id.to_reference_id();
233        }
234
235        Ok(System {
236            synchronization_config,
237            system,
238            ip_list,
239            sources: HashMap::new(),
240            controller: Controller::new(clock, synchronization_config, algorithm_config)?,
241            controller_took_control: false,
242        })
243    }
244
245    pub fn system_snapshot(&self) -> SystemSnapshot {
246        self.system
247    }
248
249    pub fn check_clock_access(&mut self) -> Result<(), <Controller::Clock as NtpClock>::Error> {
250        self.ensure_controller_control()
251    }
252
253    fn ensure_controller_control(&mut self) -> Result<(), <Controller::Clock as NtpClock>::Error> {
254        if !self.controller_took_control {
255            self.controller.take_control()?;
256            self.controller_took_control = true;
257        }
258        Ok(())
259    }
260
261    pub fn create_sock_source(
262        &mut self,
263        id: SourceId,
264        source_config: SourceConfig,
265        measurement_noise_estimate: f64,
266    ) -> Result<
267        OneWaySource<Controller::OneWaySourceController>,
268        <Controller::Clock as NtpClock>::Error,
269    > {
270        self.ensure_controller_control()?;
271        let controller =
272            self.controller
273                .add_one_way_source(id, source_config, measurement_noise_estimate, None);
274        self.sources.insert(id, None);
275        Ok(OneWaySource::new(controller))
276    }
277
278    pub fn create_pps_source(
279        &mut self,
280        id: SourceId,
281        source_config: SourceConfig,
282        measurement_noise_estimate: f64,
283        period: f64,
284    ) -> Result<
285        OneWaySource<Controller::OneWaySourceController>,
286        <Controller::Clock as NtpClock>::Error,
287    > {
288        self.ensure_controller_control()?;
289        let controller = self.controller.add_one_way_source(
290            id,
291            source_config,
292            measurement_noise_estimate,
293            Some(period),
294        );
295        self.sources.insert(id, None);
296        Ok(OneWaySource::new(controller))
297    }
298
299    #[expect(clippy::type_complexity)]
300    pub fn create_ntp_source(
301        &mut self,
302        id: SourceId,
303        source_config: SourceConfig,
304        source_addr: SocketAddr,
305        protocol_version: ProtocolVersion,
306        nts: Option<Box<SourceNtsData>>,
307    ) -> Result<
308        (
309            NtpSource<Controller::NtpSourceController>,
310            NtpSourceActionIterator<Controller::SourceMessage>,
311        ),
312        <Controller::Clock as NtpClock>::Error,
313    > {
314        self.ensure_controller_control()?;
315        let controller = self.controller.add_source(id, source_config);
316        self.sources.insert(id, None);
317        Ok(NtpSource::new(
318            source_addr,
319            source_config,
320            protocol_version,
321            controller,
322            nts,
323        ))
324    }
325
326    pub fn handle_source_remove(
327        &mut self,
328        id: SourceId,
329    ) -> Result<(), <Controller::Clock as NtpClock>::Error> {
330        self.controller.remove_source(id);
331        self.sources.remove(&id);
332        Ok(())
333    }
334
335    pub fn handle_source_update(
336        &mut self,
337        id: SourceId,
338        update: NtpSourceUpdate<Controller::SourceMessage>,
339    ) -> Result<
340        SystemActionIterator<Controller::ControllerMessage>,
341        <Controller::Clock as NtpClock>::Error,
342    > {
343        let usable = update
344            .snapshot
345            .accept_synchronization(
346                self.synchronization_config.local_stratum,
347                self.ip_list.as_ref(),
348                &self.system,
349            )
350            .is_ok();
351        self.controller.source_update(id, usable);
352        *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::Ntp(update.snapshot));
353        if let Some(message) = update.message {
354            let update = self.controller.source_message(id, message);
355            Ok(self.handle_algorithm_state_update(update))
356        } else {
357            Ok(actions!())
358        }
359    }
360
361    pub fn handle_one_way_source_update(
362        &mut self,
363        id: SourceId,
364        update: OneWaySourceUpdate<Controller::SourceMessage>,
365    ) -> Result<
366        SystemActionIterator<Controller::ControllerMessage>,
367        <Controller::Clock as NtpClock>::Error,
368    > {
369        self.controller.source_update(id, true);
370        *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::OneWay(update.snapshot));
371        if let Some(message) = update.message {
372            let update = self.controller.source_message(id, message);
373            Ok(self.handle_algorithm_state_update(update))
374        } else {
375            Ok(actions!())
376        }
377    }
378
379    fn handle_algorithm_state_update(
380        &mut self,
381        update: StateUpdate<SourceId, Controller::ControllerMessage>,
382    ) -> SystemActionIterator<Controller::ControllerMessage> {
383        let mut actions = vec![];
384        if let Some(ref used_sources) = update.used_sources {
385            self.system
386                .update_used_sources(used_sources.iter().map(|v| {
387                    self.sources.get(v).and_then(|snapshot| *snapshot).expect(
388                    "Critical error: Source used for synchronization that is not known to system",
389                )
390                }));
391        }
392        if let Some(time_snapshot) = update.time_snapshot {
393            self.system
394                .update_timedata(time_snapshot, &self.synchronization_config);
395        }
396        if let Some(timeout) = update.next_update {
397            actions.push(SystemAction::SetTimer(timeout));
398        }
399        if let Some(message) = update.source_message {
400            actions.push(SystemAction::UpdateSources(SystemSourceUpdate { message }));
401        }
402        actions.into()
403    }
404
405    pub fn handle_timer(&mut self) -> SystemActionIterator<Controller::ControllerMessage> {
406        tracing::debug!("Timer expired");
407        let update = self.controller.time_update();
408        self.handle_algorithm_state_update(update)
409    }
410
411    pub fn update_ip_list(&mut self, ip_list: Arc<[IpAddr]>) {
412        self.ip_list = ip_list;
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use std::net::{Ipv4Addr, SocketAddr};
419
420    use crate::{NtpSourceSnapshot, Reach, time_types::PollIntervalLimits};
421
422    use super::*;
423
424    #[test]
425    fn test_empty_source_update() {
426        let mut system = SystemSnapshot::default();
427
428        // Should do nothing
429        system.update_used_sources(std::iter::empty());
430
431        assert_eq!(system.stratum, 16);
432        assert_eq!(system.reference_id, ReferenceId::NONE);
433    }
434
435    #[test]
436    fn test_source_update() {
437        let mut system = SystemSnapshot::default();
438
439        system.update_used_sources(
440            vec![
441                SourceSnapshot::Ntp(NtpSourceSnapshot {
442                    source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
443                    source_id: ReferenceId::KISS_DENY,
444                    poll_interval: PollIntervalLimits::default().max,
445                    reach: Reach::never(),
446                    stratum: 2,
447                    reference_id: ReferenceId::NONE,
448                    protocol_version: ProtocolVersion::v4_upgrading_to_v5_with_default_tries(),
449                    bloom_filter: None,
450                }),
451                SourceSnapshot::Ntp(NtpSourceSnapshot {
452                    source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
453                    source_id: ReferenceId::KISS_RATE,
454                    poll_interval: PollIntervalLimits::default().max,
455                    reach: Reach::never(),
456                    stratum: 3,
457                    reference_id: ReferenceId::NONE,
458                    protocol_version: ProtocolVersion::v4_upgrading_to_v5_with_default_tries(),
459                    bloom_filter: None,
460                }),
461            ]
462            .into_iter(),
463        );
464
465        assert_eq!(system.stratum, 3);
466        assert_eq!(system.reference_id, ReferenceId::KISS_DENY);
467    }
468
469    #[test]
470    fn test_timedata_update() {
471        let mut system = SystemSnapshot::default();
472
473        let new_root_delay = NtpDuration::from_seconds(1.0);
474        let new_accumulated_threshold = NtpDuration::from_seconds(2.0);
475
476        let snapshot = TimeSnapshot {
477            root_delay: new_root_delay,
478            ..Default::default()
479        };
480        system.update_timedata(
481            snapshot,
482            &SynchronizationConfig {
483                accumulated_step_panic_threshold: Some(new_accumulated_threshold),
484                ..Default::default()
485            },
486        );
487
488        assert_eq!(system.time_snapshot, snapshot);
489
490        assert_eq!(
491            system.accumulated_steps_threshold,
492            Some(new_accumulated_threshold),
493        );
494    }
495}