flowly_spsc/
lib.rs

1mod atomic_waker;
2mod error;
3
4use std::{
5    cell::UnsafeCell,
6    future::poll_fn,
7    mem::MaybeUninit,
8    pin::Pin,
9    sync::{
10        Arc,
11        atomic::{AtomicBool, AtomicUsize, Ordering},
12    },
13    task::{Context, Poll},
14};
15
16use atomic_waker::AtomicWaker;
17pub use error::{SendError, TryRecvError, TrySendError};
18use futures::Stream;
19
20pub fn channel<T>(size: usize) -> (Sender<T>, Receiver<T>) {
21    let shared = Arc::new(Shared::new(size));
22
23    (
24        Sender {
25            shared: shared.clone(),
26            pos: 0,
27        },
28        Receiver { shared, pos: 0 },
29    )
30}
31
32struct Shared<T> {
33    buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
34    consumer: AtomicWaker,
35    producer: AtomicWaker,
36    count: AtomicUsize,
37    closed: AtomicBool,
38    capacity: usize,
39}
40
41unsafe impl<T: Send> Send for Shared<T> {}
42unsafe impl<T: Send> Sync for Shared<T> {}
43
44impl<T> Shared<T> {
45    pub(crate) fn new(capacity: usize) -> Self {
46        let capacity = std::cmp::max(capacity + 1, 2);
47        let buf = (0..capacity)
48            .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
49            .collect();
50
51        Self {
52            buf,
53            consumer: Default::default(),
54            producer: Default::default(),
55            closed: AtomicBool::new(false),
56            count: AtomicUsize::new(0),
57            capacity,
58        }
59    }
60
61    #[inline]
62    pub(crate) fn index(&self, index: usize) -> usize {
63        index % self.capacity
64    }
65
66    #[inline]
67    pub(crate) fn len(&self) -> usize {
68        self.count.load(Ordering::Relaxed)
69    }
70
71    #[inline]
72    pub(crate) fn is_empty(&self) -> bool {
73        self.len() == 0
74    }
75
76    #[inline]
77    pub(crate) fn is_full(&self) -> bool {
78        self.len() == self.capacity
79    }
80
81    #[inline]
82    pub(crate) unsafe fn get_unchecked(&self, idx: usize) -> T {
83        let ptr = self.buf.as_ptr();
84
85        unsafe { (&*ptr.add(idx)).get().read().assume_init() }
86    }
87
88    #[inline]
89    pub(crate) unsafe fn set_unchecked(&self, idx: usize, value: T) {
90        unsafe {
91            self.buf
92                .get_unchecked(idx)
93                .get()
94                .write(MaybeUninit::new(value))
95        };
96    }
97}
98
99pub struct Sender<T> {
100    shared: Arc<Shared<T>>,
101    pos: usize,
102}
103
104impl<T> Sender<T> {
105    /// Returns whether this channel is closed.
106    #[inline]
107    pub fn is_closed(&self) -> bool {
108        self.shared.closed.load(Ordering::Relaxed)
109    }
110
111    #[inline]
112    pub fn close(&mut self) {
113        self.shared.closed.store(true, Ordering::Relaxed)
114    }
115
116    #[inline]
117    pub fn try_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
118        self.try_send_inner(item, true)
119    }
120
121    #[inline]
122    pub fn start_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
123        self.try_send_inner(item, false)
124    }
125
126    fn try_send_inner(&mut self, item: T, wake: bool) -> Result<(), TrySendError<T>> {
127        if self.is_closed() {
128            return Err(TrySendError {
129                err: SendError::Disconnected,
130                val: item,
131            });
132        }
133
134        if let Some(idx) = self.next_idx() {
135            unsafe {
136                self.shared.set_unchecked(idx, item);
137            }
138
139            if wake {
140                self.shared.consumer.wake();
141            }
142
143            Ok(())
144        } else {
145            Err(TrySendError {
146                err: SendError::Full,
147                val: item,
148            })
149        }
150    }
151
152    #[inline]
153    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError>> {
154        if self.shared.is_full() {
155            self.poll_flush(cx)
156        } else {
157            Poll::Ready(Ok(()))
158        }
159    }
160
161    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError>> {
162        if self.is_closed() {
163            Poll::Ready(Err(SendError::Disconnected))
164        } else if self.shared.is_empty() {
165            // if the inner bounded is already empty,
166            // we just return ok to avoid some atomic operation.
167            Poll::Ready(Ok(()))
168        } else {
169            self.shared.producer.register(cx.waker());
170            self.shared.consumer.wake();
171            Poll::Pending
172        }
173    }
174
175    #[inline]
176    pub async fn flush(&mut self) -> Result<(), SendError> {
177        poll_fn(|cx| self.poll_flush(cx)).await
178    }
179
180    pub async fn send(&mut self, item: T) -> Result<(), TrySendError<T>> {
181        let idx = match poll_fn(|cx| self.poll_next_pos(cx)).await {
182            Ok(idx) => idx,
183            Err(err) => return Err(TrySendError { err, val: item }),
184        };
185
186        unsafe {
187            self.shared.set_unchecked(idx, item);
188        }
189
190        self.shared.consumer.wake();
191
192        Ok(())
193    }
194
195    fn poll_next_pos(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize, SendError>> {
196        if self.is_closed() {
197            return Poll::Ready(Err(SendError::Disconnected));
198        }
199
200        if let Some(idx) = self.next_idx() {
201            Poll::Ready(Ok(idx))
202        } else {
203            self.shared.producer.register(cx.waker());
204
205            // We need to poll again, in case of the receiver take some items during
206            // the register and the previous poll
207            if let Some(idx) = self.next_idx() {
208                Poll::Ready(Ok(idx))
209            } else {
210                Poll::Pending
211            }
212        }
213    }
214
215    #[inline]
216    fn next_idx(&mut self) -> Option<usize> {
217        if self.shared.is_full() {
218            None
219        } else {
220            let idx = self.pos;
221            self.pos += 1;
222            self.shared.count.fetch_add(1, Ordering::Relaxed);
223            Some(self.shared.index(idx))
224        }
225    }
226}
227
228impl<T> Drop for Sender<T> {
229    fn drop(&mut self) {
230        // we need to wake up the receiver before
231        // the sender was totally dropped, otherwise the receiver may hang up.
232        self.shared.closed.store(true, Ordering::Relaxed);
233        self.shared.consumer.wake();
234    }
235}
236
237pub struct Receiver<T> {
238    shared: Arc<Shared<T>>,
239    pos: usize,
240}
241
242impl<T> Receiver<T> {
243    pub fn try_recv(&mut self) -> Result<Option<T>, TryRecvError> {
244        match self.try_pop() {
245            None => {
246                // If there is no item in this bounded, we need to
247                // check closed and try pop again.
248                //
249                // Consider this situation:
250                // receiver try pop first, and sender send an item then close.
251                // If we just check closed without pop again, the remaining item will be lost.
252                if self.is_closed() {
253                    match self.try_pop() {
254                        None => Err(TryRecvError::Disconnected),
255                        Some(item) => Ok(Some(item)),
256                    }
257                } else {
258                    Ok(None)
259                }
260            }
261            Some(item) => Ok(Some(item)),
262        }
263    }
264
265    pub fn poll_want_recv(&mut self, cx: &mut Context<'_>) -> Poll<()> {
266        if self.is_closed() {
267            return Poll::Ready(());
268        }
269
270        self.shared.consumer.register(cx.waker());
271        self.shared.producer.wake();
272
273        if self.shared.is_empty() {
274            Poll::Pending
275        } else {
276            Poll::Ready(())
277        }
278    }
279
280    #[inline]
281    pub async fn want_recv(&mut self) {
282        poll_fn(|cx| self.poll_want_recv(cx)).await
283    }
284
285    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
286        if let Poll::Ready(op) = self.poll_next_msg() {
287            return Poll::Ready(Some(op));
288        }
289
290        self.shared.consumer.register(cx.waker());
291
292        // 1. We need to poll again,
293        //    in case of some item was sent between the registering and the previous poll.
294        //
295        // 2. We need to see whether this channel is closed. Because the sender could
296        //    be closed and wake receiver before the register operation, so if we don't check close,
297        //    this method may return Pending and will never be wakeup.
298        if self.is_closed() {
299            match self.poll_next_msg() {
300                Poll::Ready(op) => Poll::Ready(Some(op)),
301                Poll::Pending => Poll::Ready(None),
302            }
303        } else {
304            self.poll_next_msg().map(Some)
305        }
306    }
307
308    #[inline]
309    pub async fn recv(&mut self) -> Option<T> {
310        poll_fn(|cx| self.poll_recv(cx)).await
311    }
312
313    #[inline]
314    pub fn is_closed(&self) -> bool {
315        self.shared.closed.load(Ordering::Relaxed)
316    }
317
318    #[inline]
319    pub fn close(&mut self) {
320        self.shared.closed.store(true, Ordering::Relaxed)
321    }
322
323    fn poll_next_msg(&mut self) -> Poll<T> {
324        match self.try_pop() {
325            None => Poll::Pending,
326            Some(item) => {
327                self.shared.producer.wake();
328                Poll::Ready(item)
329            }
330        }
331    }
332
333    pub(crate) fn try_pop(&mut self) -> Option<T> {
334        if self.shared.is_empty() {
335            None
336        } else {
337            unsafe {
338                let now = self.pos;
339                let idx = self.shared.index(now);
340                self.pos = now + 1;
341                self.shared.count.fetch_sub(1, Ordering::Relaxed);
342                Some(self.shared.get_unchecked(idx))
343            }
344        }
345    }
346}
347
348impl<T> Stream for Receiver<T> {
349    type Item = T;
350
351    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
352        self.poll_recv(cx)
353    }
354}
355
356impl<T> Drop for Receiver<T> {
357    fn drop(&mut self) {
358        self.close();
359    }
360}