1#[cfg(feature = "stream")]
2use std::pin::Pin;
3use std::{
4 collections::HashMap,
5 future::poll_fn,
6 io,
7 iter::Peekable,
8 net::{Ipv4Addr, Ipv6Addr},
9 sync::{
10 Arc,
11 atomic::{AtomicU16, Ordering},
12 },
13 task::{Context, Poll},
14 time::Duration,
15};
16
17use bytes::BytesMut;
18#[cfg(feature = "stream")]
19use futures_core::Stream;
20use tokio::{
21 sync::mpsc::{self, error::TryRecvError},
22 time::Instant,
23};
24
25use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
26
27pub type V4Pinger = Pinger<Ipv4Addr>;
29pub type V6Pinger = Pinger<Ipv6Addr>;
31
32pub struct Pinger<V: IpVersion> {
34 inner: Arc<InnerPinger<V>>,
35}
36
37struct InnerPinger<V: IpVersion> {
38 raw: RawPinger<V>,
39 round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
40 identifier: u16,
41 sequence_number: AtomicU16,
42}
43
44enum RoundMessage<V: IpVersion> {
45 Subscribe {
46 sequence_number: u16,
47 sender: mpsc::UnboundedSender<(V, Instant)>,
48 },
49 Unsubscribe {
50 sequence_number: u16,
51 },
52}
53
54enum PollResult<V: IpVersion> {
55 Subscription(RoundMessage<V>),
56 Packet(crate::packet::EchoReplyPacket<V>),
57}
58
59impl<V: IpVersion> Pinger<V> {
60 pub fn new() -> io::Result<Self> {
67 let raw = RawPinger::new()?;
68
69 let identifier = rand::random::<u16>();
70
71 let (sender, mut receiver) = mpsc::unbounded_channel();
72
73 let inner = Arc::new(InnerPinger {
74 raw,
75 round_sender: sender,
76 identifier,
77 sequence_number: AtomicU16::new(0),
78 });
79
80 let inner_recv = Arc::clone(&inner);
82 tokio::spawn(async move {
83 let mut subscribers: HashMap<u16, mpsc::UnboundedSender<(V, Instant)>> = HashMap::new();
84 let mut recv_buf = BytesMut::new();
86
87 loop {
88 let result = poll_fn(|cx| {
96 match receiver.try_recv() {
98 Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
99 Err(TryRecvError::Empty) => {
100 }
102 Err(TryRecvError::Disconnected) => return Poll::Ready(None),
103 }
104
105 if let Poll::Ready(Ok(packet)) = inner_recv.raw.poll_recv(&mut recv_buf, cx) {
107 return Poll::Ready(Some(PollResult::Packet(packet)));
108 }
109 match receiver.poll_recv(cx) {
114 Poll::Ready(Some(msg)) => {
115 return Poll::Ready(Some(PollResult::Subscription(msg)));
116 }
117 Poll::Ready(None) => return Poll::Ready(None),
118 Poll::Pending => {}
119 }
120
121 Poll::Pending
122 })
123 .await;
124
125 match result {
126 Some(PollResult::Subscription(RoundMessage::Subscribe {
127 sequence_number,
128 sender,
129 })) => {
130 subscribers.insert(sequence_number, sender);
131 }
132 Some(PollResult::Subscription(RoundMessage::Unsubscribe {
133 sequence_number,
134 })) => {
135 subscribers.remove(&sequence_number);
136 }
137 Some(PollResult::Packet(packet)) => {
138 let recv_instant = Instant::now();
139
140 let packet_source = packet.source();
141 let packet_sequence_number = packet.sequence_number();
142
143 if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
144 if subscriber.send((packet_source, recv_instant)).is_err() {
145 subscribers.remove(&packet_sequence_number);
146 }
147 }
148 }
149 None => return, }
151 }
152 });
153
154 Ok(Self { inner })
155 }
156
157 pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
164 where
165 I: Iterator<Item = V>,
166 {
167 let (size_hint, _) = addresses.size_hint();
168 let send_queue = addresses.into_iter().peekable();
169 let (sender, receiver) = mpsc::unbounded_channel();
170
171 let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel);
172 if self
173 .inner
174 .round_sender
175 .send(RoundMessage::Subscribe {
176 sequence_number,
177 sender,
178 })
179 .is_err()
180 {
181 panic!("Receiver closed");
182 }
183
184 MeasureManyStream {
185 pinger: self,
186 send_queue,
187 in_flight: HashMap::with_capacity(size_hint),
188 receiver,
189 sequence_number,
190 }
191 }
192}
193
194pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
205 pinger: &'a Pinger<V>,
206 send_queue: Peekable<I>,
207 in_flight: HashMap<V, Instant>,
208 receiver: mpsc::UnboundedReceiver<(V, Instant)>,
209 sequence_number: u16,
210}
211
212impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
213 pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
214 if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
216 return Poll::Ready(maybe_reply);
217 }
218
219 self.poll_next_icmp_replies(cx);
221
222 if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
224 return Poll::Ready(None);
225 }
226
227 Poll::Pending
228 }
229
230 fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
231 while let Some(&addr) = self.send_queue.peek() {
232 let payload = rand::random::<[u8; 64]>();
233
234 let packet = EchoRequestPacket::new(
235 self.pinger.inner.identifier,
236 self.sequence_number,
237 &payload,
238 );
239 match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) {
240 Poll::Ready(_) => {
241 let sent_at = Instant::now();
242
243 let taken_addr = self.send_queue.next();
244 debug_assert!(taken_addr.is_some());
245
246 self.in_flight.insert(addr, sent_at);
247 }
248 Poll::Pending => break,
249 }
250 }
251 }
252
253 fn poll_next_from_different_round(
254 &mut self,
255 cx: &mut Context<'_>,
256 ) -> Poll<Option<(V, Duration)>> {
257 loop {
258 match self.receiver.poll_recv(cx) {
259 Poll::Pending => return Poll::Pending,
260 Poll::Ready(Some((addr, recv_instant))) => {
261 if let Some(send_instant) = self.in_flight.remove(&addr) {
262 let rtt = recv_instant - send_instant;
263 return Poll::Ready(Some((addr, rtt)));
264 }
265 }
266 Poll::Ready(None) => return Poll::Ready(None),
267 }
268 }
269 }
270}
271
272#[cfg(feature = "stream")]
273impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
274 type Item = (V, Duration);
275
276 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277 self.as_mut().poll_next_unpin(cx)
278 }
279}
280
281impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
282 fn drop(&mut self) {
283 let _ = self
284 .pinger
285 .inner
286 .round_sender
287 .send(RoundMessage::Unsubscribe {
288 sequence_number: self.sequence_number,
289 });
290 }
291}