use std::future::Future;
use std::pin::Pin;
use std::task::Poll;
use indextree::NodeId;
use pin_project::{pin_project, pinned_drop};
use crate::context::ContextId;
use crate::root::current_context;
use crate::Span;
enum State {
Initial(Span),
Polled {
this_node: NodeId,
this_context_id: ContextId,
},
Ready,
Disabled,
}
#[pin_project(PinnedDrop)]
pub struct Instrumented<F: Future> {
#[pin]
inner: F,
state: State,
}
impl<F: Future> Instrumented<F> {
pub(crate) fn new(inner: F, span: Span) -> Self {
Self {
inner,
state: State::Initial(span),
}
}
}
impl<F: Future> Future for Instrumented<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let context = current_context();
let (context, this_node) = match this.state {
State::Initial(span) => {
match context {
Some(c) => {
if !c.verbose() && span.is_verbose {
*this.state = State::Disabled;
return this.inner.poll(cx);
}
let node = c.tree().push(std::mem::take(span));
*this.state = State::Polled {
this_node: node,
this_context_id: c.id(),
};
(c, node)
}
None => return this.inner.poll(cx),
}
}
State::Polled {
this_node,
this_context_id: this_context,
} => {
match context {
Some(c) if c.id() == *this_context => {
c.tree().step_in(*this_node);
(c, *this_node)
}
Some(_) => {
tracing::warn!(
"future polled in a different context as it was first polled"
);
return this.inner.poll(cx);
}
None => {
tracing::warn!(
"future polled not in a context, while it was when first polled"
);
return this.inner.poll(cx);
}
}
}
State::Ready => unreachable!("the instrumented future should always be fused"),
State::Disabled => return this.inner.poll(cx),
};
debug_assert_eq!(this_node, context.tree().current());
match this.inner.poll(cx) {
Poll::Ready(output) => {
context.tree().pop();
*this.state = State::Ready;
Poll::Ready(output)
}
Poll::Pending => {
context.tree().step_out();
Poll::Pending
}
}
}
}
#[pinned_drop]
impl<F: Future> PinnedDrop for Instrumented<F> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
match this.state {
State::Polled {
this_node,
this_context_id,
} => match current_context() {
Some(c) if c.id() == *this_context_id => {
c.tree().remove_and_detach(*this_node);
}
Some(_) => {
tracing::warn!("future is dropped in a different context as it was first polled, cannot clean up!");
}
None => {
tracing::warn!("future is not in a context, while it was when first polled, cannot clean up!");
}
},
State::Initial(_) | State::Ready | State::Disabled => {}
}
}
}