use crate::{McrxError, Packet, PacketWithMetadata, Subscription};
use std::io;
#[cfg(not(unix))]
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum TokioReceiveError {
#[error("MCRX: tokio readiness failed: {0}")]
Readiness(io::Error),
#[error(transparent)]
Receive(#[from] McrxError),
}
#[derive(Debug)]
pub struct TokioSubscription {
#[cfg(unix)]
inner: tokio::io::unix::AsyncFd<Subscription>,
#[cfg(not(unix))]
inner: Subscription,
#[cfg(not(unix))]
poll_interval: Duration,
}
impl TokioSubscription {
pub fn new(subscription: Subscription) -> io::Result<Self> {
#[cfg(unix)]
{
Ok(Self {
inner: tokio::io::unix::AsyncFd::new(subscription)?,
})
}
#[cfg(not(unix))]
{
Ok(Self {
inner: subscription,
poll_interval: Duration::from_millis(10),
})
}
}
pub fn subscription(&self) -> &Subscription {
#[cfg(unix)]
{
self.inner.get_ref()
}
#[cfg(not(unix))]
{
&self.inner
}
}
pub fn into_subscription(self) -> Subscription {
#[cfg(unix)]
{
self.inner.into_inner()
}
#[cfg(not(unix))]
{
self.inner
}
}
#[cfg(not(unix))]
pub fn with_poll_interval(mut self, poll_interval: Duration) -> Self {
self.poll_interval = poll_interval;
self
}
pub async fn recv(&mut self) -> Result<Packet, TokioReceiveError> {
#[cfg(unix)]
{
loop {
let mut readiness = self
.inner
.readable_mut()
.await
.map_err(TokioReceiveError::Readiness)?;
match readiness.get_inner_mut().try_recv()? {
Some(packet) => return Ok(packet),
None => readiness.clear_ready(),
}
}
}
#[cfg(not(unix))]
{
loop {
match self.inner.try_recv()? {
Some(packet) => return Ok(packet),
None => tokio::time::sleep(self.poll_interval).await,
}
}
}
}
pub async fn recv_with_metadata(&mut self) -> Result<PacketWithMetadata, TokioReceiveError> {
#[cfg(unix)]
{
loop {
let mut readiness = self
.inner
.readable_mut()
.await
.map_err(TokioReceiveError::Readiness)?;
match readiness.get_inner_mut().try_recv_with_metadata()? {
Some(packet) => return Ok(packet),
None => readiness.clear_ready(),
}
}
}
#[cfg(not(unix))]
{
loop {
match self.inner.try_recv_with_metadata()? {
Some(packet) => return Ok(packet),
None => tokio::time::sleep(self.poll_interval).await,
}
}
}
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::{Context, SubscriptionConfig};
use std::net::{Ipv4Addr, SocketAddrV4};
use tokio::time::{Duration, timeout};
fn sample_config(port: u16) -> SubscriptionConfig {
SubscriptionConfig::asm(Ipv4Addr::new(239, 1, 2, 3), port)
}
fn ipv4_group(config: &SubscriptionConfig) -> Ipv4Addr {
config.ipv4_membership().unwrap().group
}
fn make_multicast_sender() -> std::net::UdpSocket {
std::net::UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).unwrap()
}
#[tokio::test]
async fn tokio_subscription_receives_metadata_packet() {
let mut context = Context::new();
let config = sample_config(55110);
let id = context.add_subscription(config.clone()).unwrap();
context.join_subscription(id).unwrap();
let subscription = context.take_subscription(id).unwrap();
let mut subscription = TokioSubscription::new(subscription).unwrap();
let sender = make_multicast_sender();
let payload = b"tokio adapter packet";
sender
.send_to(
payload,
SocketAddrV4::new(ipv4_group(&config), config.dst_port),
)
.unwrap();
let packet = timeout(Duration::from_secs(1), subscription.recv_with_metadata())
.await
.expect("timed out waiting for tokio packet")
.unwrap();
assert_eq!(packet.packet.subscription_id, id);
assert_eq!(&packet.packet.payload[..], payload);
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn tokio_subscription_with_metrics_is_spawn_safe() {
let mut context = Context::new();
let config = sample_config(55111);
let id = context.add_subscription(config.clone()).unwrap();
context.join_subscription(id).unwrap();
let subscription = context.take_subscription(id).unwrap();
let sender = make_multicast_sender();
let payload = b"tokio metrics packet";
let handle = tokio::spawn(async move {
let mut subscription = TokioSubscription::new(subscription).unwrap();
let packet = timeout(Duration::from_secs(1), subscription.recv_with_metadata())
.await
.expect("timed out waiting for spawned tokio packet")
.unwrap();
let metrics = subscription.subscription().metrics_snapshot();
(packet, metrics)
});
sender
.send_to(
payload,
SocketAddrV4::new(ipv4_group(&config), config.dst_port),
)
.unwrap();
let (packet, metrics) = handle.await.unwrap();
assert_eq!(packet.packet.subscription_id, id);
assert_eq!(&packet.packet.payload[..], payload);
assert_eq!(metrics.packets_received, 1);
assert_eq!(metrics.bytes_received, payload.len() as u64);
assert_eq!(metrics.last_payload_len, Some(payload.len()));
}
}