Skip to main content

irithyll/stream/
channel.rs

1//! Bounded channel adapters for backpressure-aware sample ingestion.
2//!
3//! Provides [`SampleSender`] and [`SampleReceiver`], thin wrappers around
4//! tokio's bounded mpsc channel that surface irithyll-typed errors and
5//! integrate cleanly with the async training loop in [`super::AsyncSGBT`].
6//!
7//! The bounded channel enforces backpressure: if the training loop falls
8//! behind, senders will await until capacity is available, preventing
9//! unbounded memory growth from fast data sources.
10
11use std::task::{Context, Poll};
12
13use crate::error::{IrithyllError, Result};
14use crate::sample::Sample;
15use tokio::sync::mpsc;
16
17/// A clonable sender handle for streaming [`Sample`]s into the training loop.
18///
19/// Wraps a [`tokio::sync::mpsc::Sender<Sample>`] and maps channel errors to
20/// [`IrithyllError::ChannelClosed`]. Clone this freely to feed samples from
21/// multiple async tasks.
22///
23/// # Example
24///
25/// ```no_run
26/// # use irithyll::sample::Sample;
27/// # async fn example(sender: irithyll::stream::SampleSender) -> irithyll::error::Result<()> {
28/// sender.send(Sample::new(vec![1.0, 2.0], 3.0)).await?;
29/// # Ok(())
30/// # }
31/// ```
32#[derive(Clone, Debug)]
33pub struct SampleSender {
34    inner: mpsc::Sender<Sample>,
35}
36
37impl SampleSender {
38    /// Create a new sender from a tokio mpsc sender.
39    pub(crate) fn new(inner: mpsc::Sender<Sample>) -> Self {
40        Self { inner }
41    }
42
43    /// Send a single sample into the training channel.
44    ///
45    /// Awaits if the channel is full (backpressure). Returns
46    /// [`IrithyllError::ChannelClosed`] if the receiver has been dropped.
47    pub async fn send(&self, sample: Sample) -> Result<()> {
48        self.inner
49            .send(sample)
50            .await
51            .map_err(|_| IrithyllError::ChannelClosed)
52    }
53
54    /// Send a batch of samples sequentially into the training channel.
55    ///
56    /// Each sample is sent in order; backpressure applies per-sample.
57    /// Returns [`IrithyllError::ChannelClosed`] if the receiver is dropped
58    /// before all samples are sent.
59    pub async fn send_batch(&self, samples: Vec<Sample>) -> Result<()> {
60        for sample in samples {
61            self.send(sample).await?;
62        }
63        Ok(())
64    }
65
66    /// Returns `true` if the receiver has been dropped.
67    pub fn is_closed(&self) -> bool {
68        self.inner.is_closed()
69    }
70}
71
72/// The receiving end of the sample channel, consumed by the training loop.
73///
74/// Wraps a [`tokio::sync::mpsc::Receiver<Sample>`]. Not clonable — only
75/// one consumer (the [`AsyncSGBT`](super::AsyncSGBT) training loop) should
76/// own this.
77#[derive(Debug)]
78pub struct SampleReceiver {
79    inner: mpsc::Receiver<Sample>,
80}
81
82impl SampleReceiver {
83    /// Create a new receiver from a tokio mpsc receiver.
84    pub(crate) fn new(inner: mpsc::Receiver<Sample>) -> Self {
85        Self { inner }
86    }
87
88    /// Receive the next sample from the channel.
89    ///
90    /// Returns `None` when all senders have been dropped and the channel
91    /// is drained — this is the clean shutdown signal for the training loop.
92    pub async fn recv(&mut self) -> Option<Sample> {
93        self.inner.recv().await
94    }
95
96    /// Poll for the next sample without creating an intermediate future.
97    ///
98    /// This is the non-async counterpart of [`recv`](Self::recv), designed
99    /// for use inside manual [`Stream`](futures_core::Stream) implementations
100    /// like [`PredictionStream`](super::adapters::PredictionStream).
101    pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Sample>> {
102        self.inner.poll_recv(cx)
103    }
104}
105
106/// Create a bounded sample channel with the given capacity.
107///
108/// Returns a `(SampleSender, SampleReceiver)` pair. The sender is clonable;
109/// the receiver is not.
110pub(crate) fn bounded(capacity: usize) -> (SampleSender, SampleReceiver) {
111    let (tx, rx) = mpsc::channel(capacity);
112    (SampleSender::new(tx), SampleReceiver::new(rx))
113}
114
115// ---------------------------------------------------------------------------
116// Tests
117// ---------------------------------------------------------------------------
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::sample::Sample;
123
124    fn sample(x: f64) -> Sample {
125        Sample::new(vec![x], x * 2.0)
126    }
127
128    // 1. Basic send/recv round-trip.
129    #[tokio::test]
130    async fn send_recv_round_trip() {
131        let (tx, mut rx) = bounded(16);
132        tx.send(sample(1.0)).await.unwrap();
133        let s = rx.recv().await.unwrap();
134        assert!((s.features[0] - 1.0).abs() < f64::EPSILON);
135        assert!((s.target - 2.0).abs() < f64::EPSILON);
136    }
137
138    // 2. Recv returns None after all senders are dropped.
139    #[tokio::test]
140    async fn recv_none_on_closed() {
141        let (tx, mut rx) = bounded(16);
142        drop(tx);
143        assert!(rx.recv().await.is_none());
144    }
145
146    // 3. Send fails after receiver is dropped.
147    #[tokio::test]
148    async fn send_fails_on_closed_receiver() {
149        let (tx, rx) = bounded(16);
150        drop(rx);
151        let result = tx.send(sample(1.0)).await;
152        assert!(result.is_err());
153        match result.unwrap_err() {
154            IrithyllError::ChannelClosed => {}
155            other => panic!("expected ChannelClosed, got {:?}", other),
156        }
157    }
158
159    // 4. is_closed reflects receiver state.
160    #[tokio::test]
161    async fn is_closed_reflects_state() {
162        let (tx, rx) = bounded(16);
163        assert!(!tx.is_closed());
164        drop(rx);
165        assert!(tx.is_closed());
166    }
167
168    // 5. send_batch sends all samples in order.
169    #[tokio::test]
170    async fn send_batch_preserves_order() {
171        let (tx, mut rx) = bounded(16);
172        let batch: Vec<Sample> = (0..5).map(|i| sample(i as f64)).collect();
173        tx.send_batch(batch).await.unwrap();
174        drop(tx);
175
176        let mut received = Vec::new();
177        while let Some(s) = rx.recv().await {
178            received.push(s);
179        }
180        assert_eq!(received.len(), 5);
181        for (i, s) in received.iter().enumerate() {
182            assert!((s.features[0] - i as f64).abs() < f64::EPSILON);
183        }
184    }
185
186    // 6. send_batch fails partway if receiver drops mid-batch.
187    #[tokio::test]
188    async fn send_batch_fails_on_mid_drop() {
189        let (tx, rx) = bounded(2);
190        drop(rx);
191        let batch: Vec<Sample> = (0..10).map(|i| sample(i as f64)).collect();
192        let result = tx.send_batch(batch).await;
193        assert!(result.is_err());
194    }
195
196    // 7. Cloned sender works independently.
197    #[tokio::test]
198    async fn cloned_sender_works() {
199        let (tx, mut rx) = bounded(16);
200        let tx2 = tx.clone();
201
202        tx.send(sample(1.0)).await.unwrap();
203        tx2.send(sample(2.0)).await.unwrap();
204
205        let s1 = rx.recv().await.unwrap();
206        let s2 = rx.recv().await.unwrap();
207        assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
208        assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
209    }
210
211    // 8. Bounded channel applies backpressure (capacity=1).
212    #[tokio::test]
213    async fn bounded_backpressure() {
214        let (tx, mut rx) = bounded(1);
215
216        // Fill the single slot.
217        tx.send(sample(1.0)).await.unwrap();
218
219        // Spawn a task that sends another — it should block until we recv.
220        let tx_clone = tx.clone();
221        let handle = tokio::spawn(async move {
222            tx_clone.send(sample(2.0)).await.unwrap();
223        });
224
225        // Small yield to let the spawned task attempt the send.
226        tokio::task::yield_now().await;
227
228        // Drain the first, unblocking the sender.
229        let s1 = rx.recv().await.unwrap();
230        assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
231
232        handle.await.unwrap();
233
234        let s2 = rx.recv().await.unwrap();
235        assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
236    }
237
238    // 9. Channel with large capacity handles burst.
239    #[tokio::test]
240    async fn large_burst() {
241        let (tx, mut rx) = bounded(1024);
242        for i in 0..1000 {
243            tx.send(sample(i as f64)).await.unwrap();
244        }
245        drop(tx);
246
247        let mut count = 0u64;
248        while rx.recv().await.is_some() {
249            count += 1;
250        }
251        assert_eq!(count, 1000);
252    }
253
254    // 10. SampleSender is Send + Sync (compile-time check).
255    #[test]
256    fn sender_is_send_sync() {
257        fn assert_send_sync<T: Send + Sync>() {}
258        assert_send_sync::<SampleSender>();
259    }
260}