Skip to main content

massping/
pinger.rs

1#[cfg(feature = "stream")]
2use std::pin::Pin;
3use std::{
4    collections::HashMap,
5    future::poll_fn,
6    io,
7    iter::Peekable,
8    net::{Ipv4Addr, Ipv6Addr},
9    sync::{
10        Arc,
11        atomic::{AtomicU16, Ordering},
12    },
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use bytes::BytesMut;
18#[cfg(feature = "stream")]
19use futures_core::Stream;
20use tokio::{
21    sync::mpsc::{self, error::TryRecvError},
22    time::Instant,
23};
24
25use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
26
27/// A pinger for IPv4 addresses
28pub type V4Pinger = Pinger<Ipv4Addr>;
29/// A pinger for IPv6 addresses
30pub type V6Pinger = Pinger<Ipv6Addr>;
31
32/// A pinger for [`IpVersion`] (either [`Ipv4Addr`] or [`Ipv6Addr`]).
33pub struct Pinger<V: IpVersion> {
34    inner: Arc<InnerPinger<V>>,
35}
36
37struct InnerPinger<V: IpVersion> {
38    raw: RawPinger<V>,
39    round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
40    identifier: u16,
41    sequence_number: AtomicU16,
42}
43
44enum RoundMessage<V: IpVersion> {
45    Subscribe {
46        sequence_number: u16,
47        sender: mpsc::UnboundedSender<(V, Instant)>,
48    },
49    Unsubscribe {
50        sequence_number: u16,
51    },
52}
53
54enum PollResult<V: IpVersion> {
55    Subscription(RoundMessage<V>),
56    Packet(crate::packet::EchoReplyPacket<V>),
57}
58
59impl<V: IpVersion> Pinger<V> {
60    /// Construct a new `Pinger`.
61    ///
62    /// For maximum efficiency the same instance of `Pinger` should
63    /// be used for as long as possible, altough it might also
64    /// be beneficial to `Drop` the `Pinger` and recreate it if
65    /// you are not going to be sending pings for a long period of time.
66    pub fn new() -> io::Result<Self> {
67        let raw = RawPinger::new()?;
68
69        let identifier = rand::random::<u16>();
70
71        let (sender, mut receiver) = mpsc::unbounded_channel();
72
73        let inner = Arc::new(InnerPinger {
74            raw,
75            round_sender: sender,
76            identifier,
77            sequence_number: AtomicU16::new(0),
78        });
79
80        // Spawn async receive task using the same socket
81        let inner_recv = Arc::clone(&inner);
82        tokio::spawn(async move {
83            let mut subscribers: HashMap<u16, mpsc::UnboundedSender<(V, Instant)>> = HashMap::new();
84            // Buffer kept outside poll_fn so it persists across polls.
85            let mut recv_buf = BytesMut::new();
86
87            loop {
88                // Poll both subscription channel and socket in the same waker context.
89                // This ensures we wake on either event, which is required for
90                // single-threaded runtimes where we can't rely on concurrent execution.
91                //
92                // Note: We use try_recv() before poll_recv() as a fast path optimization.
93                // Benchmarks show this is ~2x faster when messages are already queued
94                // (~15ns vs ~25ns per iteration).
95                let result = poll_fn(|cx| {
96                    // Fast path: check for subscription changes (non-blocking, no waker)
97                    match receiver.try_recv() {
98                        Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
99                        Err(TryRecvError::Empty) => {
100                            // Continue - poll_recv() below will register the waker for this channel
101                        }
102                        Err(TryRecvError::Disconnected) => return Poll::Ready(None),
103                    }
104
105                    // Try to receive an ICMP packet
106                    if let Poll::Ready(Ok(packet)) = inner_recv.raw.poll_recv(&mut recv_buf, cx) {
107                        return Poll::Ready(Some(PollResult::Packet(packet)));
108                    }
109                    // Socket error or not ready - continue polling
110
111                    // Register waker for subscription channel
112                    // We need to wake up when new subscriptions arrive
113                    match receiver.poll_recv(cx) {
114                        Poll::Ready(Some(msg)) => {
115                            return Poll::Ready(Some(PollResult::Subscription(msg)));
116                        }
117                        Poll::Ready(None) => return Poll::Ready(None),
118                        Poll::Pending => {}
119                    }
120
121                    Poll::Pending
122                })
123                .await;
124
125                match result {
126                    Some(PollResult::Subscription(RoundMessage::Subscribe {
127                        sequence_number,
128                        sender,
129                    })) => {
130                        subscribers.insert(sequence_number, sender);
131                    }
132                    Some(PollResult::Subscription(RoundMessage::Unsubscribe {
133                        sequence_number,
134                    })) => {
135                        subscribers.remove(&sequence_number);
136                    }
137                    Some(PollResult::Packet(packet)) => {
138                        let recv_instant = Instant::now();
139
140                        let packet_source = packet.source();
141                        let packet_sequence_number = packet.sequence_number();
142
143                        if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
144                            if subscriber.send((packet_source, recv_instant)).is_err() {
145                                subscribers.remove(&packet_sequence_number);
146                            }
147                        }
148                    }
149                    None => return, // Channel closed
150                }
151            }
152        });
153
154        Ok(Self { inner })
155    }
156
157    /// Ping `addresses`
158    ///
159    /// Creates [`MeasureManyStream`] which **lazily** sends ping
160    /// requests and [`Stream`]s the responses as they arrive.
161    ///
162    /// [`Stream`]: futures_core::Stream
163    pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
164    where
165        I: Iterator<Item = V>,
166    {
167        let (size_hint, _) = addresses.size_hint();
168        let send_queue = addresses.into_iter().peekable();
169        let (sender, receiver) = mpsc::unbounded_channel();
170
171        let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel);
172        if self
173            .inner
174            .round_sender
175            .send(RoundMessage::Subscribe {
176                sequence_number,
177                sender,
178            })
179            .is_err()
180        {
181            panic!("Receiver closed");
182        }
183
184        MeasureManyStream {
185            pinger: self,
186            send_queue,
187            in_flight: HashMap::with_capacity(size_hint),
188            receiver,
189            sequence_number,
190        }
191    }
192}
193
194/// A [`Stream`] of ping responses.
195///
196/// No kind of `rtt` timeout is implemented, so an external mechanism
197/// like [`tokio::time::timeout`] should be used to prevent the program
198/// from hanging indefinitely.
199///
200/// Leaking this method might crate a slowly forever growing memory leak.
201///
202/// [`Stream`]: futures_core::Stream
203/// [`tokio::time::timeout`]: tokio::time::timeout
204pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
205    pinger: &'a Pinger<V>,
206    send_queue: Peekable<I>,
207    in_flight: HashMap<V, Instant>,
208    receiver: mpsc::UnboundedReceiver<(V, Instant)>,
209    sequence_number: u16,
210}
211
212impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
213    pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
214        // Try to receive a response (may be from a different round)
215        if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
216            return Poll::Ready(maybe_reply);
217        }
218
219        // Try to send ICMP echo requests
220        self.poll_next_icmp_replies(cx);
221
222        // Check if we're done: no more addresses to send AND no responses pending
223        if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
224            return Poll::Ready(None);
225        }
226
227        Poll::Pending
228    }
229
230    fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
231        while let Some(&addr) = self.send_queue.peek() {
232            let payload = rand::random::<[u8; 64]>();
233
234            let packet = EchoRequestPacket::new(
235                self.pinger.inner.identifier,
236                self.sequence_number,
237                &payload,
238            );
239            match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) {
240                Poll::Ready(_) => {
241                    let sent_at = Instant::now();
242
243                    let taken_addr = self.send_queue.next();
244                    debug_assert!(taken_addr.is_some());
245
246                    self.in_flight.insert(addr, sent_at);
247                }
248                Poll::Pending => break,
249            }
250        }
251    }
252
253    fn poll_next_from_different_round(
254        &mut self,
255        cx: &mut Context<'_>,
256    ) -> Poll<Option<(V, Duration)>> {
257        loop {
258            match self.receiver.poll_recv(cx) {
259                Poll::Pending => return Poll::Pending,
260                Poll::Ready(Some((addr, recv_instant))) => {
261                    if let Some(send_instant) = self.in_flight.remove(&addr) {
262                        let rtt = recv_instant - send_instant;
263                        return Poll::Ready(Some((addr, rtt)));
264                    }
265                }
266                Poll::Ready(None) => return Poll::Ready(None),
267            }
268        }
269    }
270}
271
272#[cfg(feature = "stream")]
273impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
274    type Item = (V, Duration);
275
276    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277        self.as_mut().poll_next_unpin(cx)
278    }
279}
280
281impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
282    fn drop(&mut self) {
283        let _ = self
284            .pinger
285            .inner
286            .round_sender
287            .send(RoundMessage::Unsubscribe {
288                sequence_number: self.sequence_number,
289            });
290    }
291}