use std::task::{Context, Poll};
use crate::error::{IrithyllError, Result};
use crate::sample::Sample;
use tokio::sync::mpsc;
#[derive(Clone, Debug)]
pub struct SampleSender {
inner: mpsc::Sender<Sample>,
}
impl SampleSender {
pub(crate) fn new(inner: mpsc::Sender<Sample>) -> Self {
Self { inner }
}
pub async fn send(&self, sample: Sample) -> Result<()> {
self.inner
.send(sample)
.await
.map_err(|_| IrithyllError::ChannelClosed)
}
pub async fn send_batch(&self, samples: Vec<Sample>) -> Result<()> {
for sample in samples {
self.send(sample).await?;
}
Ok(())
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
#[derive(Debug)]
pub struct SampleReceiver {
inner: mpsc::Receiver<Sample>,
}
impl SampleReceiver {
pub(crate) fn new(inner: mpsc::Receiver<Sample>) -> Self {
Self { inner }
}
pub async fn recv(&mut self) -> Option<Sample> {
self.inner.recv().await
}
pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Sample>> {
self.inner.poll_recv(cx)
}
}
pub(crate) fn bounded(capacity: usize) -> (SampleSender, SampleReceiver) {
let (tx, rx) = mpsc::channel(capacity);
(SampleSender::new(tx), SampleReceiver::new(rx))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sample::Sample;
fn sample(x: f64) -> Sample {
Sample::new(vec![x], x * 2.0)
}
#[tokio::test]
async fn send_recv_round_trip() {
let (tx, mut rx) = bounded(16);
tx.send(sample(1.0)).await.unwrap();
let s = rx.recv().await.unwrap();
assert!((s.features[0] - 1.0).abs() < f64::EPSILON);
assert!((s.target - 2.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn recv_none_on_closed() {
let (tx, mut rx) = bounded(16);
drop(tx);
assert!(rx.recv().await.is_none());
}
#[tokio::test]
async fn send_fails_on_closed_receiver() {
let (tx, rx) = bounded(16);
drop(rx);
let result = tx.send(sample(1.0)).await;
assert!(result.is_err());
match result.unwrap_err() {
IrithyllError::ChannelClosed => {}
other => panic!("expected ChannelClosed, got {:?}", other),
}
}
#[tokio::test]
async fn is_closed_reflects_state() {
let (tx, rx) = bounded(16);
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
}
#[tokio::test]
async fn send_batch_preserves_order() {
let (tx, mut rx) = bounded(16);
let batch: Vec<Sample> = (0..5).map(|i| sample(i as f64)).collect();
tx.send_batch(batch).await.unwrap();
drop(tx);
let mut received = Vec::new();
while let Some(s) = rx.recv().await {
received.push(s);
}
assert_eq!(received.len(), 5);
for (i, s) in received.iter().enumerate() {
assert!((s.features[0] - i as f64).abs() < f64::EPSILON);
}
}
#[tokio::test]
async fn send_batch_fails_on_mid_drop() {
let (tx, rx) = bounded(2);
drop(rx);
let batch: Vec<Sample> = (0..10).map(|i| sample(i as f64)).collect();
let result = tx.send_batch(batch).await;
assert!(result.is_err());
}
#[tokio::test]
async fn cloned_sender_works() {
let (tx, mut rx) = bounded(16);
let tx2 = tx.clone();
tx.send(sample(1.0)).await.unwrap();
tx2.send(sample(2.0)).await.unwrap();
let s1 = rx.recv().await.unwrap();
let s2 = rx.recv().await.unwrap();
assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn bounded_backpressure() {
let (tx, mut rx) = bounded(1);
tx.send(sample(1.0)).await.unwrap();
let tx_clone = tx.clone();
let handle = tokio::spawn(async move {
tx_clone.send(sample(2.0)).await.unwrap();
});
tokio::task::yield_now().await;
let s1 = rx.recv().await.unwrap();
assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
handle.await.unwrap();
let s2 = rx.recv().await.unwrap();
assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn large_burst() {
let (tx, mut rx) = bounded(1024);
for i in 0..1000 {
tx.send(sample(i as f64)).await.unwrap();
}
drop(tx);
let mut count = 0u64;
while rx.recv().await.is_some() {
count += 1;
}
assert_eq!(count, 1000);
}
#[test]
fn sender_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SampleSender>();
}
}