1use std::collections::VecDeque;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9
10pub struct MkAsyncChannel<T> {
12 inner: Arc<ChannelInner<T>>,
13}
14
15struct ChannelInner<T> {
16 queue: Mutex<VecDeque<T>>,
17 capacity: usize,
18 len: AtomicUsize,
19 closed: AtomicBool,
20}
21
22impl<T> MkAsyncChannel<T> {
23 pub fn bounded(capacity: usize) -> Self {
25 Self {
26 inner: Arc::new(ChannelInner {
27 queue: Mutex::new(VecDeque::with_capacity(capacity)),
28 capacity,
29 len: AtomicUsize::new(0),
30 closed: AtomicBool::new(false),
31 }),
32 }
33 }
34
35 pub fn split(self) -> (MkAsyncSender<T>, MkAsyncReceiver<T>) {
37 (
38 MkAsyncSender { inner: Arc::clone(&self.inner) },
39 MkAsyncReceiver { inner: self.inner },
40 )
41 }
42
43 pub fn capacity(&self) -> usize {
45 self.inner.capacity
46 }
47
48 pub fn len(&self) -> usize {
50 self.inner.len.load(Ordering::Relaxed)
51 }
52
53 pub fn is_empty(&self) -> bool {
55 self.len() == 0
56 }
57}
58
59pub struct MkAsyncSender<T> {
61 inner: Arc<ChannelInner<T>>,
62}
63
64impl<T> MkAsyncSender<T> {
65 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
67 if self.inner.closed.load(Ordering::Acquire) {
68 return Err(SendError::Closed(value));
69 }
70
71 loop {
72 let len = self.inner.len.load(Ordering::Acquire);
73 if len < self.inner.capacity {
74 let mut queue = self.inner.queue.lock().unwrap();
75 if queue.len() < self.inner.capacity {
76 queue.push_back(value);
77 self.inner.len.fetch_add(1, Ordering::Release);
78 return Ok(());
79 }
80 }
81
82 YieldOnce::new().await;
84
85 if self.inner.closed.load(Ordering::Acquire) {
86 return Err(SendError::Closed(value));
87 }
88 }
89 }
90
91 pub fn try_send(&self, value: T) -> Result<(), SendError<T>> {
93 if self.inner.closed.load(Ordering::Acquire) {
94 return Err(SendError::Closed(value));
95 }
96
97 let mut queue = self.inner.queue.lock().unwrap();
98 if queue.len() < self.inner.capacity {
99 queue.push_back(value);
100 self.inner.len.fetch_add(1, Ordering::Release);
101 Ok(())
102 } else {
103 Err(SendError::Full(value))
104 }
105 }
106
107 pub fn close(&self) {
109 self.inner.closed.store(true, Ordering::Release);
110 }
111}
112
113impl<T> Clone for MkAsyncSender<T> {
114 fn clone(&self) -> Self {
115 Self { inner: Arc::clone(&self.inner) }
116 }
117}
118
119pub struct MkAsyncReceiver<T> {
121 inner: Arc<ChannelInner<T>>,
122}
123
124impl<T> MkAsyncReceiver<T> {
125 pub async fn recv(&self) -> Option<T> {
127 loop {
128 if let Some(value) = self.try_recv() {
129 return Some(value);
130 }
131
132 if self.inner.closed.load(Ordering::Acquire) && self.inner.len.load(Ordering::Acquire) == 0 {
133 return None;
134 }
135
136 YieldOnce::new().await;
137 }
138 }
139
140 pub fn try_recv(&self) -> Option<T> {
142 let mut queue = self.inner.queue.lock().unwrap();
143 if let Some(value) = queue.pop_front() {
144 self.inner.len.fetch_sub(1, Ordering::Release);
145 Some(value)
146 } else {
147 None
148 }
149 }
150
151 pub fn is_closed(&self) -> bool {
153 self.inner.closed.load(Ordering::Acquire)
154 }
155}
156
157#[derive(Debug)]
159pub enum SendError<T> {
160 Full(T),
161 Closed(T),
162}
163
164impl<T> SendError<T> {
165 pub fn into_inner(self) -> T {
167 match self {
168 SendError::Full(v) | SendError::Closed(v) => v,
169 }
170 }
171}
172
173struct YieldOnce(bool);
175
176impl YieldOnce {
177 fn new() -> Self {
178 Self(false)
179 }
180}
181
182impl Future for YieldOnce {
183 type Output = ();
184
185 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186 if self.0 {
187 Poll::Ready(())
188 } else {
189 self.0 = true;
190 cx.waker().wake_by_ref();
191 Poll::Pending
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_channel_sync() {
202 let channel: MkAsyncChannel<u32> = MkAsyncChannel::bounded(3);
203 let (tx, rx) = channel.split();
204
205 tx.try_send(1).unwrap();
206 tx.try_send(2).unwrap();
207 tx.try_send(3).unwrap();
208 assert!(tx.try_send(4).is_err());
209
210 assert_eq!(rx.try_recv(), Some(1));
211 assert_eq!(rx.try_recv(), Some(2));
212 assert_eq!(rx.try_recv(), Some(3));
213 assert_eq!(rx.try_recv(), None);
214 }
215}