async_unsync/
queue.rs

1use core::{
2    cell::UnsafeCell,
3    cmp,
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll, Waker},
7};
8
9use crate::{
10    alloc::collections::VecDeque,
11    error::{SendError, TryRecvError, TrySendError},
12    mask::Mask,
13    semaphore::Semaphore,
14};
15
16/// Specialization of `UnsyncQueue` for bounded queues.
17pub(crate) type BoundedQueue<T> = UnsyncQueue<T, Bounded>;
18/// Specialization of `UnsyncQueue` for unbounded queues.
19pub(crate) type UnboundedQueue<T> = UnsyncQueue<T, Unbounded>;
20
21/// An unsynchronized wrapper for a [`Queue`] using [`UnsafeCell`].
22pub(crate) struct UnsyncQueue<T, B>(pub(crate) UnsafeCell<Queue<T, B>>);
23
24impl<T, B> UnsyncQueue<T, B>
25where
26    Queue<T, B>: MaybeBoundedQueue<Item = T>,
27{
28    pub(crate) fn into_deque(self) -> VecDeque<T> {
29        self.0.into_inner().queue
30    }
31
32    pub(crate) fn len(&self) -> usize {
33        // SAFETY: no mutable or aliased access to queue possible
34        unsafe { (*self.0.get()).queue.len() }
35    }
36
37    pub(crate) fn close<const COUNTED: bool>(&self) {
38        // SAFETY: no mutable or aliased access to queue possible
39        unsafe { &mut *self.0.get() }.close::<COUNTED>();
40    }
41
42    pub(crate) fn is_closed<const COUNTED: bool>(&self) -> bool {
43        // SAFETY: no mutable or aliased access to queue possible
44        unsafe { (*self.0.get()).mask.is_closed::<COUNTED>() }
45    }
46
47    pub(crate) fn try_recv<const COUNTED: bool>(&self) -> Result<T, TryRecvError> {
48        // SAFETY: no mutable or aliased access to queue possible
49        unsafe { (*self.0.get()).try_recv::<COUNTED>() }
50    }
51
52    pub(crate) fn poll_recv<const COUNTED: bool>(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
53        // SAFETY: no mutable or aliased access to queue possible
54        unsafe { (*self.0.get()).poll_recv::<COUNTED>(cx) }
55    }
56
57    pub(crate) async fn recv<const COUNTED: bool>(&self) -> Option<T> {
58        RecvFuture::<'_, _, _, COUNTED> { queue: &self.0 }.await
59    }
60}
61
62impl<T> UnsyncQueue<T, Unbounded> {
63    pub(crate) const fn new() -> Self {
64        Self(UnsafeCell::new(Queue::new(VecDeque::new(), Unbounded)))
65    }
66
67    pub(crate) fn with_capacity(capacity: usize) -> Self {
68        Self(UnsafeCell::new(Queue::new(VecDeque::with_capacity(capacity), Unbounded)))
69    }
70
71    pub(crate) fn send<const COUNTED: bool>(&self, elem: T) -> Result<(), SendError<T>> {
72        // SAFETY: no mutable or aliased access to queue possible
73        let queue = unsafe { &mut *self.0.get() };
74
75        // check if the channel was closed
76        if queue.mask.is_closed::<COUNTED>() {
77            return Err(SendError(elem));
78        }
79
80        // ..otherwise push `elem` and wake a potential waiter
81        queue.push_and_wake(elem);
82        Ok(())
83    }
84}
85
86impl<T> FromIterator<T> for UnsyncQueue<T, Unbounded> {
87    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
88        Self(UnsafeCell::new(Queue::new(VecDeque::from_iter(iter), Unbounded)))
89    }
90}
91
92#[cold]
93const fn assert_capacity(capacity: usize) {
94    assert!(capacity > 0, "channel capacity must be at least 1");
95}
96
97impl<T> UnsyncQueue<T, Bounded> {
98    pub(crate) const fn new(capacity: usize) -> Self {
99        assert_capacity(capacity);
100        Self(UnsafeCell::new(Queue::new(
101            VecDeque::new(),
102            Bounded { semaphore: Semaphore::new(capacity), max_capacity: capacity, reserved: 0 },
103        )))
104    }
105
106    pub(crate) fn with_capacity(capacity: usize, initial: usize) -> Self {
107        assert_capacity(capacity);
108        let initial = core::cmp::max(capacity, initial);
109        Self(UnsafeCell::new(Queue::new(
110            VecDeque::with_capacity(initial),
111            Bounded { semaphore: Semaphore::new(capacity), max_capacity: capacity, reserved: 0 },
112        )))
113    }
114
115    pub(crate) fn from_iter(capacity: usize, iter: impl IntoIterator<Item = T>) -> Self {
116        let queue = VecDeque::from_iter(iter);
117        let capacity = cmp::max(queue.len(), capacity);
118        let initial_capacity = capacity - queue.len();
119
120        Self(UnsafeCell::new(Queue::new(
121            queue,
122            Bounded {
123                semaphore: Semaphore::new(initial_capacity),
124                max_capacity: capacity,
125                reserved: 0,
126            },
127        )))
128    }
129
130    pub(crate) fn max_capacity(&self) -> usize {
131        // SAFETY: no mutable or aliased access to queue possible
132        unsafe { (*self.0.get()).extra.max_capacity }
133    }
134
135    pub(crate) fn capacity(&self) -> usize {
136        // SAFETY: no mutable or aliased access to queue possible
137        unsafe { (*self.0.get()).extra.semaphore.available_permits() }
138    }
139
140    pub(crate) fn unbounded_send(&self, elem: T) {
141        // SAFETY: no mutable or aliased access to queue possible
142        let queue = unsafe { &mut *self.0.get() };
143        queue.push_and_wake(elem);
144    }
145
146    pub(crate) fn try_send<const COUNTED: bool>(&self, elem: T) -> Result<(), TrySendError<T>> {
147        // SAFETY: no mutable or aliased access to queue possible
148        let queue = unsafe { &mut *self.0.get() };
149
150        // check if there is room in the channel and the channel is still open
151        let permit = match queue.extra.semaphore.try_acquire() {
152            Ok(permit) => permit,
153            Err(e) => return Err((e, elem).into()),
154        };
155
156        // Forgetting the permit permanently decreases the number of available
157        // permits, which is increased again after the element is dequeued.
158        // The order (i.e., forget first) is somewhat important, because `wake`
159        // might panic (which can be caught), but only after `elem` is pushed.
160        permit.forget();
161        queue.push_and_wake(elem);
162
163        Ok(())
164    }
165
166    /// Performs a bounded send.
167    pub(crate) async fn send<const COUNTED: bool>(&self, elem: T) -> Result<(), SendError<T>> {
168        // try to acquire a free slot in the queue
169        let ptr = self.0.get();
170        // SAFETY: no mutable or aliased access to queue possible (a mutable
171        // reference **MUST NOT** be held across the await!)
172        let Ok(permit) = unsafe { (*ptr).extra.semaphore.acquire() }.await else {
173            return Err(SendError(elem));
174        };
175
176        // Forgetting the permit permanently decreases the number of available
177        // permits, which is increased again after the element is dequeued.
178        // The order, i.e., forget first, is somewhat important, because `wake`
179        // might panic (which can be caught), but only after `elem` is pushed.
180        permit.forget();
181        // SAFETY: no mutable or aliased access to queue possible
182        unsafe { (*ptr).push_and_wake(elem) };
183
184        Ok(())
185    }
186
187    pub(crate) fn try_reserve<const COUNTED: bool>(&self) -> Result<(), TrySendError<()>> {
188        // SAFETY: no mutable or aliased access to queue possible
189        let queue = unsafe { &mut *self.0.get() };
190
191        // check if there is room in the channel and the channel is still open
192        let permit = queue.extra.semaphore.try_acquire()?;
193
194        // Forgetting the permit permanently decreases the number of
195        // available permits. This (semaphore) permit is later "revived"
196        // when the returned (queue/channel) permit is dropped, so that the
197        // semaphore's permit count is correctly increased. This is done to
198        // avoid storing an additional (redundant) reference in the `Permit`
199        // struct.
200        permit.forget();
201        queue.extra.reserved += 1;
202        Ok(())
203    }
204
205    pub(crate) async fn reserve<const COUNTED: bool>(&self) -> Result<(), SendError<()>> {
206        // acquire a free slot in the queue
207        let ptr = self.0.get();
208        // SAFETY: no mutable or aliased access to queue possible (a mutable
209        // reference **MUST NOT** be held across the await!)
210        let Ok(permit) = unsafe { (*ptr).extra.semaphore.acquire() }.await else {
211            return Err(SendError(()));
212        };
213
214        // Forgetting the permit permanently decreases the number of
215        // available permits. This (semaphore) permit is later "revived"
216        // when the returned (queue/channel) permit is dropped, so that the
217        // semaphore's permit count is correctly increased. This is done to
218        // avoid storing an additional (redundant) reference in the `Permit`
219        // struct.
220        permit.forget();
221        unsafe { (*ptr).extra.reserved += 1 };
222        Ok(())
223    }
224
225    pub(crate) fn unreserve(&self, consumed: bool) {
226        // SAFETY: no mutable or aliased access to queue possible
227        let queue = unsafe { &mut (*self.0.get()) };
228        queue.extra.reserved -= 1;
229        if !consumed {
230            queue.extra.semaphore.add_permits(1);
231        }
232    }
233}
234
235pub(crate) struct Queue<T, B = Unbounded> {
236    /// The mask storing the closed flag and number of active senders.
237    pub(crate) mask: Mask,
238    /// The queue storing each element sent through the channel
239    queue: VecDeque<T>,
240    /// The current count of pop operations since the last reset.
241    pop_count: usize,
242    /// The waker for the receiver.
243    waker: Option<Waker>,
244    /// Extra state for bounded or unbounded specialization.
245    extra: B,
246}
247
248impl<T, B> Queue<T, B> {
249    pub(crate) fn decrease_sender_count(&mut self) {
250        if self.mask.decrease_sender_count() {
251            if let Some(waker) = self.waker.take() {
252                waker.wake();
253            }
254        }
255    }
256
257    const fn new(queue: VecDeque<T>, extra: B) -> Self {
258        Queue { mask: Mask::new(), queue, pop_count: 0, waker: None, extra }
259    }
260
261    /// Pushes `elem` to the back of the queue and wakes the registered
262    /// waker if set.
263    fn push_and_wake(&mut self, elem: T) {
264        self.queue.push_back(elem);
265        if let Some(waker) = &self.waker {
266            waker.wake_by_ref();
267        }
268    }
269
270    /// Pops the first element in the queue and checks if the queue's capacity
271    /// should be shrunk.
272    fn pop_front(&mut self) -> Option<T> {
273        match self.queue.pop_front() {
274            Some(elem) => {
275                // check every 4k ops, if the queue can be shrunk
276                self.pop_count += 1;
277                if self.pop_count == 4096 {
278                    self.try_shrink_queue();
279                }
280
281                Some(elem)
282            }
283            None => {
284                // when the queue first becomes empty, try to shrink it once.
285                if self.pop_count > 0 {
286                    self.try_shrink_queue();
287                }
288
289                None
290            }
291        }
292    }
293
294    /// Shrinks the queue's capacity to `length + 32` if current capacity is
295    /// at least 4 times that.
296    fn try_shrink_queue(&mut self) {
297        let target_capacity = self.queue.len() + 32;
298        if self.queue.capacity() / 4 > (target_capacity) {
299            self.queue.shrink_to(target_capacity);
300        }
301
302        self.pop_count = 0;
303    }
304}
305
306impl<T, B> Queue<T, B>
307where
308    Self: MaybeBoundedQueue<Item = T>,
309{
310    #[cold]
311    pub(crate) fn set_counted(&mut self) {
312        self.reset();
313        self.waker = None;
314        self.mask.reset::<{ crate::mask::COUNTED }>();
315    }
316
317    pub(crate) fn poll_recv<const COUNTED: bool>(
318        &mut self,
319        cx: &mut Context<'_>,
320    ) -> Poll<Option<T>> {
321        match self.try_recv::<COUNTED>() {
322            Ok(elem) => Poll::Ready(Some(elem)),
323            Err(TryRecvError::Disconnected) => Poll::Ready(None),
324            Err(TryRecvError::Empty) => {
325                // this overwrite any previous waker, this is unproblematic if
326                // the same future is polled (spuriously) more than once, but
327                // would like result in one future to stay pending forever if
328                // more than one `RecvFuture`s for one channel with overlapping
329                // lifetimes were to be polled.
330                self.waker = Some(cx.waker().clone());
331                Poll::Pending
332            }
333        }
334    }
335}
336
337impl<T> MaybeBoundedQueue for Queue<T, Unbounded> {
338    type Item = T;
339
340    fn reset(&mut self) {}
341
342    #[cold]
343    fn close<const COUNTED: bool>(&mut self) {
344        self.mask.close::<COUNTED>();
345    }
346
347    fn try_recv<const COUNTED: bool>(&mut self) -> Result<Self::Item, TryRecvError> {
348        match self.pop_front() {
349            Some(elem) => Ok(elem),
350            // the channel is empty, but may also have been closed already
351            None => match self.mask.is_closed::<COUNTED>() {
352                true => Err(TryRecvError::Disconnected),
353                false => Err(TryRecvError::Empty),
354            },
355        }
356    }
357}
358
359impl<T> MaybeBoundedQueue for Queue<T, Bounded> {
360    type Item = T;
361
362    #[cold]
363    fn reset(&mut self) {
364        // this can never underflow, because `permits` is never increased above
365        // the specified `max_capacity`
366        let diff = self.extra.max_capacity - self.extra.semaphore.available_permits();
367        self.extra.semaphore.add_permits(diff);
368    }
369
370    #[cold]
371    fn close<const COUNTED: bool>(&mut self) {
372        // must also close semaphore in order to notify all waiting senders
373        self.mask.close::<COUNTED>();
374        let _ = self.extra.semaphore.close();
375    }
376
377    fn try_recv<const COUNTED: bool>(&mut self) -> Result<Self::Item, TryRecvError> {
378        match self.pop_front() {
379            // an element exists in the channel, wake the next blocked
380            // sender, if any, and return the element
381            Some(elem) => {
382                self.extra.semaphore.add_permits(1);
383                Ok(elem)
384            }
385            // the channel is empty, but may also have been closed already
386            // we must also check, if there are outstanding reserved permits
387            // before the queue can be assessed to be empty
388            None => match self.extra.reserved == 0 && self.mask.is_closed::<COUNTED>() {
389                true => Err(TryRecvError::Disconnected),
390                false => Err(TryRecvError::Empty),
391            },
392        }
393    }
394}
395
396/// A trait abstracting over either *bounded* or *unbounded* queues.
397///
398/// This is declared as public but not exported in the crate's API.
399pub trait MaybeBoundedQueue {
400    /// The type stored in the queue.
401    type Item: Sized;
402
403    /// Resets the available capacity for a bounded queue.
404    fn reset(&mut self);
405
406    /// Closes the queue and notifies all waiters.
407    fn close<const COUNTED: bool>(&mut self);
408
409    /// Dequeues an element from the queue.
410    fn try_recv<const COUNTED: bool>(&mut self) -> Result<Self::Item, TryRecvError>;
411}
412
413pub(crate) struct Unbounded;
414
415pub(crate) struct Bounded {
416    /// The semaphore sequencing the blocked senders.
417    semaphore: Semaphore,
418    /// The channel's capacity.
419    max_capacity: usize,
420    /// The number of currently reserved capacity.
421    reserved: usize,
422}
423
424/// The [`Future`] for receiving an element through the channel.
425pub(crate) struct RecvFuture<'a, T, B, const COUNTED: bool> {
426    pub(crate) queue: &'a UnsafeCell<Queue<T, B>>,
427}
428
429impl<T, B, const COUNTED: bool> Future for RecvFuture<'_, T, B, COUNTED>
430where
431    Queue<T, B>: MaybeBoundedQueue<Item = T>,
432{
433    type Output = Option<T>;
434
435    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
436        let queue = self.get_mut().queue;
437        // SAFETY: no mutable or aliased access to queue possible
438        unsafe { (*queue.get()).poll_recv::<COUNTED>(cx) }
439    }
440}