massping/
pinger.rs

1use std::{
2    collections::HashMap,
3    io,
4    iter::Peekable,
5    net::{Ipv4Addr, Ipv6Addr},
6    sync::{
7        Arc,
8        atomic::{AtomicU16, Ordering},
9    },
10    task::{Context, Poll},
11    time::Duration,
12};
13#[cfg(feature = "stream")]
14use std::{pin::Pin, task::ready};
15
16#[cfg(feature = "stream")]
17use futures_core::Stream;
18use tokio::{
19    sync::mpsc::{self, error::TryRecvError},
20    time::Instant,
21};
22
23use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
24
25/// A pinger for IPv4 addresses
26pub type V4Pinger = Pinger<Ipv4Addr>;
27/// A pinger for IPv6 addresses
28pub type V6Pinger = Pinger<Ipv6Addr>;
29
30/// A pinger for [`IpVersion`] (either [`Ipv4Addr`] or [`Ipv6Addr`]).
31pub struct Pinger<V: IpVersion> {
32    inner: Arc<InnerPinger<V>>,
33}
34
35struct InnerPinger<V: IpVersion> {
36    raw: RawPinger<V>,
37    round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
38    identifier: u16,
39    sequence_number: AtomicU16,
40}
41
42enum RoundMessage<V: IpVersion> {
43    Subscribe {
44        sequence_number: u16,
45        sender: mpsc::UnboundedSender<(V, Instant)>,
46    },
47    Unsubscribe {
48        sequence_number: u16,
49    },
50}
51
52impl<V: IpVersion> Pinger<V> {
53    /// Construct a new `Pinger`.
54    ///
55    /// For maximum efficiency the same instance of `Pinger` should
56    /// be used for as long as possible, altough it might also
57    /// be beneficial to `Drop` the `Pinger` and recreate it if
58    /// you are not going to be sending pings for a long period of time.
59    pub fn new() -> io::Result<Self> {
60        let raw = RawPinger::new()?;
61
62        let identifier = rand::random::<u16>();
63
64        let (sender, mut receiver) = mpsc::unbounded_channel();
65
66        let inner = Arc::new(InnerPinger {
67            raw,
68            round_sender: sender,
69            identifier,
70            sequence_number: AtomicU16::new(0),
71        });
72
73        // Spawn async receive task using the same socket
74        let inner_recv = Arc::clone(&inner);
75        tokio::spawn(async move {
76            let mut subscribers: HashMap<u16, mpsc::UnboundedSender<(V, Instant)>> = HashMap::new();
77
78            loop {
79                // Process any pending subscription changes
80                loop {
81                    match receiver.try_recv() {
82                        Ok(RoundMessage::Subscribe {
83                            sequence_number,
84                            sender,
85                        }) => {
86                            subscribers.insert(sequence_number, sender);
87                        }
88                        Ok(RoundMessage::Unsubscribe { sequence_number }) => {
89                            drop(subscribers.remove(&sequence_number));
90                        }
91                        Err(TryRecvError::Empty) => break,
92                        Err(TryRecvError::Disconnected) => return,
93                    }
94                }
95
96                // Receive next packet (with DGRAM sockets, kernel handles routing)
97                let packet = match inner_recv.raw.recv().await {
98                    Ok(packet) => packet,
99                    Err(_) => continue,
100                };
101
102                let recv_instant = Instant::now();
103
104                let packet_source = packet.source();
105                let packet_sequence_number = packet.sequence_number();
106
107                if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
108                    if subscriber.send((packet_source, recv_instant)).is_err() {
109                        subscribers.remove(&packet_sequence_number);
110                    }
111                }
112            }
113        });
114
115        Ok(Self { inner })
116    }
117
118    /// Ping `addresses`
119    ///
120    /// Creates [`MeasureManyStream`] which **lazily** sends ping
121    /// requests and [`Stream`]s the responses as they arrive.
122    ///
123    /// [`Stream`]: futures_core::Stream
124    pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
125    where
126        I: Iterator<Item = V>,
127    {
128        let (size_hint, _) = addresses.size_hint();
129        let send_queue = addresses.into_iter().peekable();
130        let (sender, receiver) = mpsc::unbounded_channel();
131
132        let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel);
133        if self
134            .inner
135            .round_sender
136            .send(RoundMessage::Subscribe {
137                sequence_number,
138                sender,
139            })
140            .is_err()
141        {
142            panic!("Receiver closed");
143        }
144
145        MeasureManyStream {
146            pinger: self,
147            send_queue,
148            in_flight: HashMap::with_capacity(size_hint),
149            receiver,
150            sequence_number,
151        }
152    }
153}
154
155/// A [`Stream`] of ping responses.
156///
157/// No kind of `rtt` timeout is implemented, so an external mechanism
158/// like [`tokio::time::timeout`] should be used to prevent the program
159/// from hanging indefinitely.
160///
161/// Leaking this method might crate a slowly forever growing memory leak.
162///
163/// [`Stream`]: futures_core::Stream
164/// [`tokio::time::timeout`]: tokio::time::timeout
165pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
166    pinger: &'a Pinger<V>,
167    send_queue: Peekable<I>,
168    in_flight: HashMap<V, Instant>,
169    receiver: mpsc::UnboundedReceiver<(V, Instant)>,
170    sequence_number: u16,
171}
172
173impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
174    pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(V, Duration)> {
175        // Try to see if another `MeasureManyStream` got it
176        if let Poll::Ready(Some((addr, rtt))) = self.poll_next_from_different_round(cx) {
177            return Poll::Ready((addr, rtt));
178        }
179
180        // Try to send ICMP echo requests
181        self.poll_next_icmp_replies(cx);
182
183        Poll::Pending
184    }
185
186    fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
187        while let Some(&addr) = self.send_queue.peek() {
188            let payload = rand::random::<[u8; 64]>();
189
190            let packet = EchoRequestPacket::new(
191                self.pinger.inner.identifier,
192                self.sequence_number,
193                &payload,
194            );
195            match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) {
196                Poll::Ready(_) => {
197                    let sent_at = Instant::now();
198
199                    let taken_addr = self.send_queue.next();
200                    debug_assert!(taken_addr.is_some());
201
202                    self.in_flight.insert(addr, sent_at);
203                }
204                Poll::Pending => break,
205            }
206        }
207    }
208
209    fn poll_next_from_different_round(
210        &mut self,
211        cx: &mut Context<'_>,
212    ) -> Poll<Option<(V, Duration)>> {
213        loop {
214            match self.receiver.poll_recv(cx) {
215                Poll::Pending => return Poll::Pending,
216                Poll::Ready(Some((addr, recv_instant))) => {
217                    if let Some(send_instant) = self.in_flight.remove(&addr) {
218                        let rtt = recv_instant - send_instant;
219                        return Poll::Ready(Some((addr, rtt)));
220                    }
221                }
222                Poll::Ready(None) => return Poll::Ready(None),
223            }
224        }
225    }
226}
227
228#[cfg(feature = "stream")]
229impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
230    type Item = (V, Duration);
231
232    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        let result = ready!(self.as_mut().poll_next_unpin(cx));
234        Poll::Ready(Some(result))
235    }
236}
237
238impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
239    fn drop(&mut self) {
240        let _ = self
241            .pinger
242            .inner
243            .round_sender
244            .send(RoundMessage::Unsubscribe {
245                sequence_number: self.sequence_number,
246            });
247    }
248}