memkit_async/
channel.rs

1//! Zero-copy async channels.
2
3use std::collections::VecDeque;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll};
9
10/// Zero-copy channel for transferring ownership between tasks.
11pub struct MkAsyncChannel<T> {
12    inner: Arc<ChannelInner<T>>,
13}
14
15struct ChannelInner<T> {
16    queue: Mutex<VecDeque<T>>,
17    capacity: usize,
18    len: AtomicUsize,
19    closed: AtomicBool,
20}
21
22impl<T> MkAsyncChannel<T> {
23    /// Create a new bounded channel.
24    pub fn bounded(capacity: usize) -> Self {
25        Self {
26            inner: Arc::new(ChannelInner {
27                queue: Mutex::new(VecDeque::with_capacity(capacity)),
28                capacity,
29                len: AtomicUsize::new(0),
30                closed: AtomicBool::new(false),
31            }),
32        }
33    }
34
35    /// Create sender and receiver handles.
36    pub fn split(self) -> (MkAsyncSender<T>, MkAsyncReceiver<T>) {
37        (
38            MkAsyncSender { inner: Arc::clone(&self.inner) },
39            MkAsyncReceiver { inner: self.inner },
40        )
41    }
42
43    /// Get the channel capacity.
44    pub fn capacity(&self) -> usize {
45        self.inner.capacity
46    }
47
48    /// Get the current length.
49    pub fn len(&self) -> usize {
50        self.inner.len.load(Ordering::Relaxed)
51    }
52
53    /// Check if empty.
54    pub fn is_empty(&self) -> bool {
55        self.len() == 0
56    }
57}
58
59/// Sender half of an async channel.
60pub struct MkAsyncSender<T> {
61    inner: Arc<ChannelInner<T>>,
62}
63
64impl<T> MkAsyncSender<T> {
65    /// Send a value through the channel.
66    pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
67        if self.inner.closed.load(Ordering::Acquire) {
68            return Err(SendError::Closed(value));
69        }
70
71        loop {
72            let len = self.inner.len.load(Ordering::Acquire);
73            if len < self.inner.capacity {
74                let mut queue = self.inner.queue.lock().unwrap();
75                if queue.len() < self.inner.capacity {
76                    queue.push_back(value);
77                    self.inner.len.fetch_add(1, Ordering::Release);
78                    return Ok(());
79                }
80            }
81            
82            // Yield and retry
83            YieldOnce::new().await;
84            
85            if self.inner.closed.load(Ordering::Acquire) {
86                return Err(SendError::Closed(value));
87            }
88        }
89    }
90
91    /// Try to send without waiting.
92    pub fn try_send(&self, value: T) -> Result<(), SendError<T>> {
93        if self.inner.closed.load(Ordering::Acquire) {
94            return Err(SendError::Closed(value));
95        }
96
97        let mut queue = self.inner.queue.lock().unwrap();
98        if queue.len() < self.inner.capacity {
99            queue.push_back(value);
100            self.inner.len.fetch_add(1, Ordering::Release);
101            Ok(())
102        } else {
103            Err(SendError::Full(value))
104        }
105    }
106
107    /// Close the channel.
108    pub fn close(&self) {
109        self.inner.closed.store(true, Ordering::Release);
110    }
111}
112
113impl<T> Clone for MkAsyncSender<T> {
114    fn clone(&self) -> Self {
115        Self { inner: Arc::clone(&self.inner) }
116    }
117}
118
119/// Receiver half of an async channel.
120pub struct MkAsyncReceiver<T> {
121    inner: Arc<ChannelInner<T>>,
122}
123
124impl<T> MkAsyncReceiver<T> {
125    /// Receive a value from the channel.
126    pub async fn recv(&self) -> Option<T> {
127        loop {
128            if let Some(value) = self.try_recv() {
129                return Some(value);
130            }
131            
132            if self.inner.closed.load(Ordering::Acquire) && self.inner.len.load(Ordering::Acquire) == 0 {
133                return None;
134            }
135            
136            YieldOnce::new().await;
137        }
138    }
139
140    /// Try to receive without waiting.
141    pub fn try_recv(&self) -> Option<T> {
142        let mut queue = self.inner.queue.lock().unwrap();
143        if let Some(value) = queue.pop_front() {
144            self.inner.len.fetch_sub(1, Ordering::Release);
145            Some(value)
146        } else {
147            None
148        }
149    }
150
151    /// Check if the channel is closed.
152    pub fn is_closed(&self) -> bool {
153        self.inner.closed.load(Ordering::Acquire)
154    }
155}
156
157/// Error when sending fails.
158#[derive(Debug)]
159pub enum SendError<T> {
160    Full(T),
161    Closed(T),
162}
163
164impl<T> SendError<T> {
165    /// Get the value that failed to send.
166    pub fn into_inner(self) -> T {
167        match self {
168            SendError::Full(v) | SendError::Closed(v) => v,
169        }
170    }
171}
172
173/// Yield once to the runtime.
174struct YieldOnce(bool);
175
176impl YieldOnce {
177    fn new() -> Self {
178        Self(false)
179    }
180}
181
182impl Future for YieldOnce {
183    type Output = ();
184
185    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186        if self.0 {
187            Poll::Ready(())
188        } else {
189            self.0 = true;
190            cx.waker().wake_by_ref();
191            Poll::Pending
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_channel_sync() {
202        let channel: MkAsyncChannel<u32> = MkAsyncChannel::bounded(3);
203        let (tx, rx) = channel.split();
204        
205        tx.try_send(1).unwrap();
206        tx.try_send(2).unwrap();
207        tx.try_send(3).unwrap();
208        assert!(tx.try_send(4).is_err());
209        
210        assert_eq!(rx.try_recv(), Some(1));
211        assert_eq!(rx.try_recv(), Some(2));
212        assert_eq!(rx.try_recv(), Some(3));
213        assert_eq!(rx.try_recv(), None);
214    }
215}