1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::net::{IpAddr, SocketAddr};
4use std::sync::Arc;
5use std::time::Duration;
6use std::{fmt::Debug, hash::Hash};
7
8use crate::packet::v5::server_reference_id::{BloomFilter, ServerId};
9use crate::source::{NtpSourceUpdate, SourceSnapshot};
10use crate::{NtpTimestamp, OneWaySource, OneWaySourceUpdate};
11use crate::{
12 algorithm::{StateUpdate, TimeSyncController},
13 clock::NtpClock,
14 config::{SourceConfig, SynchronizationConfig},
15 identifiers::ReferenceId,
16 packet::NtpLeapIndicator,
17 source::{NtpSource, NtpSourceActionIterator, ProtocolVersion, SourceNtsData},
18 time_types::NtpDuration,
19};
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
22pub struct TimeSnapshot {
23 pub precision: NtpDuration,
25 pub root_delay: NtpDuration,
27 pub root_variance_base_time: NtpTimestamp,
29 pub root_variance_base: f64,
31 pub root_variance_linear: f64,
33 pub root_variance_quadratic: f64,
35 pub root_variance_cubic: f64,
37 pub leap_indicator: NtpLeapIndicator,
39 pub accumulated_steps: NtpDuration,
41}
42
43impl TimeSnapshot {
44 pub fn root_dispersion(&self, now: NtpTimestamp) -> NtpDuration {
45 let t = (now - self.root_variance_base_time).to_seconds();
46 NtpDuration::from_seconds(
48 (self.root_variance_base
49 + t * self.root_variance_linear
50 + t.powi(2) * self.root_variance_quadratic
51 + t.powi(3) * self.root_variance_cubic)
52 .sqrt(),
53 )
54 }
55}
56
57impl Default for TimeSnapshot {
58 fn default() -> Self {
59 Self {
60 precision: NtpDuration::from_exponent(-18),
61 root_delay: NtpDuration::ZERO,
62 root_variance_base_time: NtpTimestamp::default(),
63 root_variance_base: 0.0,
64 root_variance_linear: 0.0,
65 root_variance_quadratic: 0.0,
66 root_variance_cubic: 0.0,
67 leap_indicator: NtpLeapIndicator::Unknown,
68 accumulated_steps: NtpDuration::ZERO,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
74pub struct SystemSnapshot {
75 pub stratum: u8,
77 pub reference_id: ReferenceId,
79 pub accumulated_steps_threshold: Option<NtpDuration>,
81 #[serde(flatten)]
83 pub time_snapshot: TimeSnapshot,
84 #[serde(skip)]
86 pub bloom_filter: BloomFilter,
87 #[serde(skip)]
89 pub server_id: ServerId,
90}
91
92impl SystemSnapshot {
93 pub fn update_timedata(&mut self, timedata: TimeSnapshot, config: &SynchronizationConfig) {
94 self.time_snapshot = timedata;
95 self.accumulated_steps_threshold = config.accumulated_step_panic_threshold;
96 }
97
98 pub fn update_used_sources(&mut self, used_sources: impl Iterator<Item = SourceSnapshot>) {
99 let mut used_sources = used_sources.peekable();
100 if let Some(system_source_snapshot) = used_sources.peek() {
101 let (stratum, source_id) = match system_source_snapshot {
102 SourceSnapshot::Ntp(snapshot) => (snapshot.stratum, snapshot.source_id),
103 SourceSnapshot::OneWay(snapshot) => (snapshot.stratum, snapshot.source_id),
104 };
105
106 self.stratum = stratum.saturating_add(1);
107 self.reference_id = source_id;
108 }
109
110 self.bloom_filter = BloomFilter::new();
111 for source in used_sources {
112 if let SourceSnapshot::Ntp(source) = source {
113 if let Some(bf) = &source.bloom_filter {
114 self.bloom_filter.add(bf);
115 } else if let ProtocolVersion::V5 = source.protocol_version {
116 tracing::warn!("Using NTPv5 source without a bloom filter!");
117 }
118 }
119 }
120 self.bloom_filter.add_id(&self.server_id);
121 }
122}
123
124impl Default for SystemSnapshot {
125 fn default() -> Self {
126 Self {
127 stratum: 16,
128 reference_id: ReferenceId::NONE,
129 accumulated_steps_threshold: None,
130 time_snapshot: TimeSnapshot::default(),
131 bloom_filter: BloomFilter::new(),
132 server_id: ServerId::default(),
133 }
134 }
135}
136
137pub struct SystemSourceUpdate<ControllerMessage> {
138 pub message: ControllerMessage,
139}
140
141impl<ControllerMessage: Debug> std::fmt::Debug for SystemSourceUpdate<ControllerMessage> {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("SystemSourceUpdate")
144 .field("message", &self.message)
145 .finish()
146 }
147}
148
149impl<ControllerMessage: Clone> Clone for SystemSourceUpdate<ControllerMessage> {
150 fn clone(&self) -> Self {
151 Self {
152 message: self.message.clone(),
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
158pub enum SystemAction<ControllerMessage> {
159 UpdateSources(SystemSourceUpdate<ControllerMessage>),
160 SetTimer(Duration),
161}
162
163#[derive(Debug)]
164pub struct SystemActionIterator<ControllerMessage> {
165 iter: <Vec<SystemAction<ControllerMessage>> as IntoIterator>::IntoIter,
166}
167
168impl<ControllerMessage> Default for SystemActionIterator<ControllerMessage> {
169 fn default() -> Self {
170 Self {
171 iter: vec![].into_iter(),
172 }
173 }
174}
175
176impl<ControllerMessage> From<Vec<SystemAction<ControllerMessage>>>
177 for SystemActionIterator<ControllerMessage>
178{
179 fn from(value: Vec<SystemAction<ControllerMessage>>) -> Self {
180 Self {
181 iter: value.into_iter(),
182 }
183 }
184}
185
186impl<ControllerMessage> Iterator for SystemActionIterator<ControllerMessage> {
187 type Item = SystemAction<ControllerMessage>;
188
189 fn next(&mut self) -> Option<Self::Item> {
190 self.iter.next()
191 }
192}
193
194macro_rules! actions {
195 [$($action:expr),*] => {
196 {
197 SystemActionIterator::from(vec![$($action),*])
198 }
199 }
200}
201
202pub struct System<SourceId, Controller> {
203 synchronization_config: SynchronizationConfig,
204 system: SystemSnapshot,
205 ip_list: Arc<[IpAddr]>,
206
207 sources: HashMap<SourceId, Option<SourceSnapshot>>,
208
209 controller: Controller,
210 controller_took_control: bool,
211}
212
213impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId = SourceId>>
214 System<SourceId, Controller>
215{
216 pub fn new(
217 clock: Controller::Clock,
218 synchronization_config: SynchronizationConfig,
219 algorithm_config: Controller::AlgorithmConfig,
220 ip_list: Arc<[IpAddr]>,
221 ) -> Result<Self, <Controller::Clock as NtpClock>::Error> {
222 let mut system = SystemSnapshot {
224 stratum: synchronization_config.local_stratum,
225 ..Default::default()
226 };
227
228 if synchronization_config.local_stratum == 1 {
229 system.time_snapshot.leap_indicator = NtpLeapIndicator::NoWarning;
231 system.reference_id = synchronization_config.reference_id.to_reference_id();
233 }
234
235 Ok(System {
236 synchronization_config,
237 system,
238 ip_list,
239 sources: HashMap::new(),
240 controller: Controller::new(clock, synchronization_config, algorithm_config)?,
241 controller_took_control: false,
242 })
243 }
244
245 pub fn system_snapshot(&self) -> SystemSnapshot {
246 self.system
247 }
248
249 pub fn check_clock_access(&mut self) -> Result<(), <Controller::Clock as NtpClock>::Error> {
250 self.ensure_controller_control()
251 }
252
253 fn ensure_controller_control(&mut self) -> Result<(), <Controller::Clock as NtpClock>::Error> {
254 if !self.controller_took_control {
255 self.controller.take_control()?;
256 self.controller_took_control = true;
257 }
258 Ok(())
259 }
260
261 pub fn create_sock_source(
262 &mut self,
263 id: SourceId,
264 source_config: SourceConfig,
265 measurement_noise_estimate: f64,
266 ) -> Result<
267 OneWaySource<Controller::OneWaySourceController>,
268 <Controller::Clock as NtpClock>::Error,
269 > {
270 self.ensure_controller_control()?;
271 let controller =
272 self.controller
273 .add_one_way_source(id, source_config, measurement_noise_estimate, None);
274 self.sources.insert(id, None);
275 Ok(OneWaySource::new(controller))
276 }
277
278 pub fn create_pps_source(
279 &mut self,
280 id: SourceId,
281 source_config: SourceConfig,
282 measurement_noise_estimate: f64,
283 period: f64,
284 ) -> Result<
285 OneWaySource<Controller::OneWaySourceController>,
286 <Controller::Clock as NtpClock>::Error,
287 > {
288 self.ensure_controller_control()?;
289 let controller = self.controller.add_one_way_source(
290 id,
291 source_config,
292 measurement_noise_estimate,
293 Some(period),
294 );
295 self.sources.insert(id, None);
296 Ok(OneWaySource::new(controller))
297 }
298
299 #[expect(clippy::type_complexity)]
300 pub fn create_ntp_source(
301 &mut self,
302 id: SourceId,
303 source_config: SourceConfig,
304 source_addr: SocketAddr,
305 protocol_version: ProtocolVersion,
306 nts: Option<Box<SourceNtsData>>,
307 ) -> Result<
308 (
309 NtpSource<Controller::NtpSourceController>,
310 NtpSourceActionIterator<Controller::SourceMessage>,
311 ),
312 <Controller::Clock as NtpClock>::Error,
313 > {
314 self.ensure_controller_control()?;
315 let controller = self.controller.add_source(id, source_config);
316 self.sources.insert(id, None);
317 Ok(NtpSource::new(
318 source_addr,
319 source_config,
320 protocol_version,
321 controller,
322 nts,
323 ))
324 }
325
326 pub fn handle_source_remove(
327 &mut self,
328 id: SourceId,
329 ) -> Result<(), <Controller::Clock as NtpClock>::Error> {
330 self.controller.remove_source(id);
331 self.sources.remove(&id);
332 Ok(())
333 }
334
335 pub fn handle_source_update(
336 &mut self,
337 id: SourceId,
338 update: NtpSourceUpdate<Controller::SourceMessage>,
339 ) -> Result<
340 SystemActionIterator<Controller::ControllerMessage>,
341 <Controller::Clock as NtpClock>::Error,
342 > {
343 let usable = update
344 .snapshot
345 .accept_synchronization(
346 self.synchronization_config.local_stratum,
347 self.ip_list.as_ref(),
348 &self.system,
349 )
350 .is_ok();
351 self.controller.source_update(id, usable);
352 *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::Ntp(update.snapshot));
353 if let Some(message) = update.message {
354 let update = self.controller.source_message(id, message);
355 Ok(self.handle_algorithm_state_update(update))
356 } else {
357 Ok(actions!())
358 }
359 }
360
361 pub fn handle_one_way_source_update(
362 &mut self,
363 id: SourceId,
364 update: OneWaySourceUpdate<Controller::SourceMessage>,
365 ) -> Result<
366 SystemActionIterator<Controller::ControllerMessage>,
367 <Controller::Clock as NtpClock>::Error,
368 > {
369 self.controller.source_update(id, true);
370 *self.sources.get_mut(&id).unwrap() = Some(SourceSnapshot::OneWay(update.snapshot));
371 if let Some(message) = update.message {
372 let update = self.controller.source_message(id, message);
373 Ok(self.handle_algorithm_state_update(update))
374 } else {
375 Ok(actions!())
376 }
377 }
378
379 fn handle_algorithm_state_update(
380 &mut self,
381 update: StateUpdate<SourceId, Controller::ControllerMessage>,
382 ) -> SystemActionIterator<Controller::ControllerMessage> {
383 let mut actions = vec![];
384 if let Some(ref used_sources) = update.used_sources {
385 self.system
386 .update_used_sources(used_sources.iter().map(|v| {
387 self.sources.get(v).and_then(|snapshot| *snapshot).expect(
388 "Critical error: Source used for synchronization that is not known to system",
389 )
390 }));
391 }
392 if let Some(time_snapshot) = update.time_snapshot {
393 self.system
394 .update_timedata(time_snapshot, &self.synchronization_config);
395 }
396 if let Some(timeout) = update.next_update {
397 actions.push(SystemAction::SetTimer(timeout));
398 }
399 if let Some(message) = update.source_message {
400 actions.push(SystemAction::UpdateSources(SystemSourceUpdate { message }));
401 }
402 actions.into()
403 }
404
405 pub fn handle_timer(&mut self) -> SystemActionIterator<Controller::ControllerMessage> {
406 tracing::debug!("Timer expired");
407 let update = self.controller.time_update();
408 self.handle_algorithm_state_update(update)
409 }
410
411 pub fn update_ip_list(&mut self, ip_list: Arc<[IpAddr]>) {
412 self.ip_list = ip_list;
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use std::net::{Ipv4Addr, SocketAddr};
419
420 use crate::{NtpSourceSnapshot, Reach, time_types::PollIntervalLimits};
421
422 use super::*;
423
424 #[test]
425 fn test_empty_source_update() {
426 let mut system = SystemSnapshot::default();
427
428 system.update_used_sources(std::iter::empty());
430
431 assert_eq!(system.stratum, 16);
432 assert_eq!(system.reference_id, ReferenceId::NONE);
433 }
434
435 #[test]
436 fn test_source_update() {
437 let mut system = SystemSnapshot::default();
438
439 system.update_used_sources(
440 vec![
441 SourceSnapshot::Ntp(NtpSourceSnapshot {
442 source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
443 source_id: ReferenceId::KISS_DENY,
444 poll_interval: PollIntervalLimits::default().max,
445 reach: Reach::never(),
446 stratum: 2,
447 reference_id: ReferenceId::NONE,
448 protocol_version: ProtocolVersion::v4_upgrading_to_v5_with_default_tries(),
449 bloom_filter: None,
450 }),
451 SourceSnapshot::Ntp(NtpSourceSnapshot {
452 source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
453 source_id: ReferenceId::KISS_RATE,
454 poll_interval: PollIntervalLimits::default().max,
455 reach: Reach::never(),
456 stratum: 3,
457 reference_id: ReferenceId::NONE,
458 protocol_version: ProtocolVersion::v4_upgrading_to_v5_with_default_tries(),
459 bloom_filter: None,
460 }),
461 ]
462 .into_iter(),
463 );
464
465 assert_eq!(system.stratum, 3);
466 assert_eq!(system.reference_id, ReferenceId::KISS_DENY);
467 }
468
469 #[test]
470 fn test_timedata_update() {
471 let mut system = SystemSnapshot::default();
472
473 let new_root_delay = NtpDuration::from_seconds(1.0);
474 let new_accumulated_threshold = NtpDuration::from_seconds(2.0);
475
476 let snapshot = TimeSnapshot {
477 root_delay: new_root_delay,
478 ..Default::default()
479 };
480 system.update_timedata(
481 snapshot,
482 &SynchronizationConfig {
483 accumulated_step_panic_threshold: Some(new_accumulated_threshold),
484 ..Default::default()
485 },
486 );
487
488 assert_eq!(system.time_snapshot, snapshot);
489
490 assert_eq!(
491 system.accumulated_steps_threshold,
492 Some(new_accumulated_threshold),
493 );
494 }
495}