use std::sync::Arc;
use futures_util::future::BoxFuture;
use futures_util::future::LocalBoxFuture;
use futures_util::future::Shared;
use futures_util::FutureExt;
use parking_lot::Mutex;
use tokio::task::JoinError;
type JoinResult<TResult> = Result<TResult, Arc<JoinError>>;
type CreateFutureFn<TResult> =
Box<dyn Fn() -> LocalBoxFuture<'static, TResult> + Send + Sync>;
#[derive(Debug)]
struct State<TResult> {
retry_index: usize,
future: Option<Shared<BoxFuture<'static, JoinResult<TResult>>>>,
}
pub struct MultiRuntimeAsyncValueCreator<TResult: Send + Clone + 'static> {
create_future: CreateFutureFn<TResult>,
state: Mutex<State<TResult>>,
}
impl<TResult: Send + Clone + 'static> std::fmt::Debug
for MultiRuntimeAsyncValueCreator<TResult>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiRuntimeAsyncValueCreator").finish()
}
}
impl<TResult: Send + Clone + 'static> MultiRuntimeAsyncValueCreator<TResult> {
pub fn new(create_future: CreateFutureFn<TResult>) -> Self {
Self {
state: Mutex::new(State {
retry_index: 0,
future: None,
}),
create_future,
}
}
pub async fn get(&self) -> TResult {
let (mut future, mut retry_index) = {
let mut state = self.state.lock();
let future = match &state.future {
Some(future) => future.clone(),
None => {
let future = self.create_shared_future();
state.future = Some(future.clone());
future
}
};
(future, state.retry_index)
};
loop {
let result = future.await;
match result {
Ok(result) => return result,
Err(join_error) => {
if join_error.is_cancelled() {
let mut state = self.state.lock();
if state.retry_index == retry_index {
state.retry_index += 1;
state.future = Some(self.create_shared_future());
}
retry_index = state.retry_index;
future = state.future.as_ref().unwrap().clone();
if retry_index > 1000 {
panic!("Something went wrong.") }
} else {
panic!("{}", join_error);
}
}
}
}
}
fn create_shared_future(
&self,
) -> Shared<BoxFuture<'static, JoinResult<TResult>>> {
let future = (self.create_future)();
crate::spawn(future)
.map(|result| result.map_err(Arc::new))
.boxed()
.shared()
}
}
#[cfg(test)]
mod test {
use crate::spawn;
use super::*;
#[tokio::test]
async fn single_runtime() {
let value_creator = MultiRuntimeAsyncValueCreator::new(Box::new(|| {
async { 1 }.boxed_local()
}));
let value = value_creator.get().await;
assert_eq!(value, 1);
}
#[test]
fn multi_runtimes() {
let value_creator =
Arc::new(MultiRuntimeAsyncValueCreator::new(Box::new(|| {
async {
tokio::task::yield_now().await;
1
}
.boxed_local()
})));
let handles = (0..3)
.map(|_| {
let value_creator = value_creator.clone();
std::thread::spawn(|| {
create_runtime().block_on(async move { value_creator.get().await })
})
})
.collect::<Vec<_>>();
for handle in handles {
assert_eq!(handle.join().unwrap(), 1);
}
}
#[test]
fn multi_runtimes_first_never_finishes() {
let is_first_run = Arc::new(Mutex::new(true));
let (tx, rx) = std::sync::mpsc::channel::<()>();
let value_creator = Arc::new(MultiRuntimeAsyncValueCreator::new({
let is_first_run = is_first_run.clone();
Box::new(move || {
let is_first_run = is_first_run.clone();
let tx = tx.clone();
async move {
let is_first_run = {
let mut is_first_run = is_first_run.lock();
let initial_value = *is_first_run;
*is_first_run = false;
tx.send(()).unwrap();
initial_value
};
if is_first_run {
tokio::time::sleep(std::time::Duration::from_millis(30_000)).await;
panic!("TIMED OUT"); } else {
tokio::task::yield_now().await;
}
1
}
.boxed_local()
})
}));
std::thread::spawn({
let value_creator = value_creator.clone();
let is_first_run = is_first_run.clone();
move || {
create_runtime().block_on(async {
let value_creator = value_creator.clone();
spawn(async move { value_creator.get().await });
while *is_first_run.lock() {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
})
}
});
let handle = {
let value_creator = value_creator.clone();
std::thread::spawn(|| {
create_runtime().block_on(async move {
let value_creator = value_creator.clone();
rx.recv().unwrap();
value_creator.get().await
})
})
};
assert_eq!(handle.join().unwrap(), 1);
}
fn create_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
}
}