1use std::net::{SocketAddr, UdpSocket};
6use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
7
8use super::common::{map_io_err, system_time_to_us};
9use crate::time_src::{OffsetMicros, TimeSource, TimeSourceError};
10
11pub struct NtpSource;
12
13const NTP_TO_UNIX: u64 = 2_208_988_800;
15
16impl TimeSource for NtpSource {
17 fn name(&self) -> &'static str {
18 "ntp"
19 }
20
21 fn fetch(
22 &self,
23 target: SocketAddr,
24 timeout: Duration,
25 ) -> Result<OffsetMicros, TimeSourceError> {
26 let ntp_addr: SocketAddr = (target.ip(), 123).into();
27 fetch_ntp(ntp_addr, timeout)
28 }
29}
30
31fn fetch_ntp(addr: SocketAddr, timeout: Duration) -> Result<OffsetMicros, TimeSourceError> {
32 let socket = UdpSocket::bind(if addr.is_ipv4() {
33 "0.0.0.0:0"
34 } else {
35 "[::]:0"
36 })
37 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
38 socket
39 .set_read_timeout(Some(timeout))
40 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
41
42 let mut req = [0u8; 48];
44 req[0] = 0b00_100_011; let t1_sys = SystemTime::now();
48 let t1_ntp = system_time_to_ntp(t1_sys);
49 req[40..44].copy_from_slice(&t1_ntp.0.to_be_bytes());
50 req[44..48].copy_from_slice(&t1_ntp.1.to_be_bytes());
51
52 socket.connect(addr).map_err(|e| map_io_err(e, "connect"))?;
53
54 let t_send = Instant::now();
55 socket.send(&req).map_err(|e| map_io_err(e, "send"))?;
56
57 let mut buf = [0u8; 48];
58 let n = socket.recv(&mut buf).map_err(|e| map_io_err(e, "recv"))?;
59 let rtt = t_send.elapsed();
60
61 if n < 48 {
62 return Err(TimeSourceError::Parse(format!(
63 "short NTP response: {} bytes",
64 n
65 )));
66 }
67
68 let mode = buf[0] & 0x07;
69 if mode != 4 && mode != 5 {
70 return Err(TimeSourceError::Protocol(format!(
71 "unexpected NTP mode: {}",
72 mode
73 )));
74 }
75
76 let t2 = parse_ntp_timestamp(&buf[32..40])?;
78 let t3 = parse_ntp_timestamp(&buf[40..48])?;
80
81 let t4_us = system_time_to_us(t1_sys)? + rtt.as_micros() as i64;
83
84 let t1_us = system_time_to_us(t1_sys)?;
86 let offset_us = ((t2 - t1_us) + (t3 - t4_us)) / 2;
87
88 Ok(offset_us)
89}
90
91fn parse_ntp_timestamp(b: &[u8]) -> Result<i64, TimeSourceError> {
96 if b.len() < 8 {
97 return Err(TimeSourceError::Parse("NTP timestamp too short".into()));
98 }
99 let secs = u32::from_be_bytes([b[0], b[1], b[2], b[3]]) as u64;
100 let frac = u32::from_be_bytes([b[4], b[5], b[6], b[7]]);
101
102 if secs < NTP_TO_UNIX {
103 return Err(TimeSourceError::Parse(format!(
104 "NTP seconds {} predates Unix epoch",
105 secs
106 )));
107 }
108 let unix_secs = secs - NTP_TO_UNIX;
109 let frac_us = (frac as u64 * 1_000_000) >> 32;
111 i64::try_from(unix_secs * 1_000_000 + frac_us)
112 .map_err(|_| TimeSourceError::Parse("NTP timestamp overflows i64 (post-2262)".into()))
113}
114
115fn system_time_to_ntp(t: SystemTime) -> (u32, u32) {
117 let dur = t.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
118 let ntp_secs = (dur.as_secs() + NTP_TO_UNIX) as u32;
119 let frac = ((dur.subsec_nanos() as u64) << 32) / 1_000_000_000;
121 (ntp_secs, frac as u32)
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn parse_known_ntp_timestamp() {
130 let secs: u32 = 3_913_056_000;
133 let frac: u32 = 0;
134 let mut b = [0u8; 8];
135 b[0..4].copy_from_slice(&secs.to_be_bytes());
136 b[4..8].copy_from_slice(&frac.to_be_bytes());
137 let us = parse_ntp_timestamp(&b).unwrap();
138 assert_eq!(us, 1_704_067_200 * 1_000_000);
139 }
140
141 #[test]
142 fn parse_ntp_with_fraction() {
143 let secs: u32 = NTP_TO_UNIX as u32;
145 let frac: u32 = 1 << 31;
146 let mut b = [0u8; 8];
147 b[0..4].copy_from_slice(&secs.to_be_bytes());
148 b[4..8].copy_from_slice(&frac.to_be_bytes());
149 let us = parse_ntp_timestamp(&b).unwrap();
150 assert_eq!(us, 500_000);
151 }
152
153 #[test]
154 fn roundtrip_ntp_conversion() {
155 let now = SystemTime::now();
156 let (secs, frac) = system_time_to_ntp(now);
157 let mut b = [0u8; 8];
158 b[0..4].copy_from_slice(&secs.to_be_bytes());
159 b[4..8].copy_from_slice(&frac.to_be_bytes());
160 let us = parse_ntp_timestamp(&b).unwrap();
161 let expected = system_time_to_us(now).unwrap();
162 assert!(
164 (us - expected).abs() < 1000,
165 "roundtrip error: {}us",
166 us - expected
167 );
168 }
169
170 use proptest::prelude::*;
171
172 proptest! {
173 #[test]
174 fn parse_ntp_timestamp_never_panics(data in proptest::collection::vec(any::<u8>(), 0..16)) {
175 let _ = parse_ntp_timestamp(&data);
176 }
177 }
178}