flowly_spsc/
lib.rs

1mod atomic_waker;
2mod error;
3
4use std::{
5    cell::UnsafeCell,
6    future::poll_fn,
7    mem::MaybeUninit,
8    pin::Pin,
9    sync::{
10        Arc,
11        atomic::{AtomicBool, AtomicUsize, Ordering},
12    },
13    task::{Context, Poll},
14};
15
16use atomic_waker::AtomicWaker;
17pub use error::{SendError, TryRecvError, TrySendError};
18use futures::Stream;
19
20pub fn channel<T>(size: usize) -> (Sender<T>, Receiver<T>) {
21    let shared = Arc::new(Shared::new(size));
22
23    (
24        Sender {
25            shared: shared.clone(),
26            pos: 0,
27        },
28        Receiver { shared, pos: 0 },
29    )
30}
31
32struct Shared<T> {
33    buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
34    consumer: AtomicWaker,
35    producer: AtomicWaker,
36    count: AtomicUsize,
37    closed: AtomicBool,
38    capacity: usize,
39}
40
41unsafe impl<T: Send> Send for Shared<T> {}
42unsafe impl<T: Sync> Sync for Shared<T> {}
43
44impl<T> Shared<T> {
45    pub(crate) fn new(capacity: usize) -> Self {
46        let capacity = std::cmp::max(capacity + 1, 2);
47        let buf = (0..capacity)
48            .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
49            .collect();
50
51        Self {
52            buf,
53            consumer: Default::default(),
54            producer: Default::default(),
55            closed: AtomicBool::new(false),
56            count: AtomicUsize::new(0),
57            capacity,
58        }
59    }
60
61    #[inline]
62    pub(crate) fn index(&self, index: usize) -> usize {
63        index % self.capacity
64    }
65
66    #[inline]
67    pub(crate) fn len(&self) -> usize {
68        self.count.load(Ordering::Relaxed)
69    }
70
71    #[inline]
72    pub(crate) fn is_empty(&self) -> bool {
73        self.len() == 0
74    }
75
76    #[inline]
77    pub(crate) fn is_full(&self) -> bool {
78        self.len() == self.capacity
79    }
80
81    #[inline]
82    pub(crate) unsafe fn get_unchecked(&self, idx: usize) -> T {
83        let ptr = self.buf.as_ptr();
84
85        unsafe { (&*ptr.add(idx)).get().read().assume_init() }
86    }
87
88    #[inline]
89    pub(crate) unsafe fn set_unchecked(&self, idx: usize, value: T) {
90        unsafe {
91            self.buf
92                .get_unchecked(idx)
93                .get()
94                .write(MaybeUninit::new(value))
95        };
96    }
97}
98
99pub struct Sender<T> {
100    shared: Arc<Shared<T>>,
101    pos: usize,
102}
103
104impl<T> Sender<T> {
105    /// Returns whether this channel is closed.
106    #[inline]
107    pub fn is_closed(&self) -> bool {
108        self.shared.closed.load(Ordering::Relaxed)
109    }
110
111    pub fn start_send(&mut self, item: T) -> Result<(), TrySendError<T>> {
112        if self.is_closed() {
113            return Err(TrySendError {
114                err: SendError::Disconnected,
115                val: item,
116            });
117        }
118
119        if let Some(idx) = self.next_idx() {
120            unsafe {
121                self.shared.set_unchecked(idx, item);
122            }
123
124            Ok(())
125        } else {
126            Err(TrySendError {
127                err: SendError::Full,
128                val: item,
129            })
130        }
131    }
132
133    #[inline]
134    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError>> {
135        if self.shared.is_full() {
136            self.poll_flush(cx)
137        } else {
138            Poll::Ready(Ok(()))
139        }
140    }
141
142    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError>> {
143        if self.is_closed() {
144            Poll::Ready(Err(SendError::Disconnected))
145        } else if self.shared.is_empty() {
146            // if the inner bounded is already empty,
147            // we just return ok to avoid some atomic operation.
148            Poll::Ready(Ok(()))
149        } else {
150            self.shared.producer.register(cx.waker());
151            self.shared.consumer.wake();
152            Poll::Pending
153        }
154    }
155
156    #[inline]
157    pub async fn flush(&mut self) -> Result<(), SendError> {
158        poll_fn(|cx| self.poll_flush(cx)).await
159    }
160
161    pub async fn send(&mut self, item: T) -> Result<(), TrySendError<T>> {
162        let idx = match poll_fn(|cx| self.poll_next_pos(cx)).await {
163            Ok(idx) => idx,
164            Err(err) => return Err(TrySendError { err, val: item }),
165        };
166
167        unsafe {
168            self.shared.set_unchecked(idx, item);
169        }
170
171        self.shared.consumer.wake();
172
173        Ok(())
174    }
175
176    fn poll_next_pos(&mut self, cx: &mut Context<'_>) -> Poll<Result<usize, SendError>> {
177        if self.is_closed() {
178            return Poll::Ready(Err(SendError::Disconnected));
179        }
180
181        if let Some(idx) = self.next_idx() {
182            Poll::Ready(Ok(idx))
183        } else {
184            self.shared.producer.register(cx.waker());
185
186            // We need to poll again, in case of the receiver take some items during
187            // the register and the previous poll
188            if let Some(idx) = self.next_idx() {
189                Poll::Ready(Ok(idx))
190            } else {
191                Poll::Pending
192            }
193        }
194    }
195
196    #[inline]
197    fn next_idx(&mut self) -> Option<usize> {
198        if self.shared.is_full() {
199            None
200        } else {
201            let idx = self.pos;
202            self.pos += 1;
203            self.shared.count.fetch_add(1, Ordering::Relaxed);
204            Some(self.shared.index(idx))
205        }
206    }
207}
208
209impl<T> Drop for Sender<T> {
210    fn drop(&mut self) {
211        // we need to wake up the receiver before
212        // the sender was totally dropped, otherwise the receiver may hang up.
213        self.shared.closed.store(true, Ordering::Relaxed);
214        self.shared.consumer.wake();
215    }
216}
217
218pub struct Receiver<T> {
219    shared: Arc<Shared<T>>,
220    pos: usize,
221}
222
223impl<T> Receiver<T> {
224    pub fn try_recv(&mut self) -> Result<Option<T>, TryRecvError> {
225        match self.try_pop() {
226            None => {
227                // If there is no item in this bounded, we need to
228                // check closed and try pop again.
229                //
230                // Consider this situation:
231                // receiver try pop first, and sender send an item then close.
232                // If we just check closed without pop again, the remaining item will be lost.
233                if self.is_closed() {
234                    match self.try_pop() {
235                        None => Err(TryRecvError::Disconnected),
236                        Some(item) => Ok(Some(item)),
237                    }
238                } else {
239                    Ok(None)
240                }
241            }
242            Some(item) => Ok(Some(item)),
243        }
244    }
245
246    pub fn poll_want_recv(&mut self, cx: &mut Context<'_>) -> Poll<()> {
247        if self.is_closed() {
248            return Poll::Ready(());
249        }
250
251        self.shared.consumer.register(cx.waker());
252        self.shared.producer.wake();
253
254        if self.shared.is_empty() {
255            Poll::Pending
256        } else {
257            Poll::Ready(())
258        }
259    }
260
261    #[inline]
262    pub async fn want_recv(&mut self) {
263        poll_fn(|cx| self.poll_want_recv(cx)).await
264    }
265
266    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
267        if let Poll::Ready(op) = self.poll_next_msg() {
268            return Poll::Ready(Some(op));
269        }
270
271        self.shared.consumer.register(cx.waker());
272
273        // 1. We need to poll again,
274        //    in case of some item was sent between the registering and the previous poll.
275        //
276        // 2. We need to see whether this channel is closed. Because the sender could
277        //    be closed and wake receiver before the register operation, so if we don't check close,
278        //    this method may return Pending and will never be wakeup.
279        if self.is_closed() {
280            match self.poll_next_msg() {
281                Poll::Ready(op) => Poll::Ready(Some(op)),
282                Poll::Pending => Poll::Ready(None),
283            }
284        } else {
285            self.poll_next_msg().map(Some)
286        }
287    }
288
289    #[inline]
290    pub async fn recv(&mut self) -> Option<T> {
291        poll_fn(|cx| self.poll_recv(cx)).await
292    }
293
294    #[inline]
295    pub fn is_closed(&self) -> bool {
296        self.shared.closed.load(Ordering::Relaxed)
297    }
298
299    #[inline]
300    pub fn close(&mut self) {
301        self.shared.closed.store(true, Ordering::Relaxed)
302    }
303
304    fn poll_next_msg(&mut self) -> Poll<T> {
305        match self.try_pop() {
306            None => Poll::Pending,
307            Some(item) => {
308                self.shared.producer.wake();
309                Poll::Ready(item)
310            }
311        }
312    }
313
314    pub(crate) fn try_pop(&mut self) -> Option<T> {
315        if self.shared.is_empty() {
316            None
317        } else {
318            unsafe {
319                let now = self.pos;
320                let idx = self.shared.index(now);
321                self.pos = now + 1;
322                self.shared.count.fetch_sub(1, Ordering::Relaxed);
323                Some(self.shared.get_unchecked(idx))
324            }
325        }
326    }
327}
328
329impl<T> Stream for Receiver<T> {
330    type Item = T;
331
332    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
333        self.poll_recv(cx)
334    }
335}
336
337impl<T> Drop for Receiver<T> {
338    fn drop(&mut self) {
339        self.close();
340    }
341}