use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
pin_project! {
pub struct WithGuard<F, G> {
#[pin]
future: F,
guard: Option<G>,
}
}
impl<F: Future, G> Future for WithGuard<F, G> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let result = this.future.poll(cx);
if result.is_ready() {
*this.guard = None;
}
result
}
}
pub trait FutureExt: Future + Sized {
fn with_guard<G>(self, guard: G) -> WithGuard<Self, G>;
}
impl<F: Future> FutureExt for F {
fn with_guard<G>(self, guard: G) -> WithGuard<Self, G> {
WithGuard {
future: self,
guard: Some(guard),
}
}
}
#[cfg(test)]
mod tests {
use apollo_opentelemetry_test::{TelemetryContext, assert_metrics_snapshot};
use opentelemetry::global;
use crate::metrics::{FutureExt, UpDownCounterExt};
#[tokio::test]
async fn guard_drops_on_future_completion() {
let ctx = TelemetryContext::new();
let counter = global::meter_provider()
.meter("test")
.i64_up_down_counter("test.active")
.build();
let guard = counter.track([]);
async {
assert_metrics_snapshot!(ctx, @r"
- name: test.active
data:
type: Sum
data_points:
- value: 1
is_monotonic: false
temporality: Cumulative
");
}
.with_guard(guard)
.await;
assert_metrics_snapshot!(ctx, @r"
- name: test.active
data:
type: Sum
data_points:
- value: 0
is_monotonic: false
temporality: Cumulative
");
}
#[test]
fn guard_drops_on_future_cancellation() {
let ctx = TelemetryContext::new();
let counter = global::meter_provider()
.meter("test")
.i64_up_down_counter("test.active")
.build();
let guard = counter.track([]);
assert_metrics_snapshot!(ctx, @r"
- name: test.active
data:
type: Sum
data_points:
- value: 1
is_monotonic: false
temporality: Cumulative
");
let future = std::future::pending::<()>().with_guard(guard);
drop(future);
assert_metrics_snapshot!(ctx, @r"
- name: test.active
data:
type: Sum
data_points:
- value: 0
is_monotonic: false
temporality: Cumulative
");
}
}