lance_datafusion/utils/
background_iterator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use futures::ready;
5use futures::Stream;
6use std::{
7    future::Future,
8    panic,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tokio::task::JoinHandle;
13
14/// Wrap an iterator as a stream that executes the iterator in a background
15/// blocking thread.
16///
17/// The size hint is preserved, but the stream is not fused.
18#[pin_project::pin_project]
19pub struct BackgroundIterator<I: Iterator + Send + 'static> {
20    #[pin]
21    state: BackgroundIterState<I>,
22}
23
24impl<I: Iterator + Send + 'static> BackgroundIterator<I> {
25    pub fn new(iter: I) -> Self {
26        Self {
27            state: BackgroundIterState::Current { iter },
28        }
29    }
30}
31
32impl<I: Iterator + Send + 'static> Stream for BackgroundIterator<I>
33where
34    I::Item: Send + 'static,
35{
36    type Item = I::Item;
37
38    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
39        let mut this = self.project();
40
41        if let Some(mut iter) = this.state.as_mut().take_iter() {
42            this.state.set(BackgroundIterState::Running {
43                size_hint: iter.size_hint(),
44                task: tokio::task::spawn_blocking(move || {
45                    let next = iter.next();
46                    next.map(|next| (iter, next))
47                }),
48            });
49        }
50
51        let step = match this.state.as_mut().project_future() {
52            Some(task) => ready!(task.poll(cx)),
53            None => panic!(
54                "BackgroundIterator must not be polled after it returned `Poll::Ready(None)`"
55            ),
56        };
57
58        match step {
59            Ok(Some((iter, next))) => {
60                this.state.set(BackgroundIterState::Current { iter });
61                Poll::Ready(Some(next))
62            }
63            Ok(None) => {
64                this.state.set(BackgroundIterState::Empty);
65                Poll::Ready(None)
66            }
67            Err(err) => {
68                if err.is_panic() {
69                    // Resume the panic on the main task
70                    panic::resume_unwind(err.into_panic());
71                } else {
72                    panic!("Background task failed: {:?}", err);
73                }
74            }
75        }
76    }
77
78    fn size_hint(&self) -> (usize, Option<usize>) {
79        match &self.state {
80            BackgroundIterState::Current { iter } => iter.size_hint(),
81            BackgroundIterState::Running { size_hint, .. } => *size_hint,
82            BackgroundIterState::Empty => (0, Some(0)),
83        }
84    }
85}
86
87// Inspired by Unfold implementation: https://github.com/rust-lang/futures-rs/blob/master/futures-util/src/unfold_state.rs#L22
88#[pin_project::pin_project(project = StateProj, project_replace = StateReplace)]
89enum BackgroundIterState<I: Iterator> {
90    Current {
91        iter: I,
92    },
93    Running {
94        size_hint: (usize, Option<usize>),
95        #[pin]
96        task: NextHandle<I, I::Item>,
97    },
98    Empty,
99}
100
101type NextHandle<I, Item> = JoinHandle<Option<(I, Item)>>;
102
103impl<I: Iterator + Send + 'static> BackgroundIterState<I> {
104    fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut NextHandle<I, I::Item>>> {
105        match self.project() {
106            StateProj::Running { task, .. } => Some(task),
107            _ => None,
108        }
109    }
110
111    fn take_iter(self: Pin<&mut Self>) -> Option<I> {
112        match &*self {
113            Self::Current { .. } => match self.project_replace(Self::Empty) {
114                StateReplace::Current { iter } => Some(iter),
115                _ => None,
116            },
117            _ => None,
118        }
119    }
120}