#[cfg(feature = "stream")]
use std::pin::Pin;
use std::{
collections::HashMap,
future::poll_fn,
io,
iter::Peekable,
net::{Ipv4Addr, Ipv6Addr},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
task::{Context, Poll},
time::Duration,
};
use bytes::{Bytes, BytesMut};
#[cfg(feature = "stream")]
use futures_core::Stream;
use tokio::{
sync::mpsc::{self, error::TryRecvError},
time::Instant,
};
use crate::{IpVersion, packet::EchoRequestPacket, raw_pinger::RawPinger};
pub type V4Pinger = Pinger<Ipv4Addr>;
pub type V6Pinger = Pinger<Ipv6Addr>;
pub struct Pinger<V: IpVersion> {
inner: Arc<InnerPinger<V>>,
round_sender: mpsc::UnboundedSender<RoundMessage<V>>,
}
impl<V: IpVersion> Clone for Pinger<V> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
round_sender: self.round_sender.clone(),
}
}
}
struct InnerPinger<V: IpVersion> {
raw: RawPinger<V>,
next_round_id: AtomicU64,
}
enum RoundMessage<V: IpVersion> {
Subscribe {
round_id: u64,
expected_payload: Bytes,
sender: mpsc::UnboundedSender<(V, Instant)>,
},
Unsubscribe {
round_id: u64,
},
}
struct Subscriber<V: IpVersion> {
round_id: u64,
expected_payload: Bytes,
sender: mpsc::UnboundedSender<(V, Instant)>,
}
enum PollResult<V: IpVersion> {
Subscription(RoundMessage<V>),
Packet(crate::packet::EchoReplyPacket<V>),
}
impl<V: IpVersion> Pinger<V> {
pub fn new() -> io::Result<Self> {
let raw = RawPinger::new()?;
let (sender, mut receiver) = mpsc::unbounded_channel();
let inner = Arc::new(InnerPinger {
raw,
next_round_id: AtomicU64::new(0),
});
let inner_recv = Arc::clone(&inner);
tokio::spawn(async move {
let mut subscribers: HashMap<u16, Subscriber<V>> = HashMap::new();
let mut recv_buf = BytesMut::new();
loop {
let result = poll_fn(|cx| {
match receiver.try_recv() {
Ok(msg) => return Poll::Ready(Some(PollResult::Subscription(msg))),
Err(TryRecvError::Empty) => {
}
Err(TryRecvError::Disconnected) => return Poll::Ready(None),
}
match inner_recv.raw.poll_recv(&mut recv_buf, cx) {
Poll::Ready(Ok(packet)) => {
return Poll::Ready(Some(PollResult::Packet(packet)));
}
Poll::Ready(Err(_)) => {
cx.waker().wake_by_ref();
}
Poll::Pending => {}
}
match receiver.poll_recv(cx) {
Poll::Ready(Some(msg)) => {
return Poll::Ready(Some(PollResult::Subscription(msg)));
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
Poll::Pending
})
.await;
match result {
Some(PollResult::Subscription(RoundMessage::Subscribe {
round_id,
expected_payload,
sender,
})) => {
subscribers.insert(
round_id as u16,
Subscriber {
round_id,
expected_payload,
sender,
},
);
}
Some(PollResult::Subscription(RoundMessage::Unsubscribe { round_id })) => {
let sequence_number = round_id as u16;
if subscribers
.get(&sequence_number)
.is_some_and(|subscriber| subscriber.round_id == round_id)
{
subscribers.remove(&sequence_number);
}
}
Some(PollResult::Packet(packet)) => {
let recv_instant = Instant::now();
let packet_source = packet.source();
let packet_sequence_number = packet.sequence_number();
if let Some(subscriber) = subscribers.get(&packet_sequence_number) {
let payload_matches =
packet.payload() == &subscriber.expected_payload[..];
if payload_matches
&& subscriber
.sender
.send((packet_source, recv_instant))
.is_err()
{
subscribers.remove(&packet_sequence_number);
}
}
}
None => return, }
}
});
Ok(Self {
inner,
round_sender: sender,
})
}
pub fn measure_many<I>(&self, addresses: I) -> MeasureManyStream<'_, V, I>
where
I: Iterator<Item = V>,
{
let (size_hint, _) = addresses.size_hint();
let send_queue = addresses.into_iter().peekable();
let (sender, receiver) = mpsc::unbounded_channel();
let round_id = self.inner.next_round_id.fetch_add(1, Ordering::Relaxed);
let payload = rand::random::<[u8; 64]>();
let packet = EchoRequestPacket::new(0, round_id as u16, &payload);
if self
.round_sender
.send(RoundMessage::Subscribe {
round_id,
expected_payload: packet.payload(),
sender,
})
.is_err()
{
panic!("Receiver closed");
}
MeasureManyStream {
pinger: self,
packet,
send_queue,
in_flight: HashMap::with_capacity(size_hint),
receiver,
round_id,
}
}
}
#[must_use = "streams do nothing unless polled"]
pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
pinger: &'a Pinger<V>,
packet: EchoRequestPacket<V>,
send_queue: Peekable<I>,
in_flight: HashMap<V, Instant>,
receiver: mpsc::UnboundedReceiver<(V, Instant)>,
round_id: u64,
}
impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
return Poll::Ready(maybe_reply);
}
self.poll_next_icmp_replies(cx);
if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
return Poll::Ready(None);
}
Poll::Pending
}
fn poll_next_icmp_replies(&mut self, cx: &mut Context<'_>) {
while let Some(&addr) = self.send_queue.peek() {
if self.in_flight.contains_key(&addr) {
self.send_queue.next();
continue;
}
match self.pinger.inner.raw.poll_send_to(cx, addr, &self.packet) {
Poll::Ready(result) => {
let sent_at = Instant::now();
let taken_addr = self.send_queue.next();
debug_assert!(taken_addr.is_some());
if result.is_ok() {
self.in_flight.insert(addr, sent_at);
}
}
Poll::Pending => {
cx.waker().wake_by_ref();
break;
}
}
}
}
fn poll_next_from_different_round(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(V, Duration)>> {
loop {
match self.receiver.poll_recv(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some((addr, recv_instant))) => {
if let Some(send_instant) = self.in_flight.remove(&addr) {
let rtt = recv_instant - send_instant;
return Poll::Ready(Some((addr, rtt)));
}
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
}
}
#[cfg(feature = "stream")]
impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'_, V, I> {
type Item = (V, Duration);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().poll_next_unpin(cx)
}
}
impl<V: IpVersion, I: Iterator<Item = V>> Drop for MeasureManyStream<'_, V, I> {
fn drop(&mut self) {
let _ = self.pinger.round_sender.send(RoundMessage::Unsubscribe {
round_id: self.round_id,
});
}
}