use async_cell::sync::AsyncCell;
use futures::Future;
use std::sync::Arc;
use tracing::Instrument;
pub struct SharedPrerequisite<T: Clone>(Arc<AsyncCell<std::result::Result<T, String>>>);
impl<T: Clone> SharedPrerequisite<T> {
#[allow(dead_code)]
pub async fn get_fut(&self) -> crate::Result<T> {
self.0
.get()
.await
.map_err(|err| crate::Error::prerequisite_failed(err))
}
pub fn get_ready(&self) -> T {
self.0
.try_get()
.expect("SharedPrerequisite cached value accessed without call to wait_ready")
.expect("SharedPrerequisite cached value accessed without call to wait_ready")
}
pub async fn wait_ready(&self) -> crate::Result<()> {
self.0
.get()
.await
.map(|_| ())
.map_err(|err| crate::Error::prerequisite_failed(err))
}
pub fn spawn<F>(future: F) -> Arc<Self>
where
T: Clone + Send + 'static,
F: Future<Output = crate::Result<T>> + Send + 'static,
{
let cell = AsyncCell::<std::result::Result<T, String>>::shared();
let dst = cell.clone();
tokio::spawn(
(async move {
let res = future.await;
dst.set(res.map_err(|err| err.to_string()));
})
.in_current_span(),
);
Arc::new(Self(cell))
}
}
#[cfg(test)]
mod tests {
use std::future;
use super::*;
#[tokio::test]
async fn test_spawn_prereq() {
let fut = future::ready(crate::Result::Ok(7_u32));
let prereq = SharedPrerequisite::spawn(fut);
let mut tasks = Vec::with_capacity(10);
for _ in 0..10 {
let instance = prereq.clone();
tasks.push(tokio::spawn(async move {
instance.wait_ready().await.unwrap();
assert_eq!(instance.get_ready(), 7_u32);
}));
}
for task in tasks {
task.await.unwrap();
}
let fut = future::ready(crate::Result::Err(crate::Error::invalid_input("xyz")));
let prereq = SharedPrerequisite::<u32>::spawn(fut);
let mut tasks = Vec::with_capacity(10);
for _ in 0..10 {
let instance = prereq.clone();
tasks.push(tokio::spawn(async move {
let err = instance.wait_ready().await.unwrap_err();
assert!(err.to_string().contains("xyz"));
assert!(err.to_string().contains("task failed"));
}));
}
for task in tasks {
task.await.unwrap();
}
}
}