agner_utils/
spsc.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5
6use futures::lock::BiLock;
7
8#[cfg(test)]
9mod tests;
10
11pub fn channel<T>(max_len: usize) -> (Sender<T>, Receiver<T>) {
12    let inner =
13        Inner { queue: Default::default(), max_len, sender_waker: None, receiver_waker: None };
14
15    let (sender, receiver) = BiLock::new(inner);
16    (Sender(sender), Receiver(receiver))
17}
18
19#[derive(Debug)]
20pub struct Receiver<T>(BiLock<Inner<T>>);
21
22#[derive(Debug)]
23pub struct Sender<T>(BiLock<Inner<T>>);
24
25impl<T> Receiver<T>
26where
27    T: Unpin,
28{
29    pub fn recv(&mut self, should_block: bool) -> impl Future<Output = Option<T>> + '_ {
30        Receive { lock: &self.0, should_block }
31    }
32
33    pub async fn len(&self) -> (usize, usize) {
34        let locked = self.0.lock().await;
35        (locked.queue.len(), locked.max_len)
36    }
37}
38
39impl<T> Sender<T>
40where
41    T: Unpin,
42{
43    pub fn send(
44        &mut self,
45        item: T,
46        should_block: bool,
47    ) -> impl Future<Output = Result<(), T>> + '_ {
48        Send { lock: &self.0, should_block, item: Some(item) }
49    }
50
51    pub async fn len(&self) -> (usize, usize) {
52        let locked = self.0.lock().await;
53        (locked.queue.len(), locked.max_len)
54    }
55}
56
57#[derive(Debug)]
58struct Inner<T> {
59    queue: VecDeque<T>,
60    max_len: usize,
61    sender_waker: Option<Waker>,
62    receiver_waker: Option<Waker>,
63}
64
65#[pin_project::pin_project]
66struct Receive<'a, T> {
67    lock: &'a BiLock<Inner<T>>,
68    should_block: bool,
69}
70
71#[pin_project::pin_project]
72struct Send<'a, T> {
73    lock: &'a BiLock<Inner<T>>,
74    should_block: bool,
75    item: Option<T>,
76}
77
78impl<'a, T> Future for Receive<'a, T>
79where
80    T: Unpin,
81{
82    type Output = Option<T>;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let this = self.project();
86
87        let mut locked = futures::ready!(this.lock.poll_lock(cx));
88        let _ = locked.receiver_waker.take();
89
90        match (locked.queue.pop_front(), this.should_block) {
91            (Some(item), _) => {
92                if let Some(waker) = locked.sender_waker.take() {
93                    waker.wake();
94                }
95                Poll::Ready(Some(item))
96            },
97            (None, false) => Poll::Ready(None),
98            (None, true) => {
99                let should_be_none = locked.receiver_waker.replace(cx.waker().to_owned());
100                assert!(should_be_none.is_none());
101                Poll::Pending
102            },
103        }
104    }
105}
106
107impl<'a, T> Future for Send<'a, T>
108where
109    T: Unpin,
110{
111    type Output = Result<(), T>;
112
113    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114        let this = self.project();
115
116        let mut locked = futures::ready!(this.lock.poll_lock(cx));
117        let _ = locked.sender_waker.take();
118
119        match (locked.queue.len() < locked.max_len, this.should_block) {
120            (true, _) => {
121                let item = this.item.take().expect("Item empty");
122                locked.queue.push_back(item);
123                if let Some(waker) = locked.receiver_waker.take() {
124                    waker.wake();
125                }
126                Poll::Ready(Ok(()))
127            },
128            (false, false) => {
129                let item = this.item.take().expect("Item empty");
130                Poll::Ready(Err(item))
131            },
132            (false, true) => {
133                let should_be_none = locked.sender_waker.replace(cx.waker().to_owned());
134                assert!(should_be_none.is_none());
135                Poll::Pending
136            },
137        }
138    }
139}