1use std::net::{SocketAddr, UdpSocket};
6use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
7
8use super::common::system_time_to_us;
9use crate::time_src::{OffsetMicros, TimeSource, TimeSourceError};
10
11pub struct NtpSource;
12
13const NTP_TO_UNIX: u64 = 2_208_988_800;
15
16impl TimeSource for NtpSource {
17 fn name(&self) -> &'static str {
18 "ntp"
19 }
20
21 fn fetch(
22 &self,
23 target: SocketAddr,
24 timeout: Duration,
25 ) -> Result<OffsetMicros, TimeSourceError> {
26 let ntp_addr: SocketAddr = (target.ip(), 123).into();
27 fetch_ntp(ntp_addr, timeout)
28 }
29}
30
31fn fetch_ntp(addr: SocketAddr, timeout: Duration) -> Result<OffsetMicros, TimeSourceError> {
32 let socket = UdpSocket::bind(if addr.is_ipv4() {
33 "0.0.0.0:0"
34 } else {
35 "[::]:0"
36 })
37 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
38 socket
39 .set_read_timeout(Some(timeout))
40 .map_err(|e| TimeSourceError::Protocol(e.to_string()))?;
41
42 let mut req = [0u8; 48];
44 req[0] = 0b00_100_011; let t1_sys = SystemTime::now();
48 let t1_ntp = system_time_to_ntp(t1_sys);
49 req[40..44].copy_from_slice(&t1_ntp.0.to_be_bytes());
50 req[44..48].copy_from_slice(&t1_ntp.1.to_be_bytes());
51
52 socket.connect(addr).map_err(|e| map_io_err(e, "connect"))?;
53
54 let t_send = Instant::now();
55 socket.send(&req).map_err(|e| map_io_err(e, "send"))?;
56
57 let mut buf = [0u8; 48];
58 let n = socket.recv(&mut buf).map_err(|e| map_io_err(e, "recv"))?;
59 let rtt = t_send.elapsed();
60
61 if n < 48 {
62 return Err(TimeSourceError::Parse(format!(
63 "short NTP response: {} bytes",
64 n
65 )));
66 }
67
68 let mode = buf[0] & 0x07;
69 if mode != 4 && mode != 5 {
70 return Err(TimeSourceError::Protocol(format!(
71 "unexpected NTP mode: {}",
72 mode
73 )));
74 }
75
76 let t2 = parse_ntp_timestamp(&buf[32..40])?;
78 let t3 = parse_ntp_timestamp(&buf[40..48])?;
80
81 let t4_us = system_time_to_us(t1_sys)? + rtt.as_micros() as i64;
83
84 let t1_us = system_time_to_us(t1_sys)?;
86 let offset_us = ((t2 - t1_us) + (t3 - t4_us)) / 2;
87
88 Ok(offset_us)
89}
90
91fn parse_ntp_timestamp(b: &[u8]) -> Result<i64, TimeSourceError> {
96 if b.len() < 8 {
97 return Err(TimeSourceError::Parse("NTP timestamp too short".into()));
98 }
99 let secs = u32::from_be_bytes([b[0], b[1], b[2], b[3]]) as u64;
100 let frac = u32::from_be_bytes([b[4], b[5], b[6], b[7]]);
101
102 if secs < NTP_TO_UNIX {
103 return Err(TimeSourceError::Parse(format!(
104 "NTP seconds {} predates Unix epoch",
105 secs
106 )));
107 }
108 let unix_secs = secs - NTP_TO_UNIX;
109 let frac_us = (frac as u64 * 1_000_000) >> 32;
111 i64::try_from(unix_secs * 1_000_000 + frac_us)
112 .map_err(|_| TimeSourceError::Parse("NTP timestamp overflows i64 (post-2262)".into()))
113}
114
115fn system_time_to_ntp(t: SystemTime) -> (u32, u32) {
117 let dur = t.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
118 let ntp_secs = (dur.as_secs() + NTP_TO_UNIX) as u32;
119 let frac = ((dur.subsec_nanos() as u64) << 32) / 1_000_000_000;
121 (ntp_secs, frac as u32)
122}
123
124fn map_io_err(e: std::io::Error, op: &str) -> TimeSourceError {
125 use std::io::ErrorKind::*;
126 match e.kind() {
127 TimedOut | WouldBlock => TimeSourceError::Timeout,
128 ConnectionRefused => TimeSourceError::Refused,
129 _ => TimeSourceError::Protocol(format!("{}: {}", op, e)),
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn parse_known_ntp_timestamp() {
139 let secs: u32 = 3_913_958_400;
142 let frac: u32 = 0;
143 let mut b = [0u8; 8];
144 b[0..4].copy_from_slice(&secs.to_be_bytes());
145 b[4..8].copy_from_slice(&frac.to_be_bytes());
146 let us = parse_ntp_timestamp(&b).unwrap();
147 assert_eq!(us, 1_704_969_600 * 1_000_000);
148 }
149
150 #[test]
151 fn parse_ntp_with_fraction() {
152 let secs: u32 = NTP_TO_UNIX as u32;
154 let frac: u32 = 1 << 31;
155 let mut b = [0u8; 8];
156 b[0..4].copy_from_slice(&secs.to_be_bytes());
157 b[4..8].copy_from_slice(&frac.to_be_bytes());
158 let us = parse_ntp_timestamp(&b).unwrap();
159 assert_eq!(us, 500_000);
160 }
161
162 #[test]
163 fn roundtrip_ntp_conversion() {
164 let now = SystemTime::now();
165 let (secs, frac) = system_time_to_ntp(now);
166 let mut b = [0u8; 8];
167 b[0..4].copy_from_slice(&secs.to_be_bytes());
168 b[4..8].copy_from_slice(&frac.to_be_bytes());
169 let us = parse_ntp_timestamp(&b).unwrap();
170 let expected = system_time_to_us(now).unwrap();
171 assert!(
173 (us - expected).abs() < 1000,
174 "roundtrip error: {}us",
175 us - expected
176 );
177 }
178}