rsntp/
core_logic.rs

1use crate::error::{KissCode, ProtocolError, SynchronizationError};
2use crate::packet::{LeapIndicator, Mode, Packet, ReferenceIdentifier, SntpTimestamp};
3use crate::result::SynchronizationResult;
4use std::time::SystemTime;
5
6pub struct Request {
7    packet: Packet,
8}
9
10impl Request {
11    pub fn new() -> Request {
12        Self::new_with_transmit_time(SystemTime::now())
13    }
14
15    pub fn new_with_transmit_time(transmit_time: SystemTime) -> Request {
16        Request {
17            packet: Packet {
18                li: LeapIndicator::NoWarning,
19                mode: Mode::Client,
20                stratum: 0,
21                reference_identifier: ReferenceIdentifier::Empty,
22                reference_timestamp: SntpTimestamp::zero(),
23                originate_timestamp: SntpTimestamp::zero(),
24                receive_timestamp: SntpTimestamp::zero(),
25                transmit_timestamp: SntpTimestamp::from_systemtime(transmit_time),
26            },
27        }
28    }
29
30    pub fn as_bytes(&self) -> [u8; Packet::ENCODED_LEN] {
31        self.packet.to_bytes()
32    }
33
34    fn into_packet(self) -> Packet {
35        self.packet
36    }
37}
38
39pub struct Reply {
40    request: Packet,
41    reply: Packet,
42    reply_timestamp: SntpTimestamp,
43}
44
45impl Reply {
46    pub fn new(request: Request, reply: Packet) -> Reply {
47        Self::new_with_reply_time(request, reply, SystemTime::now())
48    }
49
50    pub fn new_with_reply_time(request: Request, reply: Packet, reply_time: SystemTime) -> Reply {
51        Reply {
52            request: request.into_packet(),
53            reply,
54            reply_timestamp: SntpTimestamp::from_systemtime(reply_time),
55        }
56    }
57
58    fn check(&self) -> Result<(), ProtocolError> {
59        if self.reply.stratum == 0 {
60            return Err(ProtocolError::KissODeath(KissCode::new(
61                &self.reply.reference_identifier,
62            )));
63        }
64
65        if self.reply.originate_timestamp != self.request.transmit_timestamp {
66            return Err(ProtocolError::InvalidOriginateTimestamp);
67        }
68
69        if self.reply.transmit_timestamp.is_zero() {
70            return Err(ProtocolError::InvalidTransmitTimestamp);
71        }
72
73        if self.reply.mode != Mode::Server && self.reply.mode != Mode::Broadcast {
74            return Err(ProtocolError::InvalidMode);
75        }
76        Ok(())
77    }
78
79    pub fn process(self) -> Result<SynchronizationResult, SynchronizationError> {
80        self.check()?;
81
82        let originate_ts = self.reply.originate_timestamp;
83        let transmit_ts = self.reply.transmit_timestamp;
84        let receive_ts = self.reply.receive_timestamp;
85        let round_trip_delay_s = (self.reply_timestamp - originate_ts) - (transmit_ts - receive_ts);
86        let clock_offset_s =
87            ((receive_ts - originate_ts) + (transmit_ts - self.reply_timestamp)) / 2.0;
88        Ok(SynchronizationResult::new(
89            clock_offset_s,
90            round_trip_delay_s,
91            self.reply.reference_identifier.clone(),
92            self.reply.li,
93            self.reply.stratum,
94        ))
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    macro_rules! assert_between {
103        ($var: expr, $lower: expr, $upper: expr) => {
104            if $var < $lower || $var > $upper {
105                panic!(
106                    "Assertion failed, {:?} is not between {:?} and {:?}",
107                    $var, $lower, $upper
108                );
109            }
110        };
111    }
112
113    #[test]
114    fn basic_synchronization_works() {
115        let now = SystemTime::now();
116        let request = Request::new_with_transmit_time(now);
117
118        let reply_packet = Packet {
119            li: LeapIndicator::NoWarning,
120            mode: Mode::Server,
121            stratum: 1,
122            reference_identifier: ReferenceIdentifier::new_ascii([0x4c, 0x4f, 0x43, 0x4c]).unwrap(),
123            reference_timestamp: SntpTimestamp::from_systemtime(
124                now - std::time::Duration::from_secs(86400),
125            ),
126            originate_timestamp: request.packet.transmit_timestamp,
127            receive_timestamp: SntpTimestamp::from_systemtime(
128                now - std::time::Duration::from_millis(400),
129            ),
130            transmit_timestamp: SntpTimestamp::from_systemtime(
131                now - std::time::Duration::from_millis(400),
132            ),
133        };
134
135        let reply = Reply::new_with_reply_time(
136            request,
137            reply_packet,
138            now + std::time::Duration::from_millis(200),
139        );
140
141        let result = reply.process().unwrap();
142
143        assert_between!(result.clock_offset().as_secs_f64(), -0.51, -0.49);
144        assert_between!(result.round_trip_delay().as_secs_f64(), 0.19, 0.21);
145
146        assert_eq!(result.reference_identifier().to_string(), "LOCL");
147        assert_eq!(result.leap_indicator(), LeapIndicator::NoWarning);
148        assert_eq!(result.stratum(), 1);
149    }
150
151    #[test]
152    fn sync_fails_if_reply_originate_ts_does_not_match_request_transmit_ts() {
153        let request = Request::new();
154        let now = SystemTime::now();
155
156        let reply_packet = Packet {
157            li: LeapIndicator::NoWarning,
158            mode: Mode::Server,
159            stratum: 1,
160            reference_identifier: ReferenceIdentifier::new_ascii([0x4c, 0x4f, 0x43, 0x4c]).unwrap(),
161            reference_timestamp: SntpTimestamp::from_systemtime(
162                now - std::time::Duration::from_secs(86400),
163            ),
164            originate_timestamp: SntpTimestamp::from_systemtime(now),
165            receive_timestamp: SntpTimestamp::from_systemtime(
166                now - std::time::Duration::from_millis(500),
167            ),
168            transmit_timestamp: SntpTimestamp::from_systemtime(
169                now - std::time::Duration::from_millis(500),
170            ),
171        };
172
173        let reply = Reply::new(request, reply_packet);
174
175        let result = reply.process();
176
177        assert!(result.is_err());
178    }
179
180    #[test]
181    fn sync_fails_if_reply_contains_zero_transmit_timestamp() {
182        let request = Request::new();
183        let now = SystemTime::now();
184
185        let reply_packet = Packet {
186            li: LeapIndicator::NoWarning,
187            mode: Mode::Server,
188            stratum: 1,
189            reference_identifier: ReferenceIdentifier::new_ascii([0x4c, 0x4f, 0x43, 0x4c]).unwrap(),
190            reference_timestamp: SntpTimestamp::from_systemtime(
191                now - std::time::Duration::from_secs(86400),
192            ),
193            originate_timestamp: request.packet.transmit_timestamp,
194            receive_timestamp: SntpTimestamp::from_systemtime(
195                now - std::time::Duration::from_millis(500),
196            ),
197            transmit_timestamp: SntpTimestamp::zero(),
198        };
199
200        let reply = Reply::new(request, reply_packet);
201
202        let result = reply.process();
203
204        assert!(result.is_err());
205    }
206
207    #[test]
208    fn sync_fails_if_reply_contains_wrong_mode() {
209        let request = Request::new();
210        let now = SystemTime::now();
211
212        let reply_packet = Packet {
213            li: LeapIndicator::NoWarning,
214            mode: Mode::Client,
215            stratum: 1,
216            reference_identifier: ReferenceIdentifier::new_ascii([0x4c, 0x4f, 0x43, 0x4c]).unwrap(),
217            reference_timestamp: SntpTimestamp::from_systemtime(
218                now - std::time::Duration::from_secs(86400),
219            ),
220            originate_timestamp: request.packet.transmit_timestamp,
221            receive_timestamp: SntpTimestamp::from_systemtime(
222                now - std::time::Duration::from_millis(500),
223            ),
224            transmit_timestamp: SntpTimestamp::from_systemtime(
225                now - std::time::Duration::from_millis(500),
226            ),
227        };
228
229        let reply = Reply::new(request, reply_packet);
230
231        let result = reply.process();
232
233        assert!(result.is_err());
234    }
235
236    #[test]
237    fn sync_fails_if_kiss_o_death_received() {
238        let request = Request::new();
239        let now = SystemTime::now();
240
241        let reply_packet = Packet {
242            li: LeapIndicator::NoWarning,
243            mode: Mode::Server,
244            stratum: 0,
245            reference_identifier: ReferenceIdentifier::new_ascii([0x52, 0x41, 0x54, 0x45]).unwrap(),
246            reference_timestamp: SntpTimestamp::from_systemtime(
247                now - std::time::Duration::from_secs(86400),
248            ),
249            originate_timestamp: request.packet.transmit_timestamp,
250            receive_timestamp: SntpTimestamp::from_systemtime(
251                now - std::time::Duration::from_millis(500),
252            ),
253            transmit_timestamp: SntpTimestamp::from_systemtime(
254                now - std::time::Duration::from_millis(500),
255            ),
256        };
257
258        let reply = Reply::new(request, reply_packet);
259
260        let err = reply.process().unwrap_err();
261
262        if let SynchronizationError::ProtocolError(ProtocolError::KissODeath(
263            KissCode::RateExceeded,
264        )) = err
265        {
266            // pass
267        } else {
268            panic!("Wrong error received");
269        }
270    }
271}