lance_core/utils/
futures.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::VecDeque,
6    sync::{Arc, Mutex},
7    task::Waker,
8};
9
10use futures::{stream::BoxStream, Stream, StreamExt};
11use pin_project::pin_project;
12use tokio::sync::Semaphore;
13use tokio_util::sync::PollSemaphore;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum Side {
17    Left,
18    Right,
19}
20
21/// A potentially unbounded capacity
22#[derive(Clone, Copy, Debug, PartialEq)]
23pub enum Capacity {
24    Bounded(u32),
25    Unbounded,
26}
27
28struct InnerState<'a, T> {
29    inner: Option<BoxStream<'a, T>>,
30    buffer: VecDeque<T>,
31    polling: Option<Side>,
32    waker: Option<Waker>,
33    exhausted: bool,
34    left_buffered: u32,
35    right_buffered: u32,
36    available_buffer: Option<PollSemaphore>,
37}
38
39/// The stream returned by [`share`].
40pub struct SharedStream<'a, T: Clone> {
41    state: Arc<Mutex<InnerState<'a, T>>>,
42    side: Side,
43}
44
45impl<'a, T: Clone> SharedStream<'a, T> {
46    pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
47        let available_buffer = match capacity {
48            Capacity::Unbounded => None,
49            Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
50                capacity as usize,
51            )))),
52        };
53        let state = InnerState {
54            inner: Some(inner),
55            buffer: VecDeque::new(),
56            polling: None,
57            waker: None,
58            exhausted: false,
59            left_buffered: 0,
60            right_buffered: 0,
61            available_buffer,
62        };
63
64        let state = Arc::new(Mutex::new(state));
65
66        let left = Self {
67            state: state.clone(),
68            side: Side::Left,
69        };
70        let right = Self {
71            state,
72            side: Side::Right,
73        };
74        (left, right)
75    }
76}
77
78impl<T: Clone> Stream for SharedStream<'_, T> {
79    type Item = T;
80
81    fn poll_next(
82        self: std::pin::Pin<&mut Self>,
83        cx: &mut std::task::Context<'_>,
84    ) -> std::task::Poll<Option<Self::Item>> {
85        let mut inner_state = self.state.lock().unwrap();
86        let can_take_buffered = match self.side {
87            Side::Left => inner_state.left_buffered > 0,
88            Side::Right => inner_state.right_buffered > 0,
89        };
90        if can_take_buffered {
91            // Easy case, there is an item in the buffer.  Grab it, decrement the count, and return it.
92            let item = inner_state.buffer.pop_front();
93            match self.side {
94                Side::Left => {
95                    inner_state.left_buffered -= 1;
96                }
97                Side::Right => {
98                    inner_state.right_buffered -= 1;
99                }
100            }
101            if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
102                available_buffer.add_permits(1);
103            }
104            std::task::Poll::Ready(item)
105        } else {
106            if inner_state.exhausted {
107                return std::task::Poll::Ready(None);
108            }
109            // No buffered items, if we have room in the buffer, then try and poll for one
110            let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
111                match available_buffer.poll_acquire(cx) {
112                    // Can return None if the semaphore is closed but we never close the semaphore
113                    // so its safe to unwrap here
114                    std::task::Poll::Ready(permit) => Some(permit.unwrap()),
115                    std::task::Poll::Pending => {
116                        return std::task::Poll::Pending;
117                    }
118                }
119            } else {
120                None
121            };
122            if let Some(polling_side) = inner_state.polling.as_ref() {
123                if *polling_side != self.side {
124                    // Another task is already polling the inner stream, so we don't need to do anything
125
126                    // Per rust docs:
127                    //   Note that on multiple calls to poll, only the Waker from the Context
128                    //   passed to the most recent call should be scheduled to receive a wakeup.
129                    //
130                    // So it is safe to replace a potentially stale waker here.
131                    inner_state.waker = Some(cx.waker().clone());
132                    return std::task::Poll::Pending;
133                }
134            }
135            inner_state.polling = Some(self.side);
136            // Release the mutex here as polling the inner stream is potentially expensive
137            let mut to_poll = inner_state
138                .inner
139                .take()
140                .expect("Other half of shared stream panic'd while polling inner stream");
141            drop(inner_state);
142            let res = to_poll.poll_next_unpin(cx);
143            let mut inner_state = self.state.lock().unwrap();
144
145            let mut should_wake = true;
146            match &res {
147                std::task::Poll::Ready(None) => {
148                    inner_state.exhausted = true;
149                    inner_state.polling = None;
150                }
151                std::task::Poll::Ready(Some(item)) => {
152                    // We got an item, forget the permit to mark that we can take one fewer items
153                    if let Some(permit) = permit {
154                        permit.forget();
155                    }
156                    inner_state.polling = None;
157                    // Let the other side know an item is available
158                    match self.side {
159                        Side::Left => {
160                            inner_state.right_buffered += 1;
161                        }
162                        Side::Right => {
163                            inner_state.left_buffered += 1;
164                        }
165                    };
166                    inner_state.buffer.push_back(item.clone());
167                }
168                std::task::Poll::Pending => {
169                    should_wake = false;
170                }
171            };
172
173            inner_state.inner = Some(to_poll);
174
175            // If the other side was waiting for us to poll, wake them up, but only after we release the mutex
176            let to_wake = if should_wake {
177                inner_state.waker.take()
178            } else {
179                // If the inner stream is pending then the inner stream will wake us up and we will wake the
180                // other side up then.
181                None
182            };
183            drop(inner_state);
184            if let Some(waker) = to_wake {
185                waker.wake();
186            }
187            res
188        }
189    }
190}
191
192pub trait SharedStreamExt<'a>: Stream + Send
193where
194    Self::Item: Clone,
195{
196    /// Split a stream into two shared streams
197    ///
198    /// Each shared stream will return the full set of items from the underlying stream.
199    /// This works by buffering the items from the underlying stream and then replaying
200    /// them to the other side.
201    ///
202    /// The capacity parameter controls how many items can be buffered at once.  Be careful
203    /// with the capacity parameter as it can lead to deadlock if the two streams are not
204    /// polled evenly.
205    ///
206    /// If the capacity is unbounded then the stream could potentially buffer the entire
207    /// input stream in memory.
208    fn share(
209        self,
210        capacity: Capacity,
211    ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
212}
213
214impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
215    fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
216        SharedStream::new(self, capacity)
217    }
218}
219
220#[pin_project]
221pub struct FinallyStream<S: Stream, F: FnOnce()> {
222    #[pin]
223    stream: S,
224    f: Option<F>,
225}
226
227impl<S: Stream, F: FnOnce()> FinallyStream<S, F> {
228    pub fn new(stream: S, f: F) -> Self {
229        Self { stream, f: Some(f) }
230    }
231}
232
233impl<S: Stream, F: FnOnce()> Stream for FinallyStream<S, F> {
234    type Item = S::Item;
235
236    fn poll_next(
237        self: std::pin::Pin<&mut Self>,
238        cx: &mut std::task::Context<'_>,
239    ) -> std::task::Poll<Option<Self::Item>> {
240        let this = self.project();
241        let res = this.stream.poll_next(cx);
242        if matches!(res, std::task::Poll::Ready(None)) {
243            // It's possible that None is polled multiple times, but we only call the function once
244            if let Some(f) = this.f.take() {
245                f();
246            }
247        }
248        res
249    }
250}
251
252pub trait FinallyStreamExt<S: Stream>: Stream + Sized {
253    fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
254        FinallyStream {
255            stream: self,
256            f: Some(f),
257        }
258    }
259}
260
261impl<S: Stream> FinallyStreamExt<S> for S {
262    fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
263        FinallyStream::new(self, f)
264    }
265}
266
267#[cfg(test)]
268mod tests {
269
270    use futures::{FutureExt, StreamExt};
271    use tokio_stream::wrappers::ReceiverStream;
272
273    use crate::utils::futures::{Capacity, SharedStreamExt};
274
275    fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
276        let noop_waker = futures::task::noop_waker();
277        let mut context = std::task::Context::from_waker(&noop_waker);
278        fut.poll_unpin(&mut context).is_pending()
279    }
280
281    #[tokio::test]
282    async fn test_shared_stream() {
283        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
284        let inner_stream = ReceiverStream::new(rx);
285
286        // Feed in a few items
287        for i in 0..3 {
288            tx.send(i).await.unwrap();
289        }
290
291        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
292
293        // We should be able to immediately poll 2 items
294        assert_eq!(left.next().await.unwrap(), 0);
295        assert_eq!(left.next().await.unwrap(), 1);
296
297        // Polling again should block because the right side has fallen behind
298        let mut left_fut = left.next();
299
300        assert!(is_pending(&mut left_fut));
301
302        // Polling the right side should yield the first cached item and unblock the left
303        assert_eq!(right.next().await.unwrap(), 0);
304        assert_eq!(left_fut.await.unwrap(), 2);
305
306        // Drain the rest of the stream from the right
307        assert_eq!(right.next().await.unwrap(), 1);
308        assert_eq!(right.next().await.unwrap(), 2);
309
310        // The channel isn't closed yet so we should get pending on both sides
311        let mut right_fut = right.next();
312        let mut left_fut = left.next();
313        assert!(is_pending(&mut right_fut));
314        assert!(is_pending(&mut left_fut));
315
316        // Send one more item
317        tx.send(3).await.unwrap();
318
319        // Should be received by both
320        assert_eq!(right_fut.await.unwrap(), 3);
321        assert_eq!(left_fut.await.unwrap(), 3);
322
323        drop(tx);
324
325        // Now we should be able to poll the end from either side
326        assert_eq!(left.next().await, None);
327        assert_eq!(right.next().await, None);
328
329        // We should be self-fused
330        assert_eq!(left.next().await, None);
331        assert_eq!(right.next().await, None);
332    }
333
334    #[tokio::test]
335    async fn test_unbounded_shared_stream() {
336        let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
337        let inner_stream = ReceiverStream::new(rx);
338
339        // Feed in a few items
340        for i in 0..10 {
341            tx.send(i).await.unwrap();
342        }
343        drop(tx);
344
345        let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
346
347        // We should be able to completely drain one side
348        for i in 0..10 {
349            assert_eq!(left.next().await.unwrap(), i);
350        }
351        assert_eq!(left.next().await, None);
352
353        // And still drain the other side from the buffer
354        for i in 0..10 {
355            assert_eq!(right.next().await.unwrap(), i);
356        }
357        assert_eq!(right.next().await, None);
358    }
359
360    #[tokio::test(flavor = "multi_thread")]
361    async fn stress_shared_stream() {
362        for _ in 0..100 {
363            let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
364            let inner_stream = ReceiverStream::new(rx);
365            let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
366
367            let left_handle = tokio::spawn(async move {
368                let mut counter = 0;
369                while let Some(item) = left.next().await {
370                    assert_eq!(item, counter);
371                    counter += 1;
372                }
373            });
374
375            let right_handle = tokio::spawn(async move {
376                let mut counter = 0;
377                while let Some(item) = right.next().await {
378                    assert_eq!(item, counter);
379                    counter += 1;
380                }
381            });
382
383            for i in 0..1000 {
384                tx.send(i).await.unwrap();
385            }
386            drop(tx);
387            left_handle.await.unwrap();
388            right_handle.await.unwrap();
389        }
390    }
391}