massping/pinger.rs
1#[cfg(feature = "stream")]
2use std::pin::Pin;
3use std::{
4 collections::HashMap,
5 future::poll_fn,
6 io,
7 iter::Peekable,
8 net::{Ipv4Addr, Ipv6Addr},
9 sync::{
10 Arc,
11 atomic::{AtomicU64, Ordering},
12 },
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use bytes::{Bytes, BytesMut};
18#[cfg(feature = "stream")]
19use futures_core::Stream;
20use tokio::{
21 sync::mpsc::{self, error::TryRecvError},
22 time::Instant,
23};
24
25use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
26
27/// A pinger for IPv4 addresses
28pub type V4Pinger = Pinger<Ipv4Addr>;
29/// A pinger for IPv6 addresses
30pub type V6Pinger = Pinger<Ipv6Addr>;
31
32/// A pinger for [`IpVersion`] (either [`Ipv4Addr`] or [`Ipv6Addr`]).
33///
34/// Cloning is cheap: clones share the same socket and background
35/// receive task, which shut down when the last clone is dropped.
36pub struct Pinger<V: IpVersion> {
37 inner: Arc<InnerPinger<V>>,
38 // Kept out of `InnerPinger` (which the background receive task holds)
39 // so that dropping the last `Pinger` clone disconnects the channel,
40 // telling the background task to shut down and release the socket.
41 round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
42}
43
44impl<V: IpVersion> Clone for Pinger<V> {
45 fn clone(&self) -> Self {
46 Self {
47 inner: Arc::clone(&self.inner),
48 round_sender: self.round_sender.clone(),
49 }
50 }
51}
52
53struct InnerPinger<V: IpVersion> {
54 raw: RawPinger<V>,
55 next_round_id: AtomicU64,
56}
57
58// Each `measure_many` round gets a unique `u64` id; the wire sequence
59// number is its lower 16 bits. The full id lets the receive task tell
60// rounds apart after the sequence number wraps around.
61enum RoundMessage<V: IpVersion> {
62 Subscribe {
63 round_id: u64,
64 expected_payload: Bytes,
65 sender: mpsc::UnboundedSender<(V, Instant)>,
66 },
67 Unsubscribe {
68 round_id: u64,
69 },
70}
71
72struct Subscriber<V: IpVersion> {
73 round_id: u64,
74 expected_payload: Bytes,
75 sender: mpsc::UnboundedSender<(V, Instant)>,
76}
77
78enum PollResult<V: IpVersion> {
79 Subscription(RoundMessage<V>),
80 Packet(crate::packet::EchoReplyPacket<V>),
81}
82
83impl<V: IpVersion> Pinger<V> {
84 /// Construct a new `Pinger`.
85 ///
86 /// For maximum efficiency the same instance of `Pinger` should
87 /// be used for as long as possible, although it might also
88 /// be beneficial to `Drop` the `Pinger` and recreate it if
89 /// you are not going to be sending pings for a long period of time.
90 ///
91 /// # Panics
92 ///
93 /// Panics if called from outside a tokio runtime, as it spawns a
94 /// background receive task.
95 pub fn new() -> io::Result<Self> {
96 let raw = RawPinger::new()?;
97
98 let (sender, mut receiver) = mpsc::unbounded_channel();
99
100 let inner = Arc::new(InnerPinger {
101 raw,
102 next_round_id: AtomicU64::new(0),
103 });
104
105 // Spawn async receive task using the same socket.
106 // It runs until `receiver` disconnects, which happens when the
107 // `Pinger` holding the only sender is dropped.
108 let inner_recv = Arc::clone(&inner);
109 tokio::spawn(async move {
110 let mut subscribers: HashMap<u16, Subscriber<V>> = HashMap::new();
111 // Buffer kept outside poll_fn so it persists across polls.
112 let mut recv_buf = BytesMut::new();
113
114 loop {
115 // Poll both subscription channel and socket in the same waker context.
116 // This ensures we wake on either event, which is required for
117 // single-threaded runtimes where we can't rely on concurrent execution.
118 //
119 // Note: We use try_recv() before poll_recv() as a fast path optimization.
120 // Benchmarks show this is ~2x faster when messages are already queued
121 // (~15ns vs ~25ns per iteration).
122 let result = poll_fn(|cx| {
123 // Fast path: check for subscription changes (non-blocking, no waker)
124 match receiver.try_recv() {
125 Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
126 Err(TryRecvError::Empty) => {
127 // Continue - poll_recv() below will register the waker for this channel
128 }
129 Err(TryRecvError::Disconnected) => return Poll::Ready(None),
130 }
131
132 // Try to receive an ICMP packet
133 match inner_recv.raw.poll_recv(&mut recv_buf, cx) {
134 Poll::Ready(Ok(packet)) => {
135 return Poll::Ready(Some(PollResult::Packet(packet)));
136 }
137 Poll::Ready(Err(_)) => {
138 // Receiving failed (typically a transient kernel
139 // resource error). The socket readiness was
140 // consumed without registering a waker, so ask to
141 // be polled again right away; parking here would
142 // suspend reply processing until an unrelated
143 // subscription message wakes the task.
144 cx.waker().wake_by_ref();
145 }
146 Poll::Pending => {}
147 }
148
149 // Register waker for subscription channel
150 // We need to wake up when new subscriptions arrive
151 match receiver.poll_recv(cx) {
152 Poll::Ready(Some(msg)) => {
153 return Poll::Ready(Some(PollResult::Subscription(msg)));
154 }
155 Poll::Ready(None) => return Poll::Ready(None),
156 Poll::Pending => {}
157 }
158
159 Poll::Pending
160 })
161 .await;
162
163 match result {
164 Some(PollResult::Subscription(RoundMessage::Subscribe {
165 round_id,
166 expected_payload,
167 sender,
168 })) => {
169 // A new round may displace a still-subscribed round
170 // whose sequence number collided after wraparound;
171 // the displaced round could not be served anyway as
172 // replies can only be told apart by sequence number.
173 subscribers.insert(
174 round_id as u16,
175 Subscriber {
176 round_id,
177 expected_payload,
178 sender,
179 },
180 );
181 }
182 Some(PollResult::Subscription(RoundMessage::Unsubscribe { round_id })) => {
183 let sequence_number = round_id as u16;
184 // Only unsubscribe if the slot still belongs to this
185 // round: after sequence number wraparound it may have
186 // been taken over by a newer round, which must keep
187 // receiving replies.
188 if subscribers
189 .get(&sequence_number)
190 .is_some_and(|subscriber| subscriber.round_id == round_id)
191 {
192 subscribers.remove(&sequence_number);
193 }
194 }
195 Some(PollResult::Packet(packet)) => {
196 let recv_instant = Instant::now();
197
198 let packet_source = packet.source();
199 let packet_sequence_number = packet.sequence_number();
200
201 if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
202 // An echo reply mirrors the request's payload, so
203 // a mismatch means the reply wasn't produced by
204 // this round (e.g. a reply to an older round whose
205 // sequence number collided after wraparound, or
206 // blindly spoofed cross-traffic). Discard it.
207 let payload_matches =
208 packet.payload() == &subscriber.expected_payload[..];
209
210 if payload_matches
211 && subscriber
212 .sender
213 .send((packet_source, recv_instant))
214 .is_err()
215 {
216 subscribers.remove(&packet_sequence_number);
217 }
218 }
219 }
220 None => return, // Channel closed
221 }
222 }
223 });
224
225 Ok(Self {
226 inner,
227 round_sender: sender,
228 })
229 }
230
231 /// Ping `addresses`
232 ///
233 /// Creates [`MeasureManyStream`] which **lazily** sends ping
234 /// requests and [`Stream`]s the responses as they arrive.
235 ///
236 /// Replies are matched by source address, so an address that appears
237 /// multiple times is only pinged once per round and yields a single
238 /// measurement.
239 ///
240 /// # Panics
241 ///
242 /// Panics if the background receive task has terminated, which only
243 /// happens when the runtime the `Pinger` was created on has been
244 /// shut down.
245 ///
246 /// [`Stream`]: futures_core::Stream
247 pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
248 where
249 I: Iterator<Item = V>,
250 {
251 let (size_hint, _) = addresses.size_hint();
252 let send_queue = addresses.into_iter().peekable();
253 let (sender, receiver) = mpsc::unbounded_channel();
254
255 // Relaxed is enough: the counter is a pure id allocator, no other
256 // memory is synchronized through it.
257 let round_id = self.inner.next_round_id.fetch_add(1, Ordering::Relaxed);
258
259 // The same packet is reused for every address of the round. Its
260 // random payload lets the receive task discard replies that don't
261 // belong to this round.
262 //
263 // The identifier is irrelevant: the kernel overwrites it with the
264 // socket's own identifier, which it also uses to route echo replies
265 // back to this socket.
266 let payload = rand::random::<[u8; 64]>();
267 let packet = EchoRequestPacket::new(0, round_id as u16, &payload);
268
269 if self
270 .round_sender
271 .send(RoundMessage::Subscribe {
272 round_id,
273 expected_payload: packet.payload(),
274 sender,
275 })
276 .is_err()
277 {
278 panic!("Receiver closed");
279 }
280
281 MeasureManyStream {
282 pinger: self,
283 packet,
284 send_queue,
285 in_flight: HashMap::with_capacity(size_hint),
286 receiver,
287 round_id,
288 }
289 }
290}
291
292/// A [`Stream`] of ping responses.
293///
294/// No kind of `rtt` timeout is implemented, so an external mechanism
295/// like [`tokio::time::timeout`] should be used to prevent the program
296/// from hanging indefinitely.
297///
298/// Leaking this stream may create a memory leak that lasts until the
299/// [`Pinger`] is dropped.
300///
301/// [`Stream`]: futures_core::Stream
302/// [`tokio::time::timeout`]: tokio::time::timeout
303#[must_use = "streams do nothing unless polled"]
304pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
305 pinger: &'a Pinger<V>,
306 packet: EchoRequestPacket<V>,
307 send_queue: Peekable<I>,
308 in_flight: HashMap<V, Instant>,
309 receiver: mpsc::UnboundedReceiver<(V, Instant)>,
310 round_id: u64,
311}
312
313impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
314 pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
315 // Try to receive a response (may be from a different round)
316 if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
317 return Poll::Ready(maybe_reply);
318 }
319
320 // Try to send ICMP echo requests
321 self.poll_next_icmp_replies(cx);
322
323 // Check if we're done: no more addresses to send AND no responses pending
324 if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
325 return Poll::Ready(None);
326 }
327
328 Poll::Pending
329 }
330
331 fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
332 while let Some(&addr) = self.send_queue.peek() {
333 // Replies are matched by source address within a round, so a
334 // second ping to an address that is still awaiting its reply
335 // could never produce a second measurement; it would only
336 // clobber the first ping's start time. Skip the duplicate.
337 if self.in_flight.contains_key(&addr) {
338 self.send_queue.next();
339 continue;
340 }
341
342 match self.pinger.inner.raw.poll_send_to(cx, addr, &self.packet) {
343 Poll::Ready(result) => {
344 let sent_at = Instant::now();
345
346 let taken_addr = self.send_queue.next();
347 debug_assert!(taken_addr.is_some());
348
349 // If the send failed (e.g. no route to host) no reply
350 // can ever arrive, so don't track the address as
351 // in-flight or the stream would never terminate.
352 if result.is_ok() {
353 self.in_flight.insert(addr, sent_at);
354 }
355 }
356 Poll::Pending => {
357 // The socket only remembers the most recent waker per
358 // direction (`AsyncFd` semantics), so with multiple
359 // streams sharing the socket another stream could
360 // overwrite ours and we'd never be woken again. Sends
361 // only return `Pending` while the send buffer is full,
362 // which clears up quickly, so schedule an immediate
363 // re-poll instead of parking.
364 cx.waker().wake_by_ref();
365 break;
366 }
367 }
368 }
369 }
370
371 fn poll_next_from_different_round(
372 &mut self,
373 cx: &mut Context<'_>,
374 ) -> Poll<Option<(V, Duration)>> {
375 loop {
376 match self.receiver.poll_recv(cx) {
377 Poll::Pending => return Poll::Pending,
378 Poll::Ready(Some((addr, recv_instant))) => {
379 if let Some(send_instant) = self.in_flight.remove(&addr) {
380 let rtt = recv_instant - send_instant;
381 return Poll::Ready(Some((addr, rtt)));
382 }
383 }
384 Poll::Ready(None) => return Poll::Ready(None),
385 }
386 }
387 }
388}
389
390#[cfg(feature = "stream")]
391impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
392 type Item = (V, Duration);
393
394 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395 self.as_mut().poll_next_unpin(cx)
396 }
397}
398
399impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
400 fn drop(&mut self) {
401 let _ = self.pinger.round_sender.send(RoundMessage::Unsubscribe {
402 round_id: self.round_id,
403 });
404 }
405}