use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use futures::FutureExt;
use crate::*;
#[cfg(not(target_arch = "wasm32"))]
pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
#[cfg(target_arch = "wasm32")]
pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
#[cfg(not(target_arch = "wasm32"))]
pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
#[cfg(target_arch = "wasm32")]
pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
#[cfg(not(target_arch = "wasm32"))]
pub trait MaybeSend: Send {}
#[cfg(target_arch = "wasm32")]
pub trait MaybeSend {}
#[cfg(not(target_arch = "wasm32"))]
impl<T: Send> MaybeSend for T {}
#[cfg(target_arch = "wasm32")]
impl<T> MaybeSend for T {}
pub struct ConcurrentTasks<I, O> {
executor: Executor,
factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
tasks: VecDeque<Task<(I, Result<O>)>>,
results: VecDeque<O>,
concurrent: usize,
prefetch: usize,
completed_but_unretrieved: Arc<AtomicUsize>,
errored: bool,
}
impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
pub fn new(
executor: Executor,
concurrent: usize,
prefetch: usize,
factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
) -> Self {
Self {
executor,
factory,
tasks: VecDeque::with_capacity(concurrent),
results: VecDeque::with_capacity(concurrent),
concurrent,
prefetch,
completed_but_unretrieved: Arc::default(),
errored: false,
}
}
#[inline]
fn is_concurrent(&self) -> bool {
self.concurrent > 1
}
pub fn clear(&mut self) {
self.tasks.clear();
self.results.clear();
}
#[inline]
pub fn has_remaining(&self) -> bool {
let completed = self.completed_but_unretrieved.load(Ordering::Relaxed);
self.tasks.len() < self.concurrent + completed.min(self.prefetch)
}
#[inline]
pub fn has_result(&self) -> bool {
!self.results.is_empty()
}
pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
let completed = self.completed_but_unretrieved.clone();
let fut = (self.factory)(input).inspect(move |_| {
completed.fetch_add(1, Ordering::Relaxed);
});
self.executor.execute(fut)
}
pub async fn execute(&mut self, input: I) -> Result<()> {
if self.errored {
return Err(Error::new(
ErrorKind::Unexpected,
"concurrent tasks met an unrecoverable error",
));
}
if !self.is_concurrent() {
let (_, o) = (self.factory)(input).await;
return match o {
Ok(o) => {
self.results.push_back(o);
Ok(())
}
Err(err) => Err(err),
};
}
if !self.has_remaining() {
let (i, o) = self
.tasks
.front_mut()
.expect("tasks must be available")
.await;
self.completed_but_unretrieved
.fetch_sub(1, Ordering::Relaxed);
match o {
Ok(o) => {
let _ = self.tasks.pop_front();
self.results.push_back(o)
}
Err(err) => {
if err.is_temporary() {
let task = self.create_task(i);
self.tasks
.front_mut()
.expect("tasks must be available")
.replace(task)
} else {
self.clear();
self.errored = true;
}
return Err(err);
}
}
}
self.tasks.push_back(self.create_task(input));
Ok(())
}
pub async fn next(&mut self) -> Option<Result<O>> {
if self.errored {
return Some(Err(Error::new(
ErrorKind::Unexpected,
"concurrent tasks met an unrecoverable error",
)));
}
if let Some(result) = self.results.pop_front() {
return Some(Ok(result));
}
if let Some(task) = self.tasks.front_mut() {
let (i, o) = task.await;
self.completed_but_unretrieved
.fetch_sub(1, Ordering::Relaxed);
return match o {
Ok(o) => {
let _ = self.tasks.pop_front();
Some(Ok(o))
}
Err(err) => {
if err.is_temporary() {
let task = self.create_task(i);
self.tasks
.front_mut()
.expect("tasks must be available")
.replace(task)
} else {
self.clear();
self.errored = true;
}
Some(Err(err))
}
};
}
None
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use rand::Rng;
use tokio::time::sleep;
use super::*;
use crate::raw::Duration;
#[tokio::test]
async fn test_concurrent_tasks() {
let executor = Executor::new();
let mut tasks = ConcurrentTasks::new(executor, 16, 8, |(i, dur)| {
Box::pin(async move {
sleep(dur).await;
if rand::thread_rng().gen_range(0..100) > 90 {
return (
(i, dur),
Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
);
}
((i, dur), Ok(i))
})
});
let mut ans = vec![];
for i in 0..10240 {
let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
loop {
let res = tasks.execute((i, dur)).await;
if res.is_ok() {
break;
}
}
}
loop {
match tasks.next().await.transpose() {
Ok(Some(i)) => ans.push(i),
Ok(None) => break,
Err(_) => continue,
}
}
assert_eq!(ans, (0..10240).collect::<Vec<_>>())
}
#[tokio::test]
async fn test_prefetch_backpressure() {
let executor = Executor::new();
let concurrent = 4;
let prefetch = 2;
let mut tasks = ConcurrentTasks::new(executor, concurrent, prefetch, |i: usize| {
Box::pin(async move {
sleep(Duration::from_millis(100)).await;
(i, Ok(i))
})
});
assert!(tasks.has_remaining(), "Should have space initially");
for i in 0..concurrent {
assert!(tasks.has_remaining(), "Should have space for task {i}");
tasks.execute(i).await.unwrap();
}
assert!(
!tasks.has_remaining(),
"Should not have space after submitting concurrent tasks"
);
sleep(Duration::from_millis(150)).await;
for i in concurrent..concurrent + prefetch {
assert!(
tasks.has_remaining(),
"Should have space for prefetch task {i}"
);
tasks.execute(i).await.unwrap();
}
assert!(
!tasks.has_remaining(),
"Should not have remaining space after filling up prefetch buffer"
);
let result = tasks.next().await;
assert!(result.is_some());
assert!(
tasks.has_remaining(),
"Should have remaining space after retrieving one result"
);
}
#[tokio::test]
async fn test_prefetch_zero() {
let executor = Executor::new();
let concurrent = 4;
let prefetch = 0;
let mut tasks = ConcurrentTasks::new(executor, concurrent, prefetch, |i: usize| {
Box::pin(async move {
sleep(Duration::from_millis(10)).await;
(i, Ok(i))
})
});
for i in 0..concurrent {
tasks.execute(i).await.unwrap();
}
assert!(
!tasks.has_remaining(),
"Should not have remaining space with prefetch=0"
);
let result = tasks.next().await;
assert!(result.is_some());
assert!(
tasks.has_remaining(),
"Should have remaining space after retrieving one result"
);
tasks.execute(concurrent).await.unwrap();
assert!(!tasks.has_remaining(), "Should be full again");
}
}