local_sync/mpsc/
chan.rs

1use std::{
2    cell::{Cell, RefCell},
3    error::Error,
4    fmt,
5    rc::Rc,
6    task::{Context, Poll, Waker},
7};
8
9use super::{block::Queue, semaphore::Semaphore};
10
11pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>)
12where
13    S: Semaphore,
14{
15    let chan = Rc::new(Chan::new(semaphore));
16    let tx = Tx::new(chan.clone());
17    let rx = Rx::new(chan);
18    (tx, rx)
19}
20
21pub(crate) struct Chan<T, S: Semaphore> {
22    queue: RefCell<Queue<T>>,
23    pub(crate) semaphore: S,
24    rx_waker: RefCell<Option<Waker>>,
25    tx_count: Cell<usize>,
26}
27
28/// Error returned by `try_recv`.
29#[derive(PartialEq, Eq, Clone, Copy, Debug)]
30pub enum TryRecvError {
31    /// This **channel** is currently empty, but the **Sender**(s) have not yet
32    /// disconnected, so data may yet become available.
33    Empty,
34    /// The **channel**'s sending half has become disconnected, and there will
35    /// never be any more data received on it.
36    Disconnected,
37}
38
39impl fmt::Display for TryRecvError {
40    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match *self {
42            TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
43            TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
44        }
45    }
46}
47
48impl Error for TryRecvError {}
49
50impl<T, S> Chan<T, S>
51where
52    S: Semaphore,
53{
54    pub(crate) fn new(semaphore: S) -> Self {
55        let queue = RefCell::new(Queue::new());
56        Self {
57            queue,
58            semaphore,
59            rx_waker: RefCell::new(None),
60            tx_count: Cell::new(0),
61        }
62    }
63}
64
65impl<T, S> Drop for Chan<T, S>
66where
67    S: Semaphore,
68{
69    fn drop(&mut self) {
70        // consume all elements:
71        // we cleared all elements on Rx drop, but there may still some
72        // values sent after permits added.
73        let mut queue = self.queue.borrow_mut();
74        while !queue.is_empty() {
75            drop(unsafe { queue.pop_unchecked() });
76        }
77        // drop all blocks of queue
78        unsafe { queue.free_blocks() }
79    }
80}
81
82pub(crate) struct Tx<T, S>
83where
84    S: Semaphore,
85{
86    pub(crate) chan: Rc<Chan<T, S>>,
87}
88
89#[derive(PartialEq, Eq, Clone, Copy, Debug)]
90pub enum SendError {
91    RxClosed,
92}
93
94pub(crate) struct Rx<T, S>
95where
96    S: Semaphore,
97{
98    chan: Rc<Chan<T, S>>,
99}
100
101impl<T, S> Tx<T, S>
102where
103    S: Semaphore,
104{
105    pub(crate) fn new(chan: Rc<Chan<T, S>>) -> Self {
106        chan.tx_count.set(chan.tx_count.get() + 1);
107        Self { chan }
108    }
109
110    // caller must make sure the chan has spaces
111    pub(crate) fn send(&self, value: T) -> Result<(), SendError> {
112        // check if the semaphore is closed
113        if self.chan.semaphore.is_closed() {
114            return Err(SendError::RxClosed);
115        }
116
117        // put data into the queue
118        unsafe {
119            self.chan.queue.borrow_mut().push_unchecked(value);
120        }
121        // if rx waker is set, wake it
122        if let Some(w) = self.chan.rx_waker.replace(None) {
123            w.wake();
124        }
125        Ok(())
126    }
127
128    pub fn is_closed(&self) -> bool {
129        self.chan.semaphore.is_closed()
130    }
131
132    /// Returns `true` if senders belong to the same channel.
133    pub(crate) fn same_channel(&self, other: &Self) -> bool {
134        Rc::ptr_eq(&self.chan, &other.chan)
135    }
136}
137
138impl<T, S> Clone for Tx<T, S>
139where
140    S: Semaphore,
141{
142    fn clone(&self) -> Self {
143        self.chan.tx_count.set(self.chan.tx_count.get() + 1);
144        Self {
145            chan: self.chan.clone(),
146        }
147    }
148}
149
150impl<T, S> Drop for Tx<T, S>
151where
152    S: Semaphore,
153{
154    fn drop(&mut self) {
155        let cnt = self.chan.tx_count.get();
156        self.chan.tx_count.set(cnt - 1);
157
158        if cnt == 1 {
159            self.chan.semaphore.close();
160            if let Some(rx_waker) = self.chan.rx_waker.take() {
161                rx_waker.wake();
162            }
163        }
164    }
165}
166
167impl<T, S> Rx<T, S>
168where
169    S: Semaphore,
170{
171    pub(crate) fn new(chan: Rc<Chan<T, S>>) -> Self {
172        Self { chan }
173    }
174
175    pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
176        let mut queue = self.chan.queue.borrow_mut();
177        if !queue.is_empty() {
178            let val = unsafe { queue.pop_unchecked() };
179            self.chan.semaphore.add_permits(1);
180            return Ok(val);
181        }
182        if self.chan.tx_count.get() == 0 {
183            Err(TryRecvError::Disconnected)
184        } else {
185            Err(TryRecvError::Empty)
186        }
187    }
188
189    pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
190        let mut queue = self.chan.queue.borrow_mut();
191        if !queue.is_empty() {
192            let val = unsafe { queue.pop_unchecked() };
193            self.chan.semaphore.add_permits(1);
194            return Poll::Ready(Some(val));
195        }
196        if self.chan.tx_count.get() == 0 {
197            return Poll::Ready(None);
198        }
199        let mut borrowed = self.chan.rx_waker.borrow_mut();
200        match borrowed.as_mut() {
201            Some(inner) => {
202                if !inner.will_wake(cx.waker()) {
203                    *inner = cx.waker().clone();
204                }
205            }
206            None => {
207                *borrowed = Some(cx.waker().clone());
208            }
209        }
210        Poll::Pending
211    }
212
213    pub(crate) fn close(&mut self) {
214        self.chan.semaphore.close();
215    }
216}
217
218impl<T, S> Drop for Rx<T, S>
219where
220    S: Semaphore,
221{
222    fn drop(&mut self) {
223        // close semaphore on close, this will make tx send await return.
224        self.chan.semaphore.close();
225        // consume all elements
226        let mut queue = self.chan.queue.borrow_mut();
227        let len = queue.len();
228        while !queue.is_empty() {
229            drop(unsafe { queue.pop_unchecked() });
230        }
231        self.chan.semaphore.add_permits(len);
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::channel;
238    use crate::semaphore::Inner;
239    use futures_util::future::poll_fn;
240
241    #[monoio::test]
242    async fn test_chan() {
243        let semaphore = Inner::new(1);
244        let (tx, mut rx) = channel::<u32, _>(semaphore);
245        assert!(tx.send(1).is_ok());
246        assert_eq!(poll_fn(|cx| rx.recv(cx)).await, Some(1));
247
248        // close rx
249        rx.close();
250        assert!(tx.is_closed());
251    }
252}