1use std::{
8 cell::RefCell,
9 collections::VecDeque,
10 future::poll_fn,
11 pin::Pin,
12 rc::Rc,
13 task::{Context, Poll, Waker},
14};
15
16use derive_more::{Display, Error};
17use futures::Stream;
18
19struct Inner<T> {
20 queue: VecDeque<T>,
21 rx_alive: bool,
23 rx_waker: Option<Waker>,
25 tx_wakers: Vec<Waker>,
27}
28
29pub struct BoundedSender<T> {
31 inner: Rc<RefCell<Inner<T>>>,
32}
33
34impl<T> Clone for BoundedSender<T> {
35 fn clone(&self) -> Self {
36 Self {
37 inner: self.inner.clone(),
38 }
39 }
40}
41
42pub struct Receiver<T> {
44 inner: Rc<RefCell<Inner<T>>>,
45}
46
47pub fn bounded<T>(cap: usize) -> (BoundedSender<T>, Receiver<T>) {
51 assert_ne!(cap, 0, "a bounded channel with capacity 0 does not work");
52 let mut queue = VecDeque::new();
53 queue.reserve_exact(cap);
54 let inner = Rc::new(RefCell::new(Inner {
55 queue,
56 rx_alive: true,
57 rx_waker: None,
58 tx_wakers: Vec::new(),
59 }));
60 let tx = BoundedSender {
61 inner: inner.clone(),
62 };
63 let rx = Receiver { inner };
64 (tx, rx)
65}
66
67#[derive(Debug, Display, Error)]
69#[display("receiver has been dropped")]
70pub struct SendError<T>(pub T);
71
72#[derive(Debug, Display, Error)]
74pub enum TrySendError<T> {
75 #[display("channel is full")]
77 Full(T),
78 #[display("receiver has been dropped")]
80 Closed(T),
81}
82
83impl<T> BoundedSender<T> {
84 pub async fn send(&mut self, val: T) -> Result<(), SendError<T>> {
88 let mut val = Some(val);
89 poll_fn(|cx| {
90 if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
91 waker.wake();
92 }
93 match self.try_send(val.take().unwrap()) {
94 Ok(()) => Poll::Ready(Ok(())),
95 Err(TrySendError::Full(v)) => {
96 val = Some(v);
97 self.inner.borrow_mut().tx_wakers.push(cx.waker().clone());
98 Poll::Pending
99 }
100 Err(TrySendError::Closed(v)) => Poll::Ready(Err(SendError(v))),
101 }
102 })
103 .await
104 }
105
106 pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
111 let mut inner = self.inner.borrow_mut();
112 if !inner.rx_alive {
113 Err(TrySendError::Closed(val))
114 } else if inner.queue.len() == inner.queue.capacity() {
115 Err(TrySendError::Full(val))
116 } else {
117 inner.queue.push_back(val);
118 Ok(())
119 }
120 }
121}
122
123#[derive(Debug, Display, Error, PartialEq, Eq)]
125#[display("all senders have been dropped")]
126pub struct RecvError;
127
128#[derive(Debug, Display, Error, PartialEq, Eq)]
130pub enum TryRecvError {
131 #[display("channel is empty")]
133 Empty,
134 #[display("all senders have been dropped")]
136 Closed,
137}
138
139impl<T> Receiver<T> {
140 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
141 match self.try_recv() {
142 Ok(val) => {
143 for waker in self.inner.borrow_mut().tx_wakers.drain(..) {
144 waker.wake();
145 }
146 Poll::Ready(Ok(val))
147 }
148 Err(TryRecvError::Empty) => {
149 self.inner.borrow_mut().rx_waker = Some(cx.waker().clone());
150 Poll::Pending
151 }
152 Err(TryRecvError::Closed) => Poll::Ready(Err(RecvError)),
153 }
154 }
155
156 pub async fn recv(&mut self) -> Result<T, RecvError> {
160 poll_fn(|cx| self.poll_recv(cx)).await
161 }
162
163 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
168 if let Some(val) = self.inner.borrow_mut().queue.pop_front() {
169 return Ok(val);
170 }
171 if Rc::strong_count(&self.inner) == 1 {
172 Err(TryRecvError::Closed)
173 } else {
174 Err(TryRecvError::Empty)
175 }
176 }
177}
178
179impl<T> Stream for Receiver<T> {
180 type Item = T;
181
182 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183 self.poll_recv(cx).map(|r| r.ok())
184 }
185
186 fn size_hint(&self) -> (usize, Option<usize>) {
187 (self.inner.borrow().queue.len(), None)
188 }
189}
190
191impl<T> Drop for BoundedSender<T> {
192 fn drop(&mut self) {
193 if Rc::strong_count(&self.inner) == 2 {
196 if let Some(waker) = self.inner.borrow_mut().rx_waker.take() {
197 waker.wake();
198 }
199 }
200 }
201}
202
203impl<T> Drop for Receiver<T> {
204 fn drop(&mut self) {
205 self.inner.borrow_mut().rx_alive = false;
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::{
212 pin::{Pin, pin},
213 task::Poll,
214 };
215
216 use futures::StreamExt;
217 use tempest_io::VirtualIo;
218
219 use crate::{block_on, spawn};
220
221 use super::*;
222
223 #[test]
226 fn test_try_send_one() {
227 block_on(VirtualIo::default(), async {
228 let (mut tx, _rx) = bounded(1);
229 assert!(tx.try_send(42).is_ok());
230 });
231 }
232
233 #[test]
234 fn test_try_send_exactly_full() {
235 block_on(VirtualIo::default(), async {
236 let (mut tx, _rx) = bounded(2);
237 assert!(tx.try_send(1).is_ok());
238 assert!(tx.try_send(2).is_ok());
239 });
240 }
241
242 #[test]
243 fn test_try_send_over_full() {
244 block_on(VirtualIo::default(), async {
245 let (mut tx, _rx) = bounded(1);
246 tx.try_send(1).unwrap();
247 match tx.try_send(99) {
248 Err(TrySendError::Full(v)) => assert_eq!(v, 99),
249 _ => panic!("expected Full"),
250 }
251 });
252 }
253
254 #[test]
255 fn test_try_send_closed() {
256 block_on(VirtualIo::default(), async {
257 let (mut tx, rx) = bounded::<i32>(1);
258 drop(rx);
259 match tx.try_send(99) {
260 Err(TrySendError::Closed(v)) => assert_eq!(v, 99),
261 _ => panic!("expected Closed"),
262 }
263 });
264 }
265
266 #[test]
269 fn test_send_one() {
270 block_on(VirtualIo::default(), async {
271 let (mut tx, _rx) = bounded(1);
272 tx.send(42).await.unwrap();
273 });
274 }
275
276 #[test]
277 fn test_send_exactly_full() {
278 block_on(VirtualIo::default(), async {
279 let (mut tx, _rx) = bounded(2);
280 tx.send(1).await.unwrap();
281 tx.send(2).await.unwrap();
282 });
283 }
284
285 #[test]
286 fn test_send_pending_when_full() {
287 block_on(VirtualIo::default(), async {
288 let (mut tx, _rx) = bounded(1);
289 tx.try_send(1).unwrap();
290
291 let waker = std::task::Waker::noop();
292 let mut cx = std::task::Context::from_waker(&waker);
293 let mut fut = pin!(tx.send(2));
294 assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
295 });
296 }
297
298 #[test]
299 fn test_send_when_full_eventually_resolves() {
300 block_on(VirtualIo::default(), async {
301 let (mut tx, mut rx) = bounded(1);
302 tx.try_send(1).unwrap();
303 let mut handle = spawn(async move { tx.send(2).await.unwrap() });
305
306 let waker = std::task::Waker::noop();
308 let mut cx = std::task::Context::from_waker(&waker);
309 assert!(matches!(Pin::new(&mut handle).poll(&mut cx), Poll::Pending));
310
311 assert_eq!(rx.recv().await, Ok(1));
312 assert_eq!(rx.recv().await, Ok(2));
313 assert!(rx.try_recv().is_err())
315 });
316 }
317
318 #[test]
319 fn test_recv_when_empty_eventually_resolves() {
320 block_on(VirtualIo::default(), async {
321 let (mut tx, mut rx) = bounded(1);
322 spawn(async move {
323 assert_eq!(rx.recv().await, Ok(1));
324 assert_eq!(rx.recv().await, Ok(2));
325 });
326
327 tx.send(1).await.unwrap();
328 assert!(matches!(tx.try_send(2), Err(TrySendError::Full(2))));
329 tx.send(2).await.unwrap();
330 });
331 }
332
333 #[test]
334 fn test_send_closed() {
335 block_on(VirtualIo::default(), async {
336 let (mut tx, rx) = bounded::<i32>(1);
337 drop(rx);
338 match tx.send(99).await {
339 Err(SendError(v)) => assert_eq!(v, 99),
340 Ok(()) => panic!("expected Err"),
341 }
342 });
343 }
344
345 #[test]
348 fn test_try_recv_one() {
349 block_on(VirtualIo::default(), async {
350 let (mut tx, mut rx) = bounded(1);
351 tx.try_send(42).unwrap();
352 assert_eq!(rx.try_recv().unwrap(), 42);
353 });
354 }
355
356 #[test]
357 fn test_try_recv_in_order() {
358 block_on(VirtualIo::default(), async {
359 let (mut tx, mut rx) = bounded(3);
360 tx.try_send(1).unwrap();
361 tx.try_send(2).unwrap();
362 tx.try_send(3).unwrap();
363 assert_eq!(rx.try_recv().unwrap(), 1);
364 assert_eq!(rx.try_recv().unwrap(), 2);
365 assert_eq!(rx.try_recv().unwrap(), 3);
366 });
367 }
368
369 #[test]
370 fn test_try_recv_empty() {
371 block_on(VirtualIo::default(), async {
372 let (_tx, mut rx) = bounded::<i32>(1);
373 assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
374 });
375 }
376
377 #[test]
378 fn test_try_recv_closed() {
379 block_on(VirtualIo::default(), async {
380 let (tx, mut rx) = bounded::<i32>(1);
381 drop(tx);
382 assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
383 });
384 }
385
386 #[test]
389 fn test_recv_one() {
390 block_on(VirtualIo::default(), async {
391 let (mut tx, mut rx) = bounded(1);
392 tx.send(42).await.unwrap();
393 assert_eq!(rx.recv().await.unwrap(), 42);
394 });
395 }
396
397 #[test]
398 fn test_recv_in_order() {
399 block_on(VirtualIo::default(), async {
400 let (mut tx, mut rx) = bounded(3);
401 tx.send(1).await.unwrap();
402 tx.send(2).await.unwrap();
403 tx.send(3).await.unwrap();
404 assert_eq!(rx.recv().await.unwrap(), 1);
405 assert_eq!(rx.recv().await.unwrap(), 2);
406 assert_eq!(rx.recv().await.unwrap(), 3);
407 });
408 }
409
410 #[test]
411 fn test_recv_pending_when_empty() {
412 block_on(VirtualIo::default(), async {
413 let (_tx, mut rx) = bounded::<i32>(1);
414
415 let waker = std::task::Waker::noop();
416 let mut cx = std::task::Context::from_waker(&waker);
417 let mut fut = pin!(rx.recv());
418 assert!(matches!(fut.as_mut().poll(&mut cx), Poll::Pending));
419 });
420 }
421
422 #[test]
423 fn test_recv_closed() {
424 block_on(VirtualIo::default(), async {
425 let (tx, mut rx) = bounded::<i32>(1);
426 drop(tx);
427 assert_eq!(rx.recv().await, Err(RecvError));
428 });
429 }
430
431 #[test]
432 fn test_recv_woken_when_last_sender_dropped() {
433 block_on(VirtualIo::default(), async {
436 let (tx, mut rx) = bounded::<i32>(1);
437 spawn(async move {
438 drop(tx);
439 });
440 assert_eq!(rx.recv().await, Err(RecvError));
441 });
442 }
443
444 #[test]
445 fn test_recv_woken_when_last_of_multiple_senders_dropped() {
446 block_on(VirtualIo::default(), async {
447 let (tx, mut rx) = bounded::<i32>(1);
448 let tx2 = tx.clone();
449 spawn(async move {
450 drop(tx);
451 drop(tx2);
452 });
453 assert_eq!(rx.recv().await, Err(RecvError));
454 });
455 }
456
457 #[test]
460 fn test_stream_recv() {
461 const ITEMS: &[i32; 3] = &[1, 2, 3];
462 block_on(VirtualIo::default(), async {
463 let (mut tx, rx) = bounded::<i32>(1);
464 spawn(async move {
465 for &item in ITEMS {
466 tx.send(item).await.unwrap();
467 }
468 });
469
470 let result: Vec<_> = rx.collect().await;
471 assert_eq!(result, ITEMS);
472 })
473 }
474}