ntp_udp/
socket.rs

1#![forbid(unsafe_code)]
2
3use std::{io, net::SocketAddr};
4
5use ntp_proto::NtpTimestamp;
6use tokio::io::{unix::AsyncFd, Interest};
7use tracing::instrument;
8
9use crate::{
10    interface::InterfaceName,
11    raw_socket::{
12        control_message_space, receive_message, set_timestamping_options, ControlMessage,
13        MessageQueue, TimestampMethod,
14    },
15    EnableTimestamps,
16};
17
18pub struct UdpSocket {
19    io: AsyncFd<std::net::UdpSocket>,
20    send_counter: u32,
21    timestamping: EnableTimestamps,
22}
23
24#[cfg(target_os = "linux")]
25const DEFAULT_TIMESTAMP_METHOD: TimestampMethod = TimestampMethod::SoTimestamping;
26
27#[cfg(all(unix, not(target_os = "linux")))]
28const DEFAULT_TIMESTAMP_METHOD: TimestampMethod = TimestampMethod::SoTimestamp;
29
30impl UdpSocket {
31    #[instrument(level = "debug", skip(peer_addr))]
32    pub async fn client(listen_addr: SocketAddr, peer_addr: SocketAddr) -> io::Result<UdpSocket> {
33        Self::client_with_timestamping(
34            listen_addr,
35            peer_addr,
36            InterfaceName::DEFAULT,
37            EnableTimestamps::default(),
38        )
39        .await
40    }
41
42    pub async fn client_with_timestamping(
43        listen_addr: SocketAddr,
44        peer_addr: SocketAddr,
45        interface: Option<InterfaceName>,
46        timestamping: EnableTimestamps,
47    ) -> io::Result<UdpSocket> {
48        Self::client_with_timestamping_internal(
49            listen_addr,
50            peer_addr,
51            interface,
52            DEFAULT_TIMESTAMP_METHOD,
53            timestamping,
54        )
55        .await
56    }
57
58    async fn client_with_timestamping_internal(
59        listen_addr: SocketAddr,
60        peer_addr: SocketAddr,
61        interface: Option<InterfaceName>,
62        method: TimestampMethod,
63        timestamping: EnableTimestamps,
64    ) -> io::Result<UdpSocket> {
65        let socket = tokio::net::UdpSocket::bind(listen_addr).await?;
66        tracing::debug!(
67            local_addr = ?socket.local_addr().unwrap(),
68            "client socket bound"
69        );
70
71        // bind the socket to a specific interface. This is relevant for hardware timestamping,
72        // because the interface determines which clock is used to produce the timestamps.
73        if let Some(_interface) = interface {
74            #[cfg(target_os = "linux")]
75            socket.bind_device(Some(&_interface)).unwrap();
76        }
77
78        socket.connect(peer_addr).await?;
79        tracing::debug!(
80            local_addr = ?socket.local_addr().unwrap(),
81            peer_addr = ?socket.peer_addr().unwrap(),
82            "client socket connected"
83        );
84
85        let socket = socket.into_std()?;
86
87        set_timestamping_options(&socket, method, timestamping)?;
88
89        Ok(UdpSocket {
90            io: AsyncFd::new(socket)?,
91            send_counter: 0,
92            timestamping,
93        })
94    }
95
96    #[instrument(level = "debug")]
97    pub async fn server(
98        listen_addr: SocketAddr,
99        interface: Option<InterfaceName>,
100    ) -> io::Result<UdpSocket> {
101        let socket = tokio::net::UdpSocket::bind(listen_addr).await?;
102        tracing::debug!(
103            local_addr = ?socket.local_addr().unwrap(),
104            "server socket bound"
105        );
106
107        // bind the socket to a specific interface. This is relevant for hardware timestamping,
108        // because the interface determines which clock is used to produce the timestamps.
109        if let Some(_interface) = interface {
110            #[cfg(target_os = "linux")]
111            socket.bind_device(Some(&_interface)).unwrap();
112        }
113
114        let socket = socket.into_std()?;
115
116        // our supported kernel versions always have receive timestamping. Send timestamping for a
117        // server connection is not relevant, so we don't even bother with checking if it is supported
118        let timestamping = EnableTimestamps {
119            rx_software: true,
120            tx_software: false,
121            rx_hardware: false,
122            tx_hardware: false,
123        };
124
125        set_timestamping_options(&socket, DEFAULT_TIMESTAMP_METHOD, timestamping)?;
126
127        Ok(UdpSocket {
128            io: AsyncFd::new(socket)?,
129            send_counter: 0,
130            timestamping,
131        })
132    }
133
134    #[instrument(level = "trace", skip(self, buf), fields(
135        local_addr = debug(self.as_ref().local_addr().unwrap()),
136        peer_addr = debug(self.as_ref().peer_addr()),
137        buf_size = buf.len(),
138    ))]
139    pub async fn send(&mut self, buf: &[u8]) -> io::Result<(usize, Option<NtpTimestamp>)> {
140        tracing::trace!(size = buf.len(), "sending bytes");
141
142        let result = self
143            .io
144            .async_io(Interest::WRITABLE, |inner| inner.send(buf))
145            .await;
146
147        let send_size = match result {
148            Ok(size) => {
149                tracing::trace!(sent = size, "sent bytes");
150                size
151            }
152            Err(e) => {
153                tracing::debug!(error = debug(&e), "error sending data");
154                return Err(e);
155            }
156        };
157
158        debug_assert_eq!(buf.len(), send_size);
159
160        let expected_counter = self.send_counter;
161        self.send_counter = self.send_counter.wrapping_add(1);
162
163        if self.timestamping.tx_software {
164            #[cfg(target_os = "linux")]
165            {
166                // the send timestamp may never come set a very short timeout to prevent hanging forever.
167                // We automatically fall back to a less accurate timestamp when this function returns None
168                let timeout = std::time::Duration::from_millis(10);
169                match tokio::time::timeout(timeout, self.fetch_send_timestamp(expected_counter))
170                    .await
171                {
172                    Err(_) => {
173                        tracing::warn!("Packet without timestamp");
174                        Ok((send_size, None))
175                    }
176                    Ok(send_timestamp) => Ok((send_size, Some(send_timestamp?))),
177                }
178            }
179
180            #[cfg(any(target_os = "macos", target_os = "freebsd"))]
181            {
182                let _ = expected_counter;
183                Ok((send_size, None))
184            }
185        } else {
186            tracing::trace!("send timestamping not supported");
187            Ok((send_size, None))
188        }
189    }
190
191    #[cfg(target_os = "linux")]
192    async fn fetch_send_timestamp(&self, expected_counter: u32) -> io::Result<NtpTimestamp> {
193        let msg = "waiting for timestamp socket to become readable to fetch a send timestamp";
194        tracing::trace!(msg);
195
196        let try_read = |udp_socket: &std::net::UdpSocket| {
197            fetch_send_timestamp_help(udp_socket, expected_counter)
198        };
199
200        loop {
201            // the timestamp being available triggers the error interest
202            match self.io.async_io(Interest::ERROR, try_read).await? {
203                Some(timestamp) => return Ok(timestamp),
204                None => continue,
205            };
206        }
207    }
208
209    #[instrument(level = "trace", skip(self, buf), fields(
210        local_addr = debug(self.as_ref().local_addr().unwrap()),
211        buf_size = buf.len(),
212    ))]
213    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> io::Result<usize> {
214        tracing::trace!(size = buf.len(), ?addr, "sending bytes");
215
216        let result = self
217            .io
218            .async_io(Interest::WRITABLE, |inner| inner.send_to(buf, addr))
219            .await;
220
221        match &result {
222            Ok(size) => tracing::trace!(sent = size, "sent bytes"),
223            Err(e) => tracing::debug!(error = debug(e), "error sending data"),
224        }
225
226        result
227    }
228
229    #[instrument(level = "trace", skip(self, buf), fields(
230        local_addr = debug(self.as_ref().local_addr().unwrap()),
231        peer_addr = debug(self.as_ref().peer_addr().ok()),
232        buf_size = buf.len(),
233    ))]
234    pub async fn recv(
235        &self,
236        buf: &mut [u8],
237    ) -> io::Result<(usize, SocketAddr, Option<NtpTimestamp>)> {
238        tracing::trace!("waiting for socket to become readable");
239
240        let result = self
241            .io
242            .async_io(Interest::READABLE, |inner| recv(inner, buf))
243            .await;
244
245        match &result {
246            Ok((size, addr, ts)) => {
247                tracing::trace!(size, ts = ?ts, addr = ?addr, "received message");
248            }
249            Err(e) => tracing::debug!(error = ?e, "error receiving data"),
250        }
251
252        result
253    }
254}
255
256impl AsRef<std::net::UdpSocket> for UdpSocket {
257    fn as_ref(&self) -> &std::net::UdpSocket {
258        self.io.get_ref()
259    }
260}
261
262fn recv(
263    socket: &std::net::UdpSocket,
264    buf: &mut [u8],
265) -> io::Result<(usize, SocketAddr, Option<NtpTimestamp>)> {
266    let mut control_buf = [0; control_message_space::<[libc::timespec; 3]>()];
267
268    // loops for when we receive an interrupt during the recv
269    let (bytes_read, control_messages, sock_addr) =
270        receive_message(socket, buf, &mut control_buf, MessageQueue::Normal)?;
271    let sock_addr =
272        sock_addr.unwrap_or_else(|| unreachable!("We never constructed a non-ip socket"));
273
274    // Loops through the control messages, but we should only get a single message in practice
275    for msg in control_messages {
276        match msg {
277            ControlMessage::Timestamping(libc_timestamp) => {
278                let ntp_timestamp = libc_timestamp.into_ntp_timestamp();
279                return Ok((bytes_read as usize, sock_addr, Some(ntp_timestamp)));
280            }
281
282            #[cfg(target_os = "linux")]
283            ControlMessage::ReceiveError(_error) => {
284                tracing::warn!("unexpected error message on the MSG_ERRQUEUE");
285            }
286
287            ControlMessage::Other(msg) => {
288                tracing::warn!(
289                    "weird control message {:?} {:?}",
290                    msg.cmsg_level,
291                    msg.cmsg_type
292                );
293            }
294        }
295    }
296
297    Ok((bytes_read as usize, sock_addr, None))
298}
299
300#[cfg(target_os = "linux")]
301fn fetch_send_timestamp_help(
302    socket: &std::net::UdpSocket,
303    expected_counter: u32,
304) -> io::Result<Option<NtpTimestamp>> {
305    // we get back two control messages: one with the timestamp (just like a receive timestamp),
306    // and one error message with no error reason. The payload for this second message is kind of
307    // undocumented.
308    //
309    // section 2.1.1 of https://www.kernel.org/doc/Documentation/networking/timestamping.txt says that
310    // a `sock_extended_err` is returned, but in practice we also see a socket address. The linux
311    // kernel also has this https://github.com/torvalds/linux/blob/master/tools/testing/selftests/net/so_txtime.c#L153=
312    //
313    // sockaddr_storage is bigger than we need, but sockaddr is too small for ipv6
314    const CONTROL_SIZE: usize = control_message_space::<[libc::timespec; 3]>()
315        + control_message_space::<(libc::sock_extended_err, libc::sockaddr_storage)>();
316
317    let mut control_buf = [0; CONTROL_SIZE];
318
319    let (_, control_messages, _) =
320        receive_message(socket, &mut [], &mut control_buf, MessageQueue::Error)?;
321
322    let mut send_ts = None;
323    for msg in control_messages {
324        match msg {
325            ControlMessage::Timestamping(timestamp) => {
326                send_ts = Some(timestamp);
327            }
328
329            ControlMessage::ReceiveError(error) => {
330                // the timestamping does not set a message; if there is a message, that means
331                // something else is wrong, and we want to know about it.
332                if error.ee_errno as libc::c_int != libc::ENOMSG {
333                    tracing::warn!(
334                        expected_counter,
335                        error.ee_data,
336                        "error message on the MSG_ERRQUEUE"
337                    );
338                }
339
340                // Check that this message belongs to the send we are interested in
341                if error.ee_data != expected_counter {
342                    tracing::debug!(
343                        error.ee_data,
344                        expected_counter,
345                        "Timestamp for unrelated packet"
346                    );
347                    return Ok(None);
348                }
349            }
350
351            ControlMessage::Other(msg) => {
352                tracing::warn!(
353                    msg.cmsg_level,
354                    msg.cmsg_type,
355                    "unexpected message on the MSG_ERRQUEUE",
356                );
357            }
358        }
359    }
360
361    Ok(send_ts.map(|ts| ts.into_ntp_timestamp()))
362}
363
364#[cfg(test)]
365mod tests {
366    use std::net::Ipv4Addr;
367
368    use super::*;
369
370    #[tokio::test]
371    async fn test_client_basic_ipv4() {
372        let mut a = UdpSocket::client(
373            "127.0.0.1:10000".parse().unwrap(),
374            "127.0.0.1:10001".parse().unwrap(),
375        )
376        .await
377        .unwrap();
378        let mut b = UdpSocket::client(
379            "127.0.0.1:10001".parse().unwrap(),
380            "127.0.0.1:10000".parse().unwrap(),
381        )
382        .await
383        .unwrap();
384
385        a.send(&[1; 48]).await.unwrap();
386        let mut buf = [0; 48];
387        let (size, addr, _) = b.recv(&mut buf).await.unwrap();
388        assert_eq!(size, 48);
389        assert_eq!(addr, "127.0.0.1:10000".parse().unwrap());
390        assert_eq!(buf, [1; 48]);
391
392        b.send(&[2; 48]).await.unwrap();
393        let (size, addr, _) = a.recv(&mut buf).await.unwrap();
394        assert_eq!(size, 48);
395        assert_eq!(addr, "127.0.0.1:10001".parse().unwrap());
396        assert_eq!(buf, [2; 48]);
397    }
398
399    #[tokio::test]
400    async fn test_client_basic_ipv6() {
401        let mut a = UdpSocket::client(
402            "[::1]:10000".parse().unwrap(),
403            "[::1]:10001".parse().unwrap(),
404        )
405        .await
406        .unwrap();
407        let mut b = UdpSocket::client(
408            "[::1]:10001".parse().unwrap(),
409            "[::1]:10000".parse().unwrap(),
410        )
411        .await
412        .unwrap();
413
414        a.send(&[1; 48]).await.unwrap();
415        let mut buf = [0; 48];
416        let (size, addr, _) = b.recv(&mut buf).await.unwrap();
417        assert_eq!(size, 48);
418        assert_eq!(addr, "[::1]:10000".parse().unwrap());
419        assert_eq!(buf, [1; 48]);
420
421        b.send(&[2; 48]).await.unwrap();
422        let (size, addr, _) = a.recv(&mut buf).await.unwrap();
423        assert_eq!(size, 48);
424        assert_eq!(addr, "[::1]:10001".parse().unwrap());
425        assert_eq!(buf, [2; 48]);
426    }
427
428    #[tokio::test]
429    async fn test_server_basic_ipv4() {
430        let a = UdpSocket::server("127.0.0.1:10002".parse().unwrap(), InterfaceName::DEFAULT)
431            .await
432            .unwrap();
433        let mut b = UdpSocket::client(
434            "127.0.0.1:10003".parse().unwrap(),
435            "127.0.0.1:10002".parse().unwrap(),
436        )
437        .await
438        .unwrap();
439
440        b.send(&[1; 48]).await.unwrap();
441        let mut buf = [0; 48];
442        let (size, addr, _) = a.recv(&mut buf).await.unwrap();
443        assert_eq!(size, 48);
444        assert_eq!(addr, "127.0.0.1:10003".parse().unwrap());
445        assert_eq!(buf, [1; 48]);
446
447        a.send_to(&[2; 48], addr).await.unwrap();
448        let (size, addr, _) = b.recv(&mut buf).await.unwrap();
449        assert_eq!(size, 48);
450        assert_eq!(addr, "127.0.0.1:10002".parse().unwrap());
451        assert_eq!(buf, [2; 48]);
452    }
453
454    #[tokio::test]
455    async fn test_server_basic_ipv6() {
456        let a = UdpSocket::server("[::1]:10002".parse().unwrap(), InterfaceName::DEFAULT)
457            .await
458            .unwrap();
459        let mut b = UdpSocket::client(
460            "[::1]:10003".parse().unwrap(),
461            "[::1]:10002".parse().unwrap(),
462        )
463        .await
464        .unwrap();
465
466        b.send(&[1; 48]).await.unwrap();
467        let mut buf = [0; 48];
468        let (size, addr, _) = a.recv(&mut buf).await.unwrap();
469        assert_eq!(size, 48);
470        assert_eq!(addr, "[::1]:10003".parse().unwrap());
471        assert_eq!(buf, [1; 48]);
472
473        a.send_to(&[2; 48], addr).await.unwrap();
474        let (size, addr, _) = b.recv(&mut buf).await.unwrap();
475        assert_eq!(size, 48);
476        assert_eq!(addr, "[::1]:10002".parse().unwrap());
477        assert_eq!(buf, [2; 48]);
478    }
479
480    async fn timestamping_reasonable(method: TimestampMethod, p1: u16, p2: u16) {
481        let mut a = UdpSocket::client(
482            SocketAddr::from((Ipv4Addr::LOCALHOST, p1)),
483            SocketAddr::from((Ipv4Addr::LOCALHOST, p2)),
484        )
485        .await
486        .unwrap();
487        let b = UdpSocket::client_with_timestamping_internal(
488            SocketAddr::from((Ipv4Addr::LOCALHOST, p2)),
489            SocketAddr::from((Ipv4Addr::LOCALHOST, p1)),
490            InterfaceName::DEFAULT,
491            method,
492            EnableTimestamps {
493                rx_software: true,
494                tx_software: true,
495                rx_hardware: false,
496                tx_hardware: false,
497            },
498        )
499        .await
500        .unwrap();
501
502        tokio::spawn(async move {
503            a.send(&[1; 48]).await.unwrap();
504            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
505            a.send(&[2; 48]).await.unwrap();
506        });
507
508        let mut buf = [0; 48];
509        let (s1, _, t1) = b.recv(&mut buf).await.unwrap();
510        let (s2, _, t2) = b.recv(&mut buf).await.unwrap();
511        assert_eq!(s1, 48);
512        assert_eq!(s2, 48);
513
514        let t1 = t1.unwrap();
515        let t2 = t2.unwrap();
516        let delta = t2 - t1;
517
518        // this can be flaky on freebsd
519        assert!(
520            delta.to_seconds() > 0.15 && delta.to_seconds() < 0.25,
521            "delta was {}s",
522            delta.to_seconds()
523        );
524    }
525
526    #[tokio::test]
527    #[cfg(target_os = "linux")]
528    async fn timestamping_reasonable_so_timestamping() {
529        timestamping_reasonable(TimestampMethod::SoTimestamping, 8000, 8001).await;
530    }
531
532    #[tokio::test]
533    #[cfg(target_os = "linux")]
534    async fn timestamping_reasonable_so_timestampns() {
535        timestamping_reasonable(TimestampMethod::SoTimestampns, 8002, 8003).await;
536    }
537
538    #[tokio::test]
539    #[cfg(unix)]
540    async fn timestamping_reasonable_so_timestamp() {
541        timestamping_reasonable(TimestampMethod::SoTimestamp, 8004, 8005).await;
542    }
543
544    #[tokio::test]
545    #[cfg_attr(
546        any(target_os = "macos", target_os = "freebsd"),
547        ignore = "send timestamps are not supported"
548    )]
549    async fn test_send_timestamp() {
550        let mut a = UdpSocket::client_with_timestamping(
551            SocketAddr::from((Ipv4Addr::LOCALHOST, 8012)),
552            SocketAddr::from((Ipv4Addr::LOCALHOST, 8013)),
553            InterfaceName::DEFAULT,
554            EnableTimestamps {
555                rx_software: true,
556                tx_software: true,
557                rx_hardware: false,
558                tx_hardware: false,
559            },
560        )
561        .await
562        .unwrap();
563        let b = UdpSocket::client(
564            SocketAddr::from((Ipv4Addr::LOCALHOST, 8013)),
565            SocketAddr::from((Ipv4Addr::LOCALHOST, 8012)),
566        )
567        .await
568        .unwrap();
569
570        let (ssend, tsend) = a.send(&[1; 48]).await.unwrap();
571        let mut buf = [0; 48];
572        let (srecv, _, trecv) = b.recv(&mut buf).await.unwrap();
573
574        assert_eq!(ssend, 48);
575        assert_eq!(srecv, 48);
576
577        let tsend = tsend.unwrap();
578        let trecv = trecv.unwrap();
579        let delta = trecv - tsend;
580        assert!(delta.to_seconds().abs() < 0.2);
581    }
582}