#[cfg(feature = "stream")]
use std::pin::Pin;
use std::{
collections::HashMap,
future::poll_fn,
io,
iter::Peekable,
net::{Ipv4Addr, Ipv6Addr},
sync::{
Arc,
atomic::{AtomicU16, Ordering},
},
task::{Context, Poll},
time::Duration,
};
use bytes::BytesMut;
#[cfg(feature = "stream")]
use futures_core::Stream;
use tokio::{
sync::mpsc::{self, error::TryRecvError},
time::Instant,
};
use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
pub type V4Pinger = Pinger<Ipv4Addr>;
pub type V6Pinger = Pinger<Ipv6Addr>;
pub struct Pinger<V: IpVersion> {
inner: Arc<InnerPinger<V>>,
}
struct InnerPinger<V: IpVersion> {
raw: RawPinger<V>,
round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
identifier: u16,
sequence_number: AtomicU16,
}
enum RoundMessage<V: IpVersion> {
Subscribe {
sequence_number: u16,
sender: mpsc::UnboundedSender<(V, Instant)>,
},
Unsubscribe {
sequence_number: u16,
},
}
enum PollResult<V: IpVersion> {
Subscription(RoundMessage<V>),
Packet(crate::packet::EchoReplyPacket<V>),
}
impl<V: IpVersion> Pinger<V> {
pub fn new() -> io::Result<Self> {
let raw = RawPinger::new()?;
let identifier = rand::random::<u16>();
let (sender, mut receiver) = mpsc::unbounded_channel();
let inner = Arc::new(InnerPinger {
raw,
round_sender: sender,
identifier,
sequence_number: AtomicU16::new(0),
});
let inner_recv = Arc::clone(&inner);
tokio::spawn(async move {
let mut subscribers: HashMap<u16, mpsc::UnboundedSender<(V, Instant)>> = HashMap::new();
let mut recv_buf = BytesMut::new();
loop {
let result = poll_fn(|cx| {
match receiver.try_recv() {
Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
Err(TryRecvError::Empty) => {
}
Err(TryRecvError::Disconnected) => return Poll::Ready(None),
}
if let Poll::Ready(Ok(packet)) = inner_recv.raw.poll_recv(&mut recv_buf, cx) {
return Poll::Ready(Some(PollResult::Packet(packet)));
}
match receiver.poll_recv(cx) {
Poll::Ready(Some(msg)) => {
return Poll::Ready(Some(PollResult::Subscription(msg)));
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
Poll::Pending
})
.await;
match result {
Some(PollResult::Subscription(RoundMessage::Subscribe {
sequence_number,
sender,
})) => {
subscribers.insert(sequence_number, sender);
}
Some(PollResult::Subscription(RoundMessage::Unsubscribe {
sequence_number,
})) => {
subscribers.remove(&sequence_number);
}
Some(PollResult::Packet(packet)) => {
let recv_instant = Instant::now();
let packet_source = packet.source();
let packet_sequence_number = packet.sequence_number();
if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
if subscriber.send((packet_source, recv_instant)).is_err() {
subscribers.remove(&packet_sequence_number);
}
}
}
None => return, }
}
});
Ok(Self { inner })
}
pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
where
I: Iterator<Item = V>,
{
let (size_hint, _) = addresses.size_hint();
let send_queue = addresses.into_iter().peekable();
let (sender, receiver) = mpsc::unbounded_channel();
let sequence_number = self.inner.sequence_number.fetch_add(1, Ordering::AcqRel);
if self
.inner
.round_sender
.send(RoundMessage::Subscribe {
sequence_number,
sender,
})
.is_err()
{
panic!("Receiver closed");
}
MeasureManyStream {
pinger: self,
send_queue,
in_flight: HashMap::with_capacity(size_hint),
receiver,
sequence_number,
}
}
}
pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
pinger: &'a Pinger<V>,
send_queue: Peekable<I>,
in_flight: HashMap<V, Instant>,
receiver: mpsc::UnboundedReceiver<(V, Instant)>,
sequence_number: u16,
}
impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
return Poll::Ready(maybe_reply);
}
self.poll_next_icmp_replies(cx);
if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
return Poll::Ready(None);
}
Poll::Pending
}
fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
while let Some(&addr) = self.send_queue.peek() {
let payload = rand::random::<[u8; 64]>();
let packet = EchoRequestPacket::new(
self.pinger.inner.identifier,
self.sequence_number,
&payload,
);
match self.pinger.inner.raw.poll_send_to(cx, addr, &packet) {
Poll::Ready(_) => {
let sent_at = Instant::now();
let taken_addr = self.send_queue.next();
debug_assert!(taken_addr.is_some());
self.in_flight.insert(addr, sent_at);
}
Poll::Pending => break,
}
}
}
fn poll_next_from_different_round(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(V, Duration)>> {
loop {
match self.receiver.poll_recv(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some((addr, recv_instant))) => {
if let Some(send_instant) = self.in_flight.remove(&addr) {
let rtt = recv_instant - send_instant;
return Poll::Ready(Some((addr, rtt)));
}
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
}
}
#[cfg(feature = "stream")]
impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
type Item = (V, Duration);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().poll_next_unpin(cx)
}
}
impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
fn drop(&mut self) {
let _ = self
.pinger
.inner
.round_sender
.send(RoundMessage::Unsubscribe {
sequence_number: self.sequence_number,
});
}
}