use pin_project_lite::pin_project;
use polars_utils::{UnitVec, unitvec};
use crate::async_executor::{AbortOnDropHandle, TaskPriority, spawn};
pin_project! {
#[project = LocalOrSpawnedFutureProj]
pub enum LocalOrSpawnedFuture<F, O> {
Local { #[pin] fut: F },
Spawned { #[pin] handle: AbortOnDropHandle<O> }
}
}
impl<F, O> LocalOrSpawnedFuture<F, O>
where
F: Future<Output = O>,
{
pub fn new_local(fut: F) -> Self {
LocalOrSpawnedFuture::Local { fut }
}
}
impl<F, O> LocalOrSpawnedFuture<F, O>
where
F: Future<Output = O> + Send + 'static,
O: Send + 'static,
{
pub fn spawn(task_priority: TaskPriority, fut: F) -> Self {
LocalOrSpawnedFuture::Spawned {
handle: AbortOnDropHandle::new(spawn(task_priority, fut)),
}
}
}
impl<F, O> Future for LocalOrSpawnedFuture<F, O>
where
F: Future<Output = O>,
{
type Output = O;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match self.project() {
LocalOrSpawnedFutureProj::Local { fut } => fut.poll(cx),
LocalOrSpawnedFutureProj::Spawned { handle } => handle.poll(cx),
}
}
}
pub fn parallelize_first_to_local<'i, 'o, I, F, O>(
task_priority: TaskPriority,
futures_iter: I,
) -> impl ExactSizeIterator<Item = impl Future<Output = O> + Send + 'static> + 'o
where
I: Iterator<Item = F> + 'i,
F: Future<Output = O> + Send + 'static,
O: Send + 'static,
{
parallelize_first_to_local_impl(task_priority, futures_iter).into_iter()
}
fn parallelize_first_to_local_impl<I, F, O>(
task_priority: TaskPriority,
mut futures_iter: I,
) -> UnitVec<LocalOrSpawnedFuture<F, O>>
where
I: Iterator<Item = F>,
F: Future<Output = O> + Send + 'static,
O: Send + 'static,
{
let Some(first_fut) = futures_iter.next() else {
return UnitVec::new();
};
let first_fut = LocalOrSpawnedFuture::new_local(first_fut);
let Some(second_fut) = futures_iter.next() else {
return unitvec![first_fut];
};
let mut futures = UnitVec::with_capacity(2 + futures_iter.size_hint().0);
futures.extend([
first_fut,
LocalOrSpawnedFuture::spawn(task_priority, second_fut),
]);
futures.extend(futures_iter.map(|x| LocalOrSpawnedFuture::spawn(task_priority, x)));
futures
}