irithyll/stream/
channel.rs1use std::task::{Context, Poll};
12
13use crate::error::{IrithyllError, Result};
14use crate::sample::Sample;
15use tokio::sync::mpsc;
16
17#[derive(Clone, Debug)]
33pub struct SampleSender {
34 inner: mpsc::Sender<Sample>,
35}
36
37impl SampleSender {
38 pub(crate) fn new(inner: mpsc::Sender<Sample>) -> Self {
40 Self { inner }
41 }
42
43 pub async fn send(&self, sample: Sample) -> Result<()> {
48 self.inner
49 .send(sample)
50 .await
51 .map_err(|_| IrithyllError::ChannelClosed)
52 }
53
54 pub async fn send_batch(&self, samples: Vec<Sample>) -> Result<()> {
60 for sample in samples {
61 self.send(sample).await?;
62 }
63 Ok(())
64 }
65
66 pub fn is_closed(&self) -> bool {
68 self.inner.is_closed()
69 }
70}
71
72#[derive(Debug)]
78pub struct SampleReceiver {
79 inner: mpsc::Receiver<Sample>,
80}
81
82impl SampleReceiver {
83 pub(crate) fn new(inner: mpsc::Receiver<Sample>) -> Self {
85 Self { inner }
86 }
87
88 pub async fn recv(&mut self) -> Option<Sample> {
93 self.inner.recv().await
94 }
95
96 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Sample>> {
102 self.inner.poll_recv(cx)
103 }
104}
105
106pub(crate) fn bounded(capacity: usize) -> (SampleSender, SampleReceiver) {
111 let (tx, rx) = mpsc::channel(capacity);
112 (SampleSender::new(tx), SampleReceiver::new(rx))
113}
114
115#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::sample::Sample;
123
124 fn sample(x: f64) -> Sample {
125 Sample::new(vec![x], x * 2.0)
126 }
127
128 #[tokio::test]
130 async fn send_recv_round_trip() {
131 let (tx, mut rx) = bounded(16);
132 tx.send(sample(1.0)).await.unwrap();
133 let s = rx.recv().await.unwrap();
134 assert!((s.features[0] - 1.0).abs() < f64::EPSILON);
135 assert!((s.target - 2.0).abs() < f64::EPSILON);
136 }
137
138 #[tokio::test]
140 async fn recv_none_on_closed() {
141 let (tx, mut rx) = bounded(16);
142 drop(tx);
143 assert!(rx.recv().await.is_none());
144 }
145
146 #[tokio::test]
148 async fn send_fails_on_closed_receiver() {
149 let (tx, rx) = bounded(16);
150 drop(rx);
151 let result = tx.send(sample(1.0)).await;
152 assert!(result.is_err());
153 match result.unwrap_err() {
154 IrithyllError::ChannelClosed => {}
155 other => panic!("expected ChannelClosed, got {:?}", other),
156 }
157 }
158
159 #[tokio::test]
161 async fn is_closed_reflects_state() {
162 let (tx, rx) = bounded(16);
163 assert!(!tx.is_closed());
164 drop(rx);
165 assert!(tx.is_closed());
166 }
167
168 #[tokio::test]
170 async fn send_batch_preserves_order() {
171 let (tx, mut rx) = bounded(16);
172 let batch: Vec<Sample> = (0..5).map(|i| sample(i as f64)).collect();
173 tx.send_batch(batch).await.unwrap();
174 drop(tx);
175
176 let mut received = Vec::new();
177 while let Some(s) = rx.recv().await {
178 received.push(s);
179 }
180 assert_eq!(received.len(), 5);
181 for (i, s) in received.iter().enumerate() {
182 assert!((s.features[0] - i as f64).abs() < f64::EPSILON);
183 }
184 }
185
186 #[tokio::test]
188 async fn send_batch_fails_on_mid_drop() {
189 let (tx, rx) = bounded(2);
190 drop(rx);
191 let batch: Vec<Sample> = (0..10).map(|i| sample(i as f64)).collect();
192 let result = tx.send_batch(batch).await;
193 assert!(result.is_err());
194 }
195
196 #[tokio::test]
198 async fn cloned_sender_works() {
199 let (tx, mut rx) = bounded(16);
200 let tx2 = tx.clone();
201
202 tx.send(sample(1.0)).await.unwrap();
203 tx2.send(sample(2.0)).await.unwrap();
204
205 let s1 = rx.recv().await.unwrap();
206 let s2 = rx.recv().await.unwrap();
207 assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
208 assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
209 }
210
211 #[tokio::test]
213 async fn bounded_backpressure() {
214 let (tx, mut rx) = bounded(1);
215
216 tx.send(sample(1.0)).await.unwrap();
218
219 let tx_clone = tx.clone();
221 let handle = tokio::spawn(async move {
222 tx_clone.send(sample(2.0)).await.unwrap();
223 });
224
225 tokio::task::yield_now().await;
227
228 let s1 = rx.recv().await.unwrap();
230 assert!((s1.features[0] - 1.0).abs() < f64::EPSILON);
231
232 handle.await.unwrap();
233
234 let s2 = rx.recv().await.unwrap();
235 assert!((s2.features[0] - 2.0).abs() < f64::EPSILON);
236 }
237
238 #[tokio::test]
240 async fn large_burst() {
241 let (tx, mut rx) = bounded(1024);
242 for i in 0..1000 {
243 tx.send(sample(i as f64)).await.unwrap();
244 }
245 drop(tx);
246
247 let mut count = 0u64;
248 while rx.recv().await.is_some() {
249 count += 1;
250 }
251 assert_eq!(count, 1000);
252 }
253
254 #[test]
256 fn sender_is_send_sync() {
257 fn assert_send_sync<T: Send + Sync>() {}
258 assert_send_sync::<SampleSender>();
259 }
260}