1#![forbid(unsafe_code)]
2
3use std::{io, net::SocketAddr};
4
5use ntp_proto::NtpTimestamp;
6use tokio::io::{unix::AsyncFd, Interest};
7use tracing::instrument;
8
9use crate::{
10 interface::InterfaceName,
11 raw_socket::{
12 control_message_space, receive_message, set_timestamping_options, ControlMessage,
13 MessageQueue, TimestampMethod,
14 },
15 EnableTimestamps,
16};
17
18pub struct UdpSocket {
19 io: AsyncFd<std::net::UdpSocket>,
20 send_counter: u32,
21 timestamping: EnableTimestamps,
22}
23
24#[cfg(target_os = "linux")]
25const DEFAULT_TIMESTAMP_METHOD: TimestampMethod = TimestampMethod::SoTimestamping;
26
27#[cfg(all(unix, not(target_os = "linux")))]
28const DEFAULT_TIMESTAMP_METHOD: TimestampMethod = TimestampMethod::SoTimestamp;
29
30impl UdpSocket {
31 #[instrument(level = "debug", skip(peer_addr))]
32 pub async fn client(listen_addr: SocketAddr, peer_addr: SocketAddr) -> io::Result<UdpSocket> {
33 Self::client_with_timestamping(
34 listen_addr,
35 peer_addr,
36 InterfaceName::DEFAULT,
37 EnableTimestamps::default(),
38 )
39 .await
40 }
41
42 pub async fn client_with_timestamping(
43 listen_addr: SocketAddr,
44 peer_addr: SocketAddr,
45 interface: Option<InterfaceName>,
46 timestamping: EnableTimestamps,
47 ) -> io::Result<UdpSocket> {
48 Self::client_with_timestamping_internal(
49 listen_addr,
50 peer_addr,
51 interface,
52 DEFAULT_TIMESTAMP_METHOD,
53 timestamping,
54 )
55 .await
56 }
57
58 async fn client_with_timestamping_internal(
59 listen_addr: SocketAddr,
60 peer_addr: SocketAddr,
61 interface: Option<InterfaceName>,
62 method: TimestampMethod,
63 timestamping: EnableTimestamps,
64 ) -> io::Result<UdpSocket> {
65 let socket = tokio::net::UdpSocket::bind(listen_addr).await?;
66 tracing::debug!(
67 local_addr = ?socket.local_addr().unwrap(),
68 "client socket bound"
69 );
70
71 if let Some(_interface) = interface {
74 #[cfg(target_os = "linux")]
75 socket.bind_device(Some(&_interface)).unwrap();
76 }
77
78 socket.connect(peer_addr).await?;
79 tracing::debug!(
80 local_addr = ?socket.local_addr().unwrap(),
81 peer_addr = ?socket.peer_addr().unwrap(),
82 "client socket connected"
83 );
84
85 let socket = socket.into_std()?;
86
87 set_timestamping_options(&socket, method, timestamping)?;
88
89 Ok(UdpSocket {
90 io: AsyncFd::new(socket)?,
91 send_counter: 0,
92 timestamping,
93 })
94 }
95
96 #[instrument(level = "debug")]
97 pub async fn server(
98 listen_addr: SocketAddr,
99 interface: Option<InterfaceName>,
100 ) -> io::Result<UdpSocket> {
101 let socket = tokio::net::UdpSocket::bind(listen_addr).await?;
102 tracing::debug!(
103 local_addr = ?socket.local_addr().unwrap(),
104 "server socket bound"
105 );
106
107 if let Some(_interface) = interface {
110 #[cfg(target_os = "linux")]
111 socket.bind_device(Some(&_interface)).unwrap();
112 }
113
114 let socket = socket.into_std()?;
115
116 let timestamping = EnableTimestamps {
119 rx_software: true,
120 tx_software: false,
121 rx_hardware: false,
122 tx_hardware: false,
123 };
124
125 set_timestamping_options(&socket, DEFAULT_TIMESTAMP_METHOD, timestamping)?;
126
127 Ok(UdpSocket {
128 io: AsyncFd::new(socket)?,
129 send_counter: 0,
130 timestamping,
131 })
132 }
133
134 #[instrument(level = "trace", skip(self, buf), fields(
135 local_addr = debug(self.as_ref().local_addr().unwrap()),
136 peer_addr = debug(self.as_ref().peer_addr()),
137 buf_size = buf.len(),
138 ))]
139 pub async fn send(&mut self, buf: &[u8]) -> io::Result<(usize, Option<NtpTimestamp>)> {
140 tracing::trace!(size = buf.len(), "sending bytes");
141
142 let result = self
143 .io
144 .async_io(Interest::WRITABLE, |inner| inner.send(buf))
145 .await;
146
147 let send_size = match result {
148 Ok(size) => {
149 tracing::trace!(sent = size, "sent bytes");
150 size
151 }
152 Err(e) => {
153 tracing::debug!(error = debug(&e), "error sending data");
154 return Err(e);
155 }
156 };
157
158 debug_assert_eq!(buf.len(), send_size);
159
160 let expected_counter = self.send_counter;
161 self.send_counter = self.send_counter.wrapping_add(1);
162
163 if self.timestamping.tx_software {
164 #[cfg(target_os = "linux")]
165 {
166 let timeout = std::time::Duration::from_millis(10);
169 match tokio::time::timeout(timeout, self.fetch_send_timestamp(expected_counter))
170 .await
171 {
172 Err(_) => {
173 tracing::warn!("Packet without timestamp");
174 Ok((send_size, None))
175 }
176 Ok(send_timestamp) => Ok((send_size, Some(send_timestamp?))),
177 }
178 }
179
180 #[cfg(any(target_os = "macos", target_os = "freebsd"))]
181 {
182 let _ = expected_counter;
183 Ok((send_size, None))
184 }
185 } else {
186 tracing::trace!("send timestamping not supported");
187 Ok((send_size, None))
188 }
189 }
190
191 #[cfg(target_os = "linux")]
192 async fn fetch_send_timestamp(&self, expected_counter: u32) -> io::Result<NtpTimestamp> {
193 let msg = "waiting for timestamp socket to become readable to fetch a send timestamp";
194 tracing::trace!(msg);
195
196 let try_read = |udp_socket: &std::net::UdpSocket| {
197 fetch_send_timestamp_help(udp_socket, expected_counter)
198 };
199
200 loop {
201 match self.io.async_io(Interest::ERROR, try_read).await? {
203 Some(timestamp) => return Ok(timestamp),
204 None => continue,
205 };
206 }
207 }
208
209 #[instrument(level = "trace", skip(self, buf), fields(
210 local_addr = debug(self.as_ref().local_addr().unwrap()),
211 buf_size = buf.len(),
212 ))]
213 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> io::Result<usize> {
214 tracing::trace!(size = buf.len(), ?addr, "sending bytes");
215
216 let result = self
217 .io
218 .async_io(Interest::WRITABLE, |inner| inner.send_to(buf, addr))
219 .await;
220
221 match &result {
222 Ok(size) => tracing::trace!(sent = size, "sent bytes"),
223 Err(e) => tracing::debug!(error = debug(e), "error sending data"),
224 }
225
226 result
227 }
228
229 #[instrument(level = "trace", skip(self, buf), fields(
230 local_addr = debug(self.as_ref().local_addr().unwrap()),
231 peer_addr = debug(self.as_ref().peer_addr().ok()),
232 buf_size = buf.len(),
233 ))]
234 pub async fn recv(
235 &self,
236 buf: &mut [u8],
237 ) -> io::Result<(usize, SocketAddr, Option<NtpTimestamp>)> {
238 tracing::trace!("waiting for socket to become readable");
239
240 let result = self
241 .io
242 .async_io(Interest::READABLE, |inner| recv(inner, buf))
243 .await;
244
245 match &result {
246 Ok((size, addr, ts)) => {
247 tracing::trace!(size, ts = ?ts, addr = ?addr, "received message");
248 }
249 Err(e) => tracing::debug!(error = ?e, "error receiving data"),
250 }
251
252 result
253 }
254}
255
256impl AsRef<std::net::UdpSocket> for UdpSocket {
257 fn as_ref(&self) -> &std::net::UdpSocket {
258 self.io.get_ref()
259 }
260}
261
262fn recv(
263 socket: &std::net::UdpSocket,
264 buf: &mut [u8],
265) -> io::Result<(usize, SocketAddr, Option<NtpTimestamp>)> {
266 let mut control_buf = [0; control_message_space::<[libc::timespec; 3]>()];
267
268 let (bytes_read, control_messages, sock_addr) =
270 receive_message(socket, buf, &mut control_buf, MessageQueue::Normal)?;
271 let sock_addr =
272 sock_addr.unwrap_or_else(|| unreachable!("We never constructed a non-ip socket"));
273
274 for msg in control_messages {
276 match msg {
277 ControlMessage::Timestamping(libc_timestamp) => {
278 let ntp_timestamp = libc_timestamp.into_ntp_timestamp();
279 return Ok((bytes_read as usize, sock_addr, Some(ntp_timestamp)));
280 }
281
282 #[cfg(target_os = "linux")]
283 ControlMessage::ReceiveError(_error) => {
284 tracing::warn!("unexpected error message on the MSG_ERRQUEUE");
285 }
286
287 ControlMessage::Other(msg) => {
288 tracing::warn!(
289 "weird control message {:?} {:?}",
290 msg.cmsg_level,
291 msg.cmsg_type
292 );
293 }
294 }
295 }
296
297 Ok((bytes_read as usize, sock_addr, None))
298}
299
300#[cfg(target_os = "linux")]
301fn fetch_send_timestamp_help(
302 socket: &std::net::UdpSocket,
303 expected_counter: u32,
304) -> io::Result<Option<NtpTimestamp>> {
305 const CONTROL_SIZE: usize = control_message_space::<[libc::timespec; 3]>()
315 + control_message_space::<(libc::sock_extended_err, libc::sockaddr_storage)>();
316
317 let mut control_buf = [0; CONTROL_SIZE];
318
319 let (_, control_messages, _) =
320 receive_message(socket, &mut [], &mut control_buf, MessageQueue::Error)?;
321
322 let mut send_ts = None;
323 for msg in control_messages {
324 match msg {
325 ControlMessage::Timestamping(timestamp) => {
326 send_ts = Some(timestamp);
327 }
328
329 ControlMessage::ReceiveError(error) => {
330 if error.ee_errno as libc::c_int != libc::ENOMSG {
333 tracing::warn!(
334 expected_counter,
335 error.ee_data,
336 "error message on the MSG_ERRQUEUE"
337 );
338 }
339
340 if error.ee_data != expected_counter {
342 tracing::debug!(
343 error.ee_data,
344 expected_counter,
345 "Timestamp for unrelated packet"
346 );
347 return Ok(None);
348 }
349 }
350
351 ControlMessage::Other(msg) => {
352 tracing::warn!(
353 msg.cmsg_level,
354 msg.cmsg_type,
355 "unexpected message on the MSG_ERRQUEUE",
356 );
357 }
358 }
359 }
360
361 Ok(send_ts.map(|ts| ts.into_ntp_timestamp()))
362}
363
364#[cfg(test)]
365mod tests {
366 use std::net::Ipv4Addr;
367
368 use super::*;
369
370 #[tokio::test]
371 async fn test_client_basic_ipv4() {
372 let mut a = UdpSocket::client(
373 "127.0.0.1:10000".parse().unwrap(),
374 "127.0.0.1:10001".parse().unwrap(),
375 )
376 .await
377 .unwrap();
378 let mut b = UdpSocket::client(
379 "127.0.0.1:10001".parse().unwrap(),
380 "127.0.0.1:10000".parse().unwrap(),
381 )
382 .await
383 .unwrap();
384
385 a.send(&[1; 48]).await.unwrap();
386 let mut buf = [0; 48];
387 let (size, addr, _) = b.recv(&mut buf).await.unwrap();
388 assert_eq!(size, 48);
389 assert_eq!(addr, "127.0.0.1:10000".parse().unwrap());
390 assert_eq!(buf, [1; 48]);
391
392 b.send(&[2; 48]).await.unwrap();
393 let (size, addr, _) = a.recv(&mut buf).await.unwrap();
394 assert_eq!(size, 48);
395 assert_eq!(addr, "127.0.0.1:10001".parse().unwrap());
396 assert_eq!(buf, [2; 48]);
397 }
398
399 #[tokio::test]
400 async fn test_client_basic_ipv6() {
401 let mut a = UdpSocket::client(
402 "[::1]:10000".parse().unwrap(),
403 "[::1]:10001".parse().unwrap(),
404 )
405 .await
406 .unwrap();
407 let mut b = UdpSocket::client(
408 "[::1]:10001".parse().unwrap(),
409 "[::1]:10000".parse().unwrap(),
410 )
411 .await
412 .unwrap();
413
414 a.send(&[1; 48]).await.unwrap();
415 let mut buf = [0; 48];
416 let (size, addr, _) = b.recv(&mut buf).await.unwrap();
417 assert_eq!(size, 48);
418 assert_eq!(addr, "[::1]:10000".parse().unwrap());
419 assert_eq!(buf, [1; 48]);
420
421 b.send(&[2; 48]).await.unwrap();
422 let (size, addr, _) = a.recv(&mut buf).await.unwrap();
423 assert_eq!(size, 48);
424 assert_eq!(addr, "[::1]:10001".parse().unwrap());
425 assert_eq!(buf, [2; 48]);
426 }
427
428 #[tokio::test]
429 async fn test_server_basic_ipv4() {
430 let a = UdpSocket::server("127.0.0.1:10002".parse().unwrap(), InterfaceName::DEFAULT)
431 .await
432 .unwrap();
433 let mut b = UdpSocket::client(
434 "127.0.0.1:10003".parse().unwrap(),
435 "127.0.0.1:10002".parse().unwrap(),
436 )
437 .await
438 .unwrap();
439
440 b.send(&[1; 48]).await.unwrap();
441 let mut buf = [0; 48];
442 let (size, addr, _) = a.recv(&mut buf).await.unwrap();
443 assert_eq!(size, 48);
444 assert_eq!(addr, "127.0.0.1:10003".parse().unwrap());
445 assert_eq!(buf, [1; 48]);
446
447 a.send_to(&[2; 48], addr).await.unwrap();
448 let (size, addr, _) = b.recv(&mut buf).await.unwrap();
449 assert_eq!(size, 48);
450 assert_eq!(addr, "127.0.0.1:10002".parse().unwrap());
451 assert_eq!(buf, [2; 48]);
452 }
453
454 #[tokio::test]
455 async fn test_server_basic_ipv6() {
456 let a = UdpSocket::server("[::1]:10002".parse().unwrap(), InterfaceName::DEFAULT)
457 .await
458 .unwrap();
459 let mut b = UdpSocket::client(
460 "[::1]:10003".parse().unwrap(),
461 "[::1]:10002".parse().unwrap(),
462 )
463 .await
464 .unwrap();
465
466 b.send(&[1; 48]).await.unwrap();
467 let mut buf = [0; 48];
468 let (size, addr, _) = a.recv(&mut buf).await.unwrap();
469 assert_eq!(size, 48);
470 assert_eq!(addr, "[::1]:10003".parse().unwrap());
471 assert_eq!(buf, [1; 48]);
472
473 a.send_to(&[2; 48], addr).await.unwrap();
474 let (size, addr, _) = b.recv(&mut buf).await.unwrap();
475 assert_eq!(size, 48);
476 assert_eq!(addr, "[::1]:10002".parse().unwrap());
477 assert_eq!(buf, [2; 48]);
478 }
479
480 async fn timestamping_reasonable(method: TimestampMethod, p1: u16, p2: u16) {
481 let mut a = UdpSocket::client(
482 SocketAddr::from((Ipv4Addr::LOCALHOST, p1)),
483 SocketAddr::from((Ipv4Addr::LOCALHOST, p2)),
484 )
485 .await
486 .unwrap();
487 let b = UdpSocket::client_with_timestamping_internal(
488 SocketAddr::from((Ipv4Addr::LOCALHOST, p2)),
489 SocketAddr::from((Ipv4Addr::LOCALHOST, p1)),
490 InterfaceName::DEFAULT,
491 method,
492 EnableTimestamps {
493 rx_software: true,
494 tx_software: true,
495 rx_hardware: false,
496 tx_hardware: false,
497 },
498 )
499 .await
500 .unwrap();
501
502 tokio::spawn(async move {
503 a.send(&[1; 48]).await.unwrap();
504 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
505 a.send(&[2; 48]).await.unwrap();
506 });
507
508 let mut buf = [0; 48];
509 let (s1, _, t1) = b.recv(&mut buf).await.unwrap();
510 let (s2, _, t2) = b.recv(&mut buf).await.unwrap();
511 assert_eq!(s1, 48);
512 assert_eq!(s2, 48);
513
514 let t1 = t1.unwrap();
515 let t2 = t2.unwrap();
516 let delta = t2 - t1;
517
518 assert!(
520 delta.to_seconds() > 0.15 && delta.to_seconds() < 0.25,
521 "delta was {}s",
522 delta.to_seconds()
523 );
524 }
525
526 #[tokio::test]
527 #[cfg(target_os = "linux")]
528 async fn timestamping_reasonable_so_timestamping() {
529 timestamping_reasonable(TimestampMethod::SoTimestamping, 8000, 8001).await;
530 }
531
532 #[tokio::test]
533 #[cfg(target_os = "linux")]
534 async fn timestamping_reasonable_so_timestampns() {
535 timestamping_reasonable(TimestampMethod::SoTimestampns, 8002, 8003).await;
536 }
537
538 #[tokio::test]
539 #[cfg(unix)]
540 async fn timestamping_reasonable_so_timestamp() {
541 timestamping_reasonable(TimestampMethod::SoTimestamp, 8004, 8005).await;
542 }
543
544 #[tokio::test]
545 #[cfg_attr(
546 any(target_os = "macos", target_os = "freebsd"),
547 ignore = "send timestamps are not supported"
548 )]
549 async fn test_send_timestamp() {
550 let mut a = UdpSocket::client_with_timestamping(
551 SocketAddr::from((Ipv4Addr::LOCALHOST, 8012)),
552 SocketAddr::from((Ipv4Addr::LOCALHOST, 8013)),
553 InterfaceName::DEFAULT,
554 EnableTimestamps {
555 rx_software: true,
556 tx_software: true,
557 rx_hardware: false,
558 tx_hardware: false,
559 },
560 )
561 .await
562 .unwrap();
563 let b = UdpSocket::client(
564 SocketAddr::from((Ipv4Addr::LOCALHOST, 8013)),
565 SocketAddr::from((Ipv4Addr::LOCALHOST, 8012)),
566 )
567 .await
568 .unwrap();
569
570 let (ssend, tsend) = a.send(&[1; 48]).await.unwrap();
571 let mut buf = [0; 48];
572 let (srecv, _, trecv) = b.recv(&mut buf).await.unwrap();
573
574 assert_eq!(ssend, 48);
575 assert_eq!(srecv, 48);
576
577 let tsend = tsend.unwrap();
578 let trecv = trecv.unwrap();
579 let delta = trecv - tsend;
580 assert!(delta.to_seconds().abs() < 0.2);
581 }
582}