1use std::{
2 collections::HashMap,
3 io,
4 iter::Peekable,
5 net::{Ipv4Addr, Ipv6Addr},
6 sync::{
7 Arc,
8 atomic::{AtomicU16, Ordering},
9 },
10 task::{Context, Poll},
11 time::Duration,
12};
13#[cfg(feature = "stream")]
14use std::{pin::Pin, task::ready};
15
16#[cfg(feature = "stream")]
17use futures_core::Stream;
18use tokio::{
19 sync::mpsc::{self, error::TryRecvError},
20 time::Instant,
21};
22
23use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
24
25pub type V4Pinger = Pinger<Ipv4Addr>;
27pub type V6Pinger = Pinger<Ipv6Addr>;
29
30pub struct Pinger<V: IpVersion> {
32 inner: Arc<InnerPinger<V>>,
33}
34
35struct InnerPinger<V: IpVersion> {
36 raw: RawPinger<V>,
37 round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
38 identifier: u16,
39 sequence_number: AtomicU16,
40}
41
42enum RoundMessage<V: IpVersion> {
43 Subscribe {
44 sequence_number: u16,
45 sender: mpsc::UnboundedSender<(V, Instant)>,
46 },
47 Unsubscribe {
48 sequence_number: u16,
49 },
50}
51
52impl<V: IpVersion> Pinger<V> {
53 pub fn new() -> io::Result<Self> {
60 let raw = RawPinger::new()?;
61
62 let identifier = rand::random::<u16>();
63
64 let (sender, mut receiver) = mpsc::unbounded_channel();
65
66 let inner = Arc::new(InnerPinger {
67 raw,
68 round_sender: sender,
69 identifier,
70 sequence_number: AtomicU16::new(0),
71 });
72
73 let inner_recv = Arc::clone(&inner);
75 tokio::spawn(async move {
76 let mut subscribers: HashMap<u16, mpsc::UnboundedSender<(V, Instant)>> = HashMap::new();
77
78 loop {
79 loop {
81 match receiver.try_recv() {
82 Ok(RoundMessage::Subscribe {
83 sequence_number,
84 sender,
85 }) => {
86 subscribers.insert(sequence_number, sender);
87 }
88 Ok(RoundMessage::Unsubscribe { sequence_number }) => {
89 drop(subscribers.remove(&sequence_number));
90 }
91 Err(TryRecvError::Empty) => break,
92 Err(TryRecvError::Disconnected) => return,
93 }
94 }
95
96 let packet = match inner_recv.raw.recv().await {
98 Ok(packet) => packet,
99 Err(_) => continue,
100 };
101
102 let recv_instant = Instant::now();
103
104 let packet_source = packet.source();
105 let packet_sequence_number = packet.sequence_number();
106
107 if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
108 if subscriber.send((packet_source, recv_instant)).is_err() {
109 subscribers.remove(&packet_sequence_number);
110 }
111 }
112 }
113 });
114
115 Ok(Self { inner })
116 }
117
118 pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
125 where
126 I: Iterator<Item = V>,
127 {
128 let (size_hint, _) = addresses.size_hint();
129 let send_queue = addresses.into_iter().peekable();
130 let (sender, receiver) = mpsc::unbounded_channel();
131
132 let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel);
133 if self
134 .inner
135 .round_sender
136 .send(RoundMessage::Subscribe {
137 sequence_number,
138 sender,
139 })
140 .is_err()
141 {
142 panic!("Receiver closed");
143 }
144
145 MeasureManyStream {
146 pinger: self,
147 send_queue,
148 in_flight: HashMap::with_capacity(size_hint),
149 receiver,
150 sequence_number,
151 }
152 }
153}
154
155pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
166 pinger: &'a Pinger<V>,
167 send_queue: Peekable<I>,
168 in_flight: HashMap<V, Instant>,
169 receiver: mpsc::UnboundedReceiver<(V, Instant)>,
170 sequence_number: u16,
171}
172
173impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
174 pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(V, Duration)> {
175 if let Poll::Ready(Some((addr, rtt))) = self.poll_next_from_different_round(cx) {
177 return Poll::Ready((addr, rtt));
178 }
179
180 self.poll_next_icmp_replies(cx);
182
183 Poll::Pending
184 }
185
186 fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
187 while let Some(&addr) = self.send_queue.peek() {
188 let payload = rand::random::<[u8; 64]>();
189
190 let packet = EchoRequestPacket::new(
191 self.pinger.inner.identifier,
192 self.sequence_number,
193 &payload,
194 );
195 match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) {
196 Poll::Ready(_) => {
197 let sent_at = Instant::now();
198
199 let taken_addr = self.send_queue.next();
200 debug_assert!(taken_addr.is_some());
201
202 self.in_flight.insert(addr, sent_at);
203 }
204 Poll::Pending => break,
205 }
206 }
207 }
208
209 fn poll_next_from_different_round(
210 &mut self,
211 cx: &mut Context<'_>,
212 ) -> Poll<Option<(V, Duration)>> {
213 loop {
214 match self.receiver.poll_recv(cx) {
215 Poll::Pending => return Poll::Pending,
216 Poll::Ready(Some((addr, recv_instant))) => {
217 if let Some(send_instant) = self.in_flight.remove(&addr) {
218 let rtt = recv_instant - send_instant;
219 return Poll::Ready(Some((addr, rtt)));
220 }
221 }
222 Poll::Ready(None) => return Poll::Ready(None),
223 }
224 }
225 }
226}
227
228#[cfg(feature = "stream")]
229impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
230 type Item = (V, Duration);
231
232 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233 let result = ready!(self.as_mut().poll_next_unpin(cx));
234 Poll::Ready(Some(result))
235 }
236}
237
238impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
239 fn drop(&mut self) {
240 let _ = self
241 .pinger
242 .inner
243 .round_sender
244 .send(RoundMessage::Unsubscribe {
245 sequence_number: self.sequence_number,
246 });
247 }
248}