use futures::future::BoxFuture;
use futures::FutureExt;
use std::any::Any;
use std::error::Error;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::future::Future;
use tokio::sync::OnceCell;
pub trait JoinSetTracer: Send + Sync + 'static {
fn trace_future(
&self,
fut: BoxFuture<'static, Box<dyn Any + Send>>,
) -> BoxFuture<'static, Box<dyn Any + Send>>;
fn trace_block(
&self,
f: Box<dyn FnOnce() -> Box<dyn Any + Send> + Send>,
) -> Box<dyn FnOnce() -> Box<dyn Any + Send> + Send>;
}
struct NoopTracer;
impl JoinSetTracer for NoopTracer {
fn trace_future(
&self,
fut: BoxFuture<'static, Box<dyn Any + Send>>,
) -> BoxFuture<'static, Box<dyn Any + Send>> {
fut
}
fn trace_block(
&self,
f: Box<dyn FnOnce() -> Box<dyn Any + Send> + Send>,
) -> Box<dyn FnOnce() -> Box<dyn Any + Send> + Send> {
f
}
}
#[derive(Debug)]
pub enum JoinSetTracerError {
AlreadySet,
}
impl Display for JoinSetTracerError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
JoinSetTracerError::AlreadySet => {
write!(f, "The global JoinSetTracer is already set")
}
}
}
}
impl Error for JoinSetTracerError {}
static GLOBAL_TRACER: OnceCell<&'static dyn JoinSetTracer> = OnceCell::const_new();
static NOOP_TRACER: NoopTracer = NoopTracer;
#[inline]
fn get_tracer() -> &'static dyn JoinSetTracer {
GLOBAL_TRACER.get().copied().unwrap_or(&NOOP_TRACER)
}
pub fn set_join_set_tracer(
tracer: &'static dyn JoinSetTracer,
) -> Result<(), JoinSetTracerError> {
GLOBAL_TRACER
.set(tracer)
.map_err(|_set_err| JoinSetTracerError::AlreadySet)
}
pub fn trace_future<T, F>(future: F) -> BoxFuture<'static, T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let erased_future = async move {
let result = future.await;
Box::new(result) as Box<dyn Any + Send>
}
.boxed();
get_tracer()
.trace_future(erased_future)
.map(|any_box| {
*any_box
.downcast::<T>()
.expect("Tracer must preserve the future’s output type!")
})
.boxed()
}
pub fn trace_block<T, F>(f: F) -> Box<dyn FnOnce() -> T + Send>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let erased_closure = Box::new(|| {
let result = f();
Box::new(result) as Box<dyn Any + Send>
});
let traced_closure = get_tracer().trace_block(erased_closure);
Box::new(move || {
let any_box = traced_closure();
*any_box
.downcast::<T>()
.expect("Tracer must preserve the closure’s return type!")
})
}