1use std::{
2 cell::{Cell, RefCell},
3 error::Error,
4 fmt,
5 rc::Rc,
6 task::{Context, Poll, Waker},
7};
8
9use super::{block::Queue, semaphore::Semaphore};
10
11pub(crate) fn channel<T, S>(semaphore: S) -> (Tx<T, S>, Rx<T, S>)
12where
13 S: Semaphore,
14{
15 let chan = Rc::new(Chan::new(semaphore));
16 let tx = Tx::new(chan.clone());
17 let rx = Rx::new(chan);
18 (tx, rx)
19}
20
21pub(crate) struct Chan<T, S: Semaphore> {
22 queue: RefCell<Queue<T>>,
23 pub(crate) semaphore: S,
24 rx_waker: RefCell<Option<Waker>>,
25 tx_count: Cell<usize>,
26}
27
28#[derive(PartialEq, Eq, Clone, Copy, Debug)]
30pub enum TryRecvError {
31 Empty,
34 Disconnected,
37}
38
39impl fmt::Display for TryRecvError {
40 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match *self {
42 TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
43 TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
44 }
45 }
46}
47
48impl Error for TryRecvError {}
49
50impl<T, S> Chan<T, S>
51where
52 S: Semaphore,
53{
54 pub(crate) fn new(semaphore: S) -> Self {
55 let queue = RefCell::new(Queue::new());
56 Self {
57 queue,
58 semaphore,
59 rx_waker: RefCell::new(None),
60 tx_count: Cell::new(0),
61 }
62 }
63}
64
65impl<T, S> Drop for Chan<T, S>
66where
67 S: Semaphore,
68{
69 fn drop(&mut self) {
70 let mut queue = self.queue.borrow_mut();
74 while !queue.is_empty() {
75 drop(unsafe { queue.pop_unchecked() });
76 }
77 unsafe { queue.free_blocks() }
79 }
80}
81
82pub(crate) struct Tx<T, S>
83where
84 S: Semaphore,
85{
86 pub(crate) chan: Rc<Chan<T, S>>,
87}
88
89#[derive(PartialEq, Eq, Clone, Copy, Debug)]
90pub enum SendError {
91 RxClosed,
92}
93
94pub(crate) struct Rx<T, S>
95where
96 S: Semaphore,
97{
98 chan: Rc<Chan<T, S>>,
99}
100
101impl<T, S> Tx<T, S>
102where
103 S: Semaphore,
104{
105 pub(crate) fn new(chan: Rc<Chan<T, S>>) -> Self {
106 chan.tx_count.set(chan.tx_count.get() + 1);
107 Self { chan }
108 }
109
110 pub(crate) fn send(&self, value: T) -> Result<(), SendError> {
112 if self.chan.semaphore.is_closed() {
114 return Err(SendError::RxClosed);
115 }
116
117 unsafe {
119 self.chan.queue.borrow_mut().push_unchecked(value);
120 }
121 if let Some(w) = self.chan.rx_waker.replace(None) {
123 w.wake();
124 }
125 Ok(())
126 }
127
128 pub fn is_closed(&self) -> bool {
129 self.chan.semaphore.is_closed()
130 }
131
132 pub(crate) fn same_channel(&self, other: &Self) -> bool {
134 Rc::ptr_eq(&self.chan, &other.chan)
135 }
136}
137
138impl<T, S> Clone for Tx<T, S>
139where
140 S: Semaphore,
141{
142 fn clone(&self) -> Self {
143 self.chan.tx_count.set(self.chan.tx_count.get() + 1);
144 Self {
145 chan: self.chan.clone(),
146 }
147 }
148}
149
150impl<T, S> Drop for Tx<T, S>
151where
152 S: Semaphore,
153{
154 fn drop(&mut self) {
155 let cnt = self.chan.tx_count.get();
156 self.chan.tx_count.set(cnt - 1);
157
158 if cnt == 1 {
159 self.chan.semaphore.close();
160 if let Some(rx_waker) = self.chan.rx_waker.take() {
161 rx_waker.wake();
162 }
163 }
164 }
165}
166
167impl<T, S> Rx<T, S>
168where
169 S: Semaphore,
170{
171 pub(crate) fn new(chan: Rc<Chan<T, S>>) -> Self {
172 Self { chan }
173 }
174
175 pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
176 let mut queue = self.chan.queue.borrow_mut();
177 if !queue.is_empty() {
178 let val = unsafe { queue.pop_unchecked() };
179 self.chan.semaphore.add_permits(1);
180 return Ok(val);
181 }
182 if self.chan.tx_count.get() == 0 {
183 Err(TryRecvError::Disconnected)
184 } else {
185 Err(TryRecvError::Empty)
186 }
187 }
188
189 pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
190 let mut queue = self.chan.queue.borrow_mut();
191 if !queue.is_empty() {
192 let val = unsafe { queue.pop_unchecked() };
193 self.chan.semaphore.add_permits(1);
194 return Poll::Ready(Some(val));
195 }
196 if self.chan.tx_count.get() == 0 {
197 return Poll::Ready(None);
198 }
199 let mut borrowed = self.chan.rx_waker.borrow_mut();
200 match borrowed.as_mut() {
201 Some(inner) => {
202 if !inner.will_wake(cx.waker()) {
203 *inner = cx.waker().clone();
204 }
205 }
206 None => {
207 *borrowed = Some(cx.waker().clone());
208 }
209 }
210 Poll::Pending
211 }
212
213 pub(crate) fn close(&mut self) {
214 self.chan.semaphore.close();
215 }
216}
217
218impl<T, S> Drop for Rx<T, S>
219where
220 S: Semaphore,
221{
222 fn drop(&mut self) {
223 self.chan.semaphore.close();
225 let mut queue = self.chan.queue.borrow_mut();
227 let len = queue.len();
228 while !queue.is_empty() {
229 drop(unsafe { queue.pop_unchecked() });
230 }
231 self.chan.semaphore.add_permits(len);
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::channel;
238 use crate::semaphore::Inner;
239 use futures_util::future::poll_fn;
240
241 #[monoio::test]
242 async fn test_chan() {
243 let semaphore = Inner::new(1);
244 let (tx, mut rx) = channel::<u32, _>(semaphore);
245 assert!(tx.send(1).is_ok());
246 assert_eq!(poll_fn(|cx| rx.recv(cx)).await, Some(1));
247
248 rx.close();
250 assert!(tx.is_closed());
251 }
252}