#![allow(dyn_drop)]
use super::{GuardedGauge, IntCounterWithLabels, Labels};
use pin_project::pin_project;
use prometheus::core::{Atomic, GenericCounter};
use std::{any::Any, future, ops::Deref, pin::Pin, task};
#[pin_project]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct InstrumentedFuture<F: future::Future> {
#[pin]
inner: F,
pre_polls: Vec<Box<dyn FnOnce() -> Option<Box<dyn Any + Send>> + Send>>,
#[allow(dyn_drop)]
resource_guards: Vec<Box<dyn Any + Send>>,
}
pub trait IntoInstrumentedFuture {
type Future: future::Future;
fn into_instrumented_future(self) -> InstrumentedFuture<Self::Future>;
}
impl<F: future::Future> IntoInstrumentedFuture for F {
type Future = Self;
fn into_instrumented_future(self) -> InstrumentedFuture<Self> {
InstrumentedFuture {
inner: self,
pre_polls: vec![],
resource_guards: vec![],
}
}
}
impl<F: future::Future> InstrumentedFuture<F> {
pub fn with_guard<GuardFn: FnOnce() -> Option<Box<dyn Any + Send>> + Send + 'static>(
mut self,
guard_fn: GuardFn,
) -> Self {
self.pre_polls.push(Box::new(guard_fn));
self
}
pub fn with_count<P: Atomic + 'static>(mut self, counter: &'static GenericCounter<P>) -> Self {
self.pre_polls.push(Box::new(move || {
counter.inc();
None
}));
self
}
pub fn with_count_labeled<C, L>(mut self, counter: &'static C, labels: L) -> Self
where
C: Deref<Target = IntCounterWithLabels<L>> + Sync,
L: Labels + Sync + Send + 'static,
{
self.pre_polls.push(Box::new(move || {
counter.inc(&labels);
None
}));
self
}
pub fn with_count_gauge<G, T, P>(mut self, gauge: &'static G) -> Self
where
G: Deref<Target = T> + Sync,
T: GuardedGauge<P> + 'static,
P: Atomic + 'static,
{
self.pre_polls.push(Box::new(move || {
Some(Box::new(gauge.deref().guarded_inc()))
}));
self
}
}
impl<F: future::Future> future::Future for InstrumentedFuture<F> {
type Output = <F as future::Future>::Output;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
use task::Poll::{Pending, Ready};
let pin_projection = self.project();
for pre_poll in pin_projection.pre_polls.drain(..) {
if let Some(droppable) = pre_poll() {
pin_projection.resource_guards.push(droppable);
}
}
match pin_projection.inner.poll(cx) {
p @ Pending => p,
out @ Ready(_) => {
pin_projection.resource_guards.clear();
out
}
}
}
}
#[test]
fn counters_increment_only_when_futures_run() {
use lazy_static::lazy_static;
use prometheus::{opts, register_int_counter, register_int_gauge, IntCounter, IntGauge};
use std::sync::{atomic::AtomicU8, atomic::Ordering, Arc, Mutex};
lazy_static! {
static ref WORK_COUNTER: IntCounter = register_int_counter!(opts!(
"work_counter",
"the number of times `work()` has been called"
))
.unwrap();
static ref WORK_GAUGE: IntGauge =
register_int_gauge!(opts!("work_gauge", "the number `work()` currently running"))
.unwrap();
static ref CAN_MEASURE: AtomicU8 = AtomicU8::new(0);
}
let work_stoppage = Arc::new(Mutex::new(0));
async fn work(stop_ref: Arc<Mutex<usize>>) {
CAN_MEASURE.store(1, Ordering::SeqCst);
*stop_ref.lock().unwrap() = 4;
}
let stop_ref = Arc::clone(&work_stoppage);
let value_lock = work_stoppage.lock().unwrap();
let f = work(stop_ref)
.into_instrumented_future()
.with_count(&WORK_COUNTER)
.with_count_gauge(&WORK_GAUGE);
assert_eq!(WORK_COUNTER.get(), 0);
assert_eq!(WORK_GAUGE.get(), 0);
let rt = tokio::runtime::Builder::new_multi_thread()
.build()
.expect("can build runtime");
let handle = rt.spawn(f);
while CAN_MEASURE.load(Ordering::SeqCst) == 0 {
std::thread::sleep(std::time::Duration::from_millis(10));
}
assert_eq!(WORK_COUNTER.get(), 1);
assert_eq!(WORK_GAUGE.get(), 1);
std::mem::drop(value_lock);
rt.block_on(handle).expect("can block on f");
assert_eq!(WORK_COUNTER.get(), 1);
assert_eq!(WORK_GAUGE.get(), 0);
assert_eq!(*work_stoppage.lock().unwrap(), 4);
}