Skip to main content

massping/
pinger.rs

1#[cfg(feature = "stream")]
2use std::pin::Pin;
3use std::{
4    collections::HashMap,
5    future::poll_fn,
6    io,
7    iter::Peekable,
8    net::{Ipv4Addr, Ipv6Addr},
9    sync::{
10        Arc,
11        atomic::{AtomicU64, Ordering},
12    },
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use bytes::{Bytes, BytesMut};
18#[cfg(feature = "stream")]
19use futures_core::Stream;
20use tokio::{
21    sync::mpsc::{self, error::TryRecvError},
22    time::Instant,
23};
24
25use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
26
27/// A pinger for IPv4 addresses
28pub type V4Pinger = Pinger<Ipv4Addr>;
29/// A pinger for IPv6 addresses
30pub type V6Pinger = Pinger<Ipv6Addr>;
31
32/// A pinger for [`IpVersion`] (either [`Ipv4Addr`] or [`Ipv6Addr`]).
33///
34/// Cloning is cheap: clones share the same socket and background
35/// receive task, which shut down when the last clone is dropped.
36pub struct Pinger<V: IpVersion> {
37    inner: Arc<InnerPinger<V>>,
38    // Kept out of `InnerPinger` (which the background receive task holds)
39    // so that dropping the last `Pinger` clone disconnects the channel,
40    // telling the background task to shut down and release the socket.
41    round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
42}
43
44impl<V: IpVersion> Clone for Pinger<V> {
45    fn clone(&self) -> Self {
46        Self {
47            inner: Arc::clone(&self.inner),
48            round_sender: self.round_sender.clone(),
49        }
50    }
51}
52
53struct InnerPinger<V: IpVersion> {
54    raw: RawPinger<V>,
55    next_round_id: AtomicU64,
56}
57
58// Each `measure_many` round gets a unique `u64` id; the wire sequence
59// number is its lower 16 bits. The full id lets the receive task tell
60// rounds apart after the sequence number wraps around.
61enum RoundMessage<V: IpVersion> {
62    Subscribe {
63        round_id: u64,
64        expected_payload: Bytes,
65        sender: mpsc::UnboundedSender<(V, Instant)>,
66    },
67    Unsubscribe {
68        round_id: u64,
69    },
70}
71
72struct Subscriber<V: IpVersion> {
73    round_id: u64,
74    expected_payload: Bytes,
75    sender: mpsc::UnboundedSender<(V, Instant)>,
76}
77
78enum PollResult<V: IpVersion> {
79    Subscription(RoundMessage<V>),
80    Packet(crate::packet::EchoReplyPacket<V>),
81}
82
83impl<V: IpVersion> Pinger<V> {
84    /// Construct a new `Pinger`.
85    ///
86    /// For maximum efficiency the same instance of `Pinger` should
87    /// be used for as long as possible, although it might also
88    /// be beneficial to `Drop` the `Pinger` and recreate it if
89    /// you are not going to be sending pings for a long period of time.
90    ///
91    /// # Panics
92    ///
93    /// Panics if called from outside a tokio runtime, as it spawns a
94    /// background receive task.
95    pub fn new() -> io::Result<Self> {
96        let raw = RawPinger::new()?;
97
98        let (sender, mut receiver) = mpsc::unbounded_channel();
99
100        let inner = Arc::new(InnerPinger {
101            raw,
102            next_round_id: AtomicU64::new(0),
103        });
104
105        // Spawn async receive task using the same socket.
106        // It runs until `receiver` disconnects, which happens when the
107        // `Pinger` holding the only sender is dropped.
108        let inner_recv = Arc::clone(&inner);
109        tokio::spawn(async move {
110            let mut subscribers: HashMap<u16, Subscriber<V>> = HashMap::new();
111            // Buffer kept outside poll_fn so it persists across polls.
112            let mut recv_buf = BytesMut::new();
113
114            loop {
115                // Poll both subscription channel and socket in the same waker context.
116                // This ensures we wake on either event, which is required for
117                // single-threaded runtimes where we can't rely on concurrent execution.
118                //
119                // Note: We use try_recv() before poll_recv() as a fast path optimization.
120                // Benchmarks show this is ~2x faster when messages are already queued
121                // (~15ns vs ~25ns per iteration).
122                let result = poll_fn(|cx| {
123                    // Fast path: check for subscription changes (non-blocking, no waker)
124                    match receiver.try_recv() {
125                        Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
126                        Err(TryRecvError::Empty) => {
127                            // Continue - poll_recv() below will register the waker for this channel
128                        }
129                        Err(TryRecvError::Disconnected) => return Poll::Ready(None),
130                    }
131
132                    // Try to receive an ICMP packet
133                    match inner_recv.raw.poll_recv(&mut recv_buf, cx) {
134                        Poll::Ready(Ok(packet)) => {
135                            return Poll::Ready(Some(PollResult::Packet(packet)));
136                        }
137                        Poll::Ready(Err(_)) => {
138                            // Receiving failed (typically a transient kernel
139                            // resource error). The socket readiness was
140                            // consumed without registering a waker, so ask to
141                            // be polled again right away; parking here would
142                            // suspend reply processing until an unrelated
143                            // subscription message wakes the task.
144                            cx.waker().wake_by_ref();
145                        }
146                        Poll::Pending => {}
147                    }
148
149                    // Register waker for subscription channel
150                    // We need to wake up when new subscriptions arrive
151                    match receiver.poll_recv(cx) {
152                        Poll::Ready(Some(msg)) => {
153                            return Poll::Ready(Some(PollResult::Subscription(msg)));
154                        }
155                        Poll::Ready(None) => return Poll::Ready(None),
156                        Poll::Pending => {}
157                    }
158
159                    Poll::Pending
160                })
161                .await;
162
163                match result {
164                    Some(PollResult::Subscription(RoundMessage::Subscribe {
165                        round_id,
166                        expected_payload,
167                        sender,
168                    })) => {
169                        // A new round may displace a still-subscribed round
170                        // whose sequence number collided after wraparound;
171                        // the displaced round could not be served anyway as
172                        // replies can only be told apart by sequence number.
173                        subscribers.insert(
174                            round_id as u16,
175                            Subscriber {
176                                round_id,
177                                expected_payload,
178                                sender,
179                            },
180                        );
181                    }
182                    Some(PollResult::Subscription(RoundMessage::Unsubscribe { round_id })) => {
183                        let sequence_number = round_id as u16;
184                        // Only unsubscribe if the slot still belongs to this
185                        // round: after sequence number wraparound it may have
186                        // been taken over by a newer round, which must keep
187                        // receiving replies.
188                        if subscribers
189                            .get(&sequence_number)
190                            .is_some_and(|subscriber| subscriber.round_id == round_id)
191                        {
192                            subscribers.remove(&sequence_number);
193                        }
194                    }
195                    Some(PollResult::Packet(packet)) => {
196                        let recv_instant = Instant::now();
197
198                        let packet_source = packet.source();
199                        let packet_sequence_number = packet.sequence_number();
200
201                        if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
202                            // An echo reply mirrors the request's payload, so
203                            // a mismatch means the reply wasn't produced by
204                            // this round (e.g. a reply to an older round whose
205                            // sequence number collided after wraparound, or
206                            // blindly spoofed cross-traffic). Discard it.
207                            let payload_matches =
208                                packet.payload() == &subscriber.expected_payload[..];
209
210                            if payload_matches
211                                && subscriber
212                                    .sender
213                                    .send((packet_source, recv_instant))
214                                    .is_err()
215                            {
216                                subscribers.remove(&packet_sequence_number);
217                            }
218                        }
219                    }
220                    None => return, // Channel closed
221                }
222            }
223        });
224
225        Ok(Self {
226            inner,
227            round_sender: sender,
228        })
229    }
230
231    /// Ping `addresses`
232    ///
233    /// Creates [`MeasureManyStream`] which **lazily** sends ping
234    /// requests and [`Stream`]s the responses as they arrive.
235    ///
236    /// Replies are matched by source address, so an address that appears
237    /// multiple times is only pinged once per round and yields a single
238    /// measurement.
239    ///
240    /// # Panics
241    ///
242    /// Panics if the background receive task has terminated, which only
243    /// happens when the runtime the `Pinger` was created on has been
244    /// shut down.
245    ///
246    /// [`Stream`]: futures_core::Stream
247    pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
248    where
249        I: Iterator<Item = V>,
250    {
251        let (size_hint, _) = addresses.size_hint();
252        let send_queue = addresses.into_iter().peekable();
253        let (sender, receiver) = mpsc::unbounded_channel();
254
255        // Relaxed is enough: the counter is a pure id allocator, no other
256        // memory is synchronized through it.
257        let round_id = self.inner.next_round_id.fetch_add(1, Ordering::Relaxed);
258
259        // The same packet is reused for every address of the round. Its
260        // random payload lets the receive task discard replies that don't
261        // belong to this round.
262        //
263        // The identifier is irrelevant: the kernel overwrites it with the
264        // socket's own identifier, which it also uses to route echo replies
265        // back to this socket.
266        let payload = rand::random::<[u8; 64]>();
267        let packet = EchoRequestPacket::new(0, round_id as u16, &payload);
268
269        if self
270            .round_sender
271            .send(RoundMessage::Subscribe {
272                round_id,
273                expected_payload: packet.payload(),
274                sender,
275            })
276            .is_err()
277        {
278            panic!("Receiver closed");
279        }
280
281        MeasureManyStream {
282            pinger: self,
283            packet,
284            send_queue,
285            in_flight: HashMap::with_capacity(size_hint),
286            receiver,
287            round_id,
288        }
289    }
290}
291
292/// A [`Stream`] of ping responses.
293///
294/// No kind of `rtt` timeout is implemented, so an external mechanism
295/// like [`tokio::time::timeout`] should be used to prevent the program
296/// from hanging indefinitely.
297///
298/// Leaking this stream may create a memory leak that lasts until the
299/// [`Pinger`] is dropped.
300///
301/// [`Stream`]: futures_core::Stream
302/// [`tokio::time::timeout`]: tokio::time::timeout
303#[must_use = "streams do nothing unless polled"]
304pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
305    pinger: &'a Pinger<V>,
306    packet: EchoRequestPacket<V>,
307    send_queue: Peekable<I>,
308    in_flight: HashMap<V, Instant>,
309    receiver: mpsc::UnboundedReceiver<(V, Instant)>,
310    round_id: u64,
311}
312
313impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
314    pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
315        // Try to receive a response (may be from a different round)
316        if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
317            return Poll::Ready(maybe_reply);
318        }
319
320        // Try to send ICMP echo requests
321        self.poll_next_icmp_replies(cx);
322
323        // Check if we're done: no more addresses to send AND no responses pending
324        if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
325            return Poll::Ready(None);
326        }
327
328        Poll::Pending
329    }
330
331    fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
332        while let Some(&addr) = self.send_queue.peek() {
333            // Replies are matched by source address within a round, so a
334            // second ping to an address that is still awaiting its reply
335            // could never produce a second measurement; it would only
336            // clobber the first ping's start time. Skip the duplicate.
337            if self.in_flight.contains_key(&addr) {
338                self.send_queue.next();
339                continue;
340            }
341
342            match self.pinger.inner.raw.poll_send_to(cx, addr, &self.packet) {
343                Poll::Ready(result) => {
344                    let sent_at = Instant::now();
345
346                    let taken_addr = self.send_queue.next();
347                    debug_assert!(taken_addr.is_some());
348
349                    // If the send failed (e.g. no route to host) no reply
350                    // can ever arrive, so don't track the address as
351                    // in-flight or the stream would never terminate.
352                    if result.is_ok() {
353                        self.in_flight.insert(addr, sent_at);
354                    }
355                }
356                Poll::Pending => {
357                    // The socket only remembers the most recent waker per
358                    // direction (`AsyncFd` semantics), so with multiple
359                    // streams sharing the socket another stream could
360                    // overwrite ours and we'd never be woken again. Sends
361                    // only return `Pending` while the send buffer is full,
362                    // which clears up quickly, so schedule an immediate
363                    // re-poll instead of parking.
364                    cx.waker().wake_by_ref();
365                    break;
366                }
367            }
368        }
369    }
370
371    fn poll_next_from_different_round(
372        &mut self,
373        cx: &mut Context<'_>,
374    ) -> Poll<Option<(V, Duration)>> {
375        loop {
376            match self.receiver.poll_recv(cx) {
377                Poll::Pending => return Poll::Pending,
378                Poll::Ready(Some((addr, recv_instant))) => {
379                    if let Some(send_instant) = self.in_flight.remove(&addr) {
380                        let rtt = recv_instant - send_instant;
381                        return Poll::Ready(Some((addr, rtt)));
382                    }
383                }
384                Poll::Ready(None) => return Poll::Ready(None),
385            }
386        }
387    }
388}
389
390#[cfg(feature = "stream")]
391impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
392    type Item = (V, Duration);
393
394    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395        self.as_mut().poll_next_unpin(cx)
396    }
397}
398
399impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
400    fn drop(&mut self) {
401        let _ = self.pinger.round_sender.send(RoundMessage::Unsubscribe {
402            round_id: self.round_id,
403        });
404    }
405}