Skip to main content

queue_ext/
queue_stream.rs

1use std::collections::VecDeque;
2use std::fmt;
3use std::marker::PhantomData;
4use std::marker::Unpin;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::sync::Mutex;
10use std::task::{Context, Poll};
11
12use futures::task::AtomicWaker;
13use futures::Stream;
14use pin_project_lite::pin_project;
15
16use super::Waker;
17
18pin_project! {
19    #[must_use = "streams do nothing unless polled"]
20    pub struct QueueStream<Q, Item, F> {
21        #[pin]
22        q: Q,
23        #[pin]
24        f: F,
25        recv_task: Arc<AtomicWaker>,
26        parked_queue: Arc<Mutex<VecDeque<std::task::Waker>>>,
27        closed: Arc<AtomicBool>,
28        _item: PhantomData<Item>,
29    }
30}
31
32unsafe impl<Q, Item, F> Sync for QueueStream<Q, Item, F> {}
33
34unsafe impl<Q, Item, F> Send for QueueStream<Q, Item, F> {}
35
36impl<Q, Item, F> Clone for QueueStream<Q, Item, F>
37where
38    Q: Clone,
39    F: Clone,
40{
41    #[inline]
42    fn clone(&self) -> Self {
43        Self {
44            q: self.q.clone(),
45            f: self.f.clone(),
46            recv_task: self.recv_task.clone(),
47            parked_queue: self.parked_queue.clone(),
48            closed: self.closed.clone(),
49            _item: PhantomData,
50        }
51    }
52}
53
54impl<Q, Item, F> fmt::Debug for QueueStream<Q, Item, F>
55where
56    Q: fmt::Debug,
57{
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        f.debug_struct("QueueStream")
60            .field("queue", &self.q)
61            .finish()
62    }
63}
64
65impl<Q: Unpin, Item, F> QueueStream<Q, Item, F> {
66    #[inline]
67    pub(super) fn new(q: Q, f: F) -> Self {
68        Self {
69            q,
70            f,
71            recv_task: Arc::new(AtomicWaker::new()),
72            parked_queue: Arc::new(Mutex::new(VecDeque::default())),
73            closed: Arc::new(AtomicBool::new(false)),
74            _item: PhantomData,
75        }
76    }
77
78    #[inline]
79    pub fn is_closed(&self) -> bool {
80        self.closed.load(Ordering::SeqCst)
81    }
82}
83
84impl<Q, Item, F> Waker for QueueStream<Q, Item, F> {
85    #[inline]
86    fn rx_wake(&self) {
87        self.recv_task.wake()
88    }
89
90    #[inline]
91    fn tx_park(&self, w: std::task::Waker) {
92        self.parked_queue.lock().unwrap().push_back(w);
93    }
94
95    #[inline]
96    fn close_channel(&self) {
97        if !self.closed.load(Ordering::SeqCst) {
98            self.closed.store(true, Ordering::SeqCst);
99            self.rx_wake();
100            if let Some(w) = self.parked_queue.lock().unwrap().pop_front() {
101                w.wake();
102            }
103        }
104    }
105
106    #[inline]
107    fn is_closed(&self) -> bool {
108        self.closed.load(Ordering::SeqCst)
109    }
110}
111
112impl<Q, Item, F> Stream for QueueStream<Q, Item, F>
113where
114    Q: Unpin,
115    F: Fn(Pin<&mut Q>, &mut Context<'_>) -> Poll<Option<Item>>,
116{
117    type Item = Item;
118
119    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120        let mut this = self.project();
121        let f = this.f.as_mut();
122        match f(this.q.as_mut(), ctx) {
123            Poll::Ready(msg) => {
124                if let Some(w) = this.parked_queue.lock().unwrap().pop_front() {
125                    w.wake();
126                }
127                Poll::Ready(msg)
128            }
129            Poll::Pending => {
130                if this.closed.load(Ordering::SeqCst) {
131                    Poll::Ready(None)
132                } else {
133                    this.recv_task.register(ctx.waker());
134                    f(this.q.as_mut(), ctx)
135                }
136            }
137        }
138    }
139}
140
141impl<Q, Item, F> Deref for QueueStream<Q, Item, F> {
142    type Target = Q;
143    #[inline]
144    fn deref(&self) -> &Self::Target {
145        &self.q
146    }
147}
148
149impl<Q, Item, F> DerefMut for QueueStream<Q, Item, F> {
150    #[inline]
151    fn deref_mut(&mut self) -> &mut Self::Target {
152        &mut self.q
153    }
154}
155
156#[cfg(test)]
157use futures::pin_mut;
158#[cfg(test)]
159use futures::task::noop_waker;
160#[cfg(test)]
161use std::cell::Cell;
162
163/// Minimal queue type for testing QueueStream.
164#[cfg(test)]
165struct TestQueue {
166    items: VecDeque<i32>,
167}
168
169#[cfg(test)]
170fn poll_items(pin_q: Pin<&mut TestQueue>, _cx: &mut Context<'_>) -> Poll<Option<i32>> {
171    Poll::Ready(pin_q.get_mut().items.pop_front())
172}
173
174#[cfg(test)]
175fn poll_pending(_pin_q: Pin<&mut TestQueue>, _cx: &mut Context<'_>) -> Poll<Option<i32>> {
176    Poll::Pending
177}
178
179// ---------------------------------------------------------------------------
180// poll_next
181// ---------------------------------------------------------------------------
182
183#[test]
184fn poll_next_yields_items() {
185    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
186        TestQueue {
187            items: VecDeque::from([10, 20, 30]),
188        },
189        poll_items,
190    );
191    pin_mut!(stream);
192
193    let waker = noop_waker();
194    let mut cx = Context::from_waker(&waker);
195
196    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(Some(10)));
197    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(Some(20)));
198    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(Some(30)));
199    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(None));
200}
201
202#[test]
203fn poll_next_none_when_empty() {
204    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
205        TestQueue {
206            items: VecDeque::new(),
207        },
208        poll_items,
209    );
210    pin_mut!(stream);
211
212    let waker = noop_waker();
213    let mut cx = Context::from_waker(&waker);
214    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(None));
215}
216
217// ---------------------------------------------------------------------------
218// is_closed
219// ---------------------------------------------------------------------------
220
221#[test]
222fn is_closed_open() {
223    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
224        TestQueue {
225            items: VecDeque::new(),
226        },
227        poll_pending,
228    );
229    assert!(!stream.is_closed());
230}
231
232#[test]
233fn is_closed_after_close() {
234    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
235        TestQueue {
236            items: VecDeque::new(),
237        },
238        poll_pending,
239    );
240    stream.close_channel();
241    assert!(stream.is_closed());
242}
243
244// ---------------------------------------------------------------------------
245// Stream termination when channel is closed
246// ---------------------------------------------------------------------------
247
248#[test]
249fn poll_next_terminates_on_closed() {
250    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
251        TestQueue {
252            items: VecDeque::new(),
253        },
254        poll_pending,
255    );
256    stream.close_channel();
257
258    pin_mut!(stream);
259    let waker = noop_waker();
260    let mut cx = Context::from_waker(&waker);
261    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(None));
262}
263
264#[test]
265fn close_channel_during_pending_returns_none() {
266    let call_count = Cell::new(0u32);
267    let poll_fn = |_: Pin<&mut TestQueue>, _cx: &mut Context<'_>| -> Poll<Option<i32>> {
268        call_count.set(call_count.get() + 1);
269        Poll::Pending
270    };
271
272    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
273        TestQueue {
274            items: VecDeque::new(),
275        },
276        poll_fn,
277    );
278
279    pin_mut!(stream);
280    let waker = noop_waker();
281    let mut cx = Context::from_waker(&waker);
282
283    // First poll: Pending, waker registered
284    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Pending);
285
286    // Close the channel while stream is waiting (simulating sender drop)
287    stream.close_channel();
288
289    // Second poll: should now return Ready(None) because closed flag is set
290    assert_eq!(stream.as_mut().poll_next(&mut cx), Poll::Ready(None));
291}
292
293// ---------------------------------------------------------------------------
294// Waker registration on pending
295// ---------------------------------------------------------------------------
296
297#[test]
298fn waker_registered_on_pending() {
299    use std::sync::atomic::{AtomicBool, Ordering};
300    use std::sync::Arc;
301
302    let woken = Arc::new(AtomicBool::new(false));
303    let woken_clone = woken.clone();
304
305    struct TestWaker(Arc<AtomicBool>);
306    impl futures::task::ArcWake for TestWaker {
307        fn wake_by_ref(arc_self: &Arc<Self>) {
308            arc_self.0.store(true, Ordering::SeqCst);
309        }
310    }
311
312    let test_waker = Arc::new(TestWaker(woken_clone));
313    let waker = futures::task::waker(test_waker);
314
315    let poll_fn =
316        |_: Pin<&mut TestQueue>, _cx: &mut Context<'_>| -> Poll<Option<i32>> { Poll::Pending };
317
318    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
319        TestQueue {
320            items: VecDeque::new(),
321        },
322        poll_fn,
323    );
324    pin_mut!(stream);
325
326    let mut cx = Context::from_waker(&waker);
327    let _ = stream.as_mut().poll_next(&mut cx);
328
329    // Simulate a sender calling rx_wake()
330    stream.rx_wake();
331
332    // After wake, the waker flag should be set
333    assert!(woken.load(Ordering::SeqCst));
334}
335
336// ---------------------------------------------------------------------------
337// Waker trait impl on QueueStream
338// ---------------------------------------------------------------------------
339
340#[test]
341fn queue_stream_implements_waker() {
342    fn requires_waker<T: super::Waker>(_t: &T) {}
343
344    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
345        TestQueue {
346            items: VecDeque::new(),
347        },
348        poll_pending,
349    );
350    requires_waker(&stream);
351}
352
353#[test]
354fn queue_stream_waker_tx_park_and_wake() {
355    use std::sync::atomic::{AtomicBool, Ordering};
356    use std::sync::Arc;
357
358    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
359        TestQueue {
360            items: VecDeque::new(),
361        },
362        poll_pending,
363    );
364
365    let woken = Arc::new(AtomicBool::new(false));
366    let woken_clone = woken.clone();
367
368    struct TestWaker(Arc<AtomicBool>);
369    impl futures::task::ArcWake for TestWaker {
370        fn wake_by_ref(arc_self: &Arc<Self>) {
371            arc_self.0.store(true, Ordering::SeqCst);
372        }
373    }
374
375    let waker = futures::task::waker(Arc::new(TestWaker(woken_clone)));
376
377    // Park a waker
378    stream.tx_park(waker);
379
380    // close_channel should pop the parked waker and wake it
381    assert!(!woken.load(Ordering::SeqCst));
382    stream.close_channel();
383    assert!(woken.load(Ordering::SeqCst));
384}
385
386// ---------------------------------------------------------------------------
387// Send / Sync compile check
388// ---------------------------------------------------------------------------
389
390#[test]
391fn queue_stream_is_send_sync() {
392    fn assert_send<T: Send>(_t: &T) {}
393    fn assert_sync<T: Sync>(_t: &T) {}
394
395    let stream: QueueStream<TestQueue, i32, _> = QueueStream::new(
396        TestQueue {
397            items: VecDeque::new(),
398        },
399        poll_pending,
400    );
401    assert_send(&stream);
402    assert_sync(&stream);
403}