Skip to main content

lance_core/utils/
futures.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::VecDeque,
6    sync::{Arc, Mutex},
7    task::Waker,
8};
9
10use futures::{Stream, StreamExt, stream::BoxStream};
11use pin_project::{pin_project, pinned_drop};
12use tokio::sync::Semaphore;
13use tokio_util::sync::PollSemaphore;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum Side {
17    Left,
18    Right,
19}
20
21/// A potentially unbounded capacity
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub enum Capacity {
24    Bounded(u32),
25    Unbounded,
26}
27
28struct InnerState<'a, T> {
29    inner: Option<BoxStream<'a, T>>,
30    buffer: VecDeque<T>,
31    polling: Option<Side>,
32    waker: Option<Waker>,
33    exhausted: bool,
34    left_buffered: u32,
35    right_buffered: u32,
36    available_buffer: Option<PollSemaphore>,
37}
38
39/// A stream that can be shared between two consumers.
40pub struct SharedStream<'a, T: Clone> {
41    state: Arc<Mutex<InnerState<'a, T>>>,
42    side: Side,
43}
44
45impl<'a, T: Clone> SharedStream<'a, T> {
46    pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
47        let available_buffer = match capacity {
48            Capacity::Unbounded => None,
49            Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
50                capacity as usize,
51            )))),
52        };
53        let state = InnerState {
54            inner: Some(inner),
55            buffer: VecDeque::new(),
56            polling: None,
57            waker: None,
58            exhausted: false,
59            left_buffered: 0,
60            right_buffered: 0,
61            available_buffer,
62        };
63
64        let state = Arc::new(Mutex::new(state));
65
66        let left = Self {
67            state: state.clone(),
68            side: Side::Left,
69        };
70        let right = Self {
71            state,
72            side: Side::Right,
73        };
74        (left, right)
75    }
76}
77
78impl<T: Clone> Stream for SharedStream<'_, T> {
79    type Item = T;
80
81    fn poll_next(
82        self: std::pin::Pin<&mut Self>,
83        cx: &mut std::task::Context<'_>,
84    ) -> std::task::Poll<Option<Self::Item>> {
85        let mut inner_state = self.state.lock().unwrap();
86        let can_take_buffered = match self.side {
87            Side::Left => inner_state.left_buffered > 0,
88            Side::Right => inner_state.right_buffered > 0,
89        };
90        if can_take_buffered {
91            // Easy case, there is an item in the buffer.  Grab it, decrement the count, and return it.
92            let item = inner_state.buffer.pop_front();
93            match self.side {
94                Side::Left => {
95                    inner_state.left_buffered -= 1;
96                }
97                Side::Right => {
98                    inner_state.right_buffered -= 1;
99                }
100            }
101            if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
102                available_buffer.add_permits(1);
103            }
104            std::task::Poll::Ready(item)
105        } else {
106            if inner_state.exhausted {
107                return std::task::Poll::Ready(None);
108            }
109            // No buffered items, if we have room in the buffer, then try and poll for one
110            let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
111                match available_buffer.poll_acquire(cx) {
112                    // Can return None if the semaphore is closed but we never close the semaphore
113                    // so its safe to unwrap here
114                    std::task::Poll::Ready(permit) => Some(permit.unwrap()),
115                    std::task::Poll::Pending => {
116                        return std::task::Poll::Pending;
117                    }
118                }
119            } else {
120                None
121            };
122            if let Some(polling_side) = inner_state.polling.as_ref()
123                && *polling_side != self.side
124            {
125                // Another task is already polling the inner stream, so we don't need to do anything
126
127                // Per rust docs:
128                //   Note that on multiple calls to poll, only the Waker from the Context
129                //   passed to the most recent call should be scheduled to receive a wakeup.
130                //
131                // So it is safe to replace a potentially stale waker here.
132                inner_state.waker = Some(cx.waker().clone());
133                return std::task::Poll::Pending;
134            }
135            inner_state.polling = Some(self.side);
136            // Release the mutex here as polling the inner stream is potentially expensive
137            let mut to_poll = inner_state
138                .inner
139                .take()
140                .expect("Other half of shared stream panic'd while polling inner stream");
141            drop(inner_state);
142            let res = to_poll.poll_next_unpin(cx);
143            let mut inner_state = self.state.lock().unwrap();
144
145            let mut should_wake = true;
146            match &res {
147                std::task::Poll::Ready(None) => {
148                    inner_state.exhausted = true;
149                    inner_state.polling = None;
150                }
151                std::task::Poll::Ready(Some(item)) => {
152                    // We got an item, forget the permit to mark that we can take one fewer items
153                    if let Some(permit) = permit {
154                        permit.forget();
155                    }
156                    inner_state.polling = None;
157                    // Let the other side know an item is available
158                    match self.side {
159                        Side::Left => {
160                            inner_state.right_buffered += 1;
161                        }
162                        Side::Right => {
163                            inner_state.left_buffered += 1;
164                        }
165                    };
166                    inner_state.buffer.push_back(item.clone());
167                }
168                std::task::Poll::Pending => {
169                    should_wake = false;
170                }
171            };
172
173            inner_state.inner = Some(to_poll);
174
175            // If the other side was waiting for us to poll, wake them up, but only after we release the mutex
176            let to_wake = if should_wake {
177                inner_state.waker.take()
178            } else {
179                // If the inner stream is pending then the inner stream will wake us up and we will wake the
180                // other side up then.
181                None
182            };
183            drop(inner_state);
184            if let Some(waker) = to_wake {
185                waker.wake();
186            }
187            res
188        }
189    }
190}
191
192pub trait SharedStreamExt<'a>: Stream + Send
193where
194    Self::Item: Clone,
195{
196    /// Split a stream into two shared streams
197    ///
198    /// Each shared stream will return the full set of items from the underlying stream.
199    /// This works by buffering the items from the underlying stream and then replaying
200    /// them to the other side.
201    ///
202    /// The capacity parameter controls how many items can be buffered at once.  Be careful
203    /// with the capacity parameter as it can lead to deadlock if the two streams are not
204    /// polled evenly.
205    ///
206    /// If the capacity is unbounded then the stream could potentially buffer the entire
207    /// input stream in memory.
208    fn share(
209        self,
210        capacity: Capacity,
211    ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
212}
213
214impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
215    fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
216        SharedStream::new(self, capacity)
217    }
218}
219
220#[pin_project]
221pub struct FinallyStream<S: Stream, F: FnOnce()> {
222    #[pin]
223    stream: S,
224    f: Option<F>,
225}
226
227impl<S: Stream, F: FnOnce()> FinallyStream<S, F> {
228    pub fn new(stream: S, f: F) -> Self {
229        Self { stream, f: Some(f) }
230    }
231}
232
233impl<S: Stream, F: FnOnce()> Stream for FinallyStream<S, F> {
234    type Item = S::Item;
235
236    fn poll_next(
237        self: std::pin::Pin<&mut Self>,
238        cx: &mut std::task::Context<'_>,
239    ) -> std::task::Poll<Option<Self::Item>> {
240        let this = self.project();
241        let res = this.stream.poll_next(cx);
242        if matches!(res, std::task::Poll::Ready(None)) {
243            // It's possible that None is polled multiple times, but we only call the function once
244            if let Some(f) = this.f.take() {
245                f();
246            }
247        }
248        res
249    }
250}
251
252pub trait FinallyStreamExt<S: Stream>: Stream + Sized {
253    fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
254        FinallyStream {
255            stream: self,
256            f: Some(f),
257        }
258    }
259}
260
261impl<S: Stream> FinallyStreamExt<S> for S {
262    fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
263        FinallyStream::new(self, f)
264    }
265}
266
267/// A stream wrapper that calls a function when dropped.
268///
269/// Unlike [`FinallyStream`], which fires when the inner stream yields `None`,
270/// this fires when the wrapper is dropped — even if the stream was not fully
271/// consumed.
272#[pin_project(PinnedDrop)]
273pub struct OnDropStream<S: Stream, F: FnOnce()> {
274    #[pin]
275    stream: S,
276    f: Option<F>,
277}
278
279impl<S: Stream, F: FnOnce()> OnDropStream<S, F> {
280    pub fn new(stream: S, f: F) -> Self {
281        Self { stream, f: Some(f) }
282    }
283}
284
285impl<S: Stream, F: FnOnce()> Stream for OnDropStream<S, F> {
286    type Item = S::Item;
287
288    fn poll_next(
289        self: std::pin::Pin<&mut Self>,
290        cx: &mut std::task::Context<'_>,
291    ) -> std::task::Poll<Option<Self::Item>> {
292        self.project().stream.poll_next(cx)
293    }
294}
295
296#[pinned_drop]
297impl<S: Stream, F: FnOnce()> PinnedDrop for OnDropStream<S, F> {
298    fn drop(self: std::pin::Pin<&mut Self>) {
299        let this = self.project();
300        if let Some(f) = this.f.take() {
301            f();
302        }
303    }
304}
305
306pub trait StreamOnDropExt: Stream + Sized {
307    /// Wrap this stream so that `f` is called when the stream is dropped.
308    fn on_drop<F: FnOnce()>(self, f: F) -> OnDropStream<Self, F> {
309        OnDropStream::new(self, f)
310    }
311}
312
313impl<S: Stream> StreamOnDropExt for S {}
314
315#[cfg(test)]
316mod tests {
317
318    use std::sync::Arc;
319    use std::sync::atomic::{AtomicBool, Ordering};
320
321    use futures::{FutureExt, StreamExt};
322    use tokio_stream::wrappers::ReceiverStream;
323
324    use crate::utils::futures::{Capacity, SharedStreamExt, StreamOnDropExt};
325
326    fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
327        let noop_waker = futures::task::noop_waker();
328        let mut context = std::task::Context::from_waker(&noop_waker);
329        fut.poll_unpin(&mut context).is_pending()
330    }
331
332    #[tokio::test]
333    async fn test_shared_stream() {
334        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
335        let inner_stream = ReceiverStream::new(rx);
336
337        // Feed in a few items
338        for i in 0..3 {
339            tx.send(i).await.unwrap();
340        }
341
342        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
343
344        // We should be able to immediately poll 2 items
345        assert_eq!(left.next().await.unwrap(), 0);
346        assert_eq!(left.next().await.unwrap(), 1);
347
348        // Polling again should block because the right side has fallen behind
349        let mut left_fut = left.next();
350
351        assert!(is_pending(&mut left_fut));
352
353        // Polling the right side should yield the first cached item and unblock the left
354        assert_eq!(right.next().await.unwrap(), 0);
355        assert_eq!(left_fut.await.unwrap(), 2);
356
357        // Drain the rest of the stream from the right
358        assert_eq!(right.next().await.unwrap(), 1);
359        assert_eq!(right.next().await.unwrap(), 2);
360
361        // The channel isn't closed yet so we should get pending on both sides
362        let mut right_fut = right.next();
363        let mut left_fut = left.next();
364        assert!(is_pending(&mut right_fut));
365        assert!(is_pending(&mut left_fut));
366
367        // Send one more item
368        tx.send(3).await.unwrap();
369
370        // Should be received by both
371        assert_eq!(right_fut.await.unwrap(), 3);
372        assert_eq!(left_fut.await.unwrap(), 3);
373
374        drop(tx);
375
376        // Now we should be able to poll the end from either side
377        assert_eq!(left.next().await, None);
378        assert_eq!(right.next().await, None);
379
380        // We should be self-fused
381        assert_eq!(left.next().await, None);
382        assert_eq!(right.next().await, None);
383    }
384
385    #[tokio::test]
386    async fn test_unbounded_shared_stream() {
387        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
388        let inner_stream = ReceiverStream::new(rx);
389
390        // Feed in a few items
391        for i in 0..10 {
392            tx.send(i).await.unwrap();
393        }
394        drop(tx);
395
396        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
397
398        // We should be able to completely drain one side
399        for i in 0..10 {
400            assert_eq!(left.next().await.unwrap(), i);
401        }
402        assert_eq!(left.next().await, None);
403
404        // And still drain the other side from the buffer
405        for i in 0..10 {
406            assert_eq!(right.next().await.unwrap(), i);
407        }
408        assert_eq!(right.next().await, None);
409    }
410
411    #[tokio::test(flavor = "multi_thread")]
412    async fn stress_shared_stream() {
413        for _ in 0..100 {
414            let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
415            let inner_stream = ReceiverStream::new(rx);
416            let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
417
418            let left_handle = tokio::spawn(async move {
419                let mut counter = 0;
420                while let Some(item) = left.next().await {
421                    assert_eq!(item, counter);
422                    counter += 1;
423                }
424            });
425
426            let right_handle = tokio::spawn(async move {
427                let mut counter = 0;
428                while let Some(item) = right.next().await {
429                    assert_eq!(item, counter);
430                    counter += 1;
431                }
432            });
433
434            for i in 0..1000 {
435                tx.send(i).await.unwrap();
436            }
437            drop(tx);
438            left_handle.await.unwrap();
439            right_handle.await.unwrap();
440        }
441    }
442
443    #[tokio::test]
444    async fn test_on_drop_fires_on_early_drop() {
445        let called = Arc::new(AtomicBool::new(false));
446        let called_clone = called.clone();
447
448        let stream = futures::stream::iter(vec![1, 2, 3]);
449        let mut stream = stream.on_drop(move || {
450            called_clone.store(true, Ordering::SeqCst);
451        });
452
453        // Consume only one item, then drop
454        assert_eq!(stream.next().await, Some(1));
455        assert!(!called.load(Ordering::SeqCst));
456        drop(stream);
457        assert!(called.load(Ordering::SeqCst));
458    }
459
460    #[tokio::test]
461    async fn test_on_drop_fires_after_exhaustion() {
462        let called = Arc::new(AtomicBool::new(false));
463        let called_clone = called.clone();
464
465        let stream = futures::stream::iter(vec![1]);
466        let mut stream = stream.on_drop(move || {
467            called_clone.store(true, Ordering::SeqCst);
468        });
469
470        assert_eq!(stream.next().await, Some(1));
471        assert_eq!(stream.next().await, None);
472        assert!(!called.load(Ordering::SeqCst));
473        drop(stream);
474        assert!(called.load(Ordering::SeqCst));
475    }
476
477    #[tokio::test]
478    async fn test_on_drop_fires_without_polling() {
479        let called = Arc::new(AtomicBool::new(false));
480        let called_clone = called.clone();
481
482        let stream = futures::stream::iter(vec![1, 2, 3]);
483        let stream = stream.on_drop(move || {
484            called_clone.store(true, Ordering::SeqCst);
485        });
486
487        assert!(!called.load(Ordering::SeqCst));
488        drop(stream);
489        assert!(called.load(Ordering::SeqCst));
490    }
491}