use std::sync::Arc;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use opentelemetry::Context as OtelContext;
use opentelemetry::InstrumentationScope;
use opentelemetry::global::{BoxedTracer, tracer_provider};
use opentelemetry::trace::{SpanBuilder, TraceContextExt, Tracer, TracerProvider};
use tower::{Layer, Service};
pub trait MakeSpan<Req> {
fn make_span(&self, req: &Req) -> SpanBuilder;
}
impl<F, Req> MakeSpan<Req> for F
where
F: Fn(&Req) -> SpanBuilder,
{
fn make_span(&self, req: &Req) -> SpanBuilder {
(self)(req)
}
}
#[derive(Clone)]
pub struct TracedLayer<F> {
tracer: Arc<BoxedTracer>,
make_span: F,
}
impl<F> TracedLayer<F> {
pub fn new(scope: &'static InstrumentationScope, make_span: F) -> Self {
let tracer = tracer_provider().tracer_with_scope(scope.clone());
Self {
tracer: Arc::new(tracer),
make_span,
}
}
pub fn with_tracer(tracer: Arc<BoxedTracer>, make_span: F) -> Self {
Self { tracer, make_span }
}
}
impl<S, F> Layer<S> for TracedLayer<F>
where
F: Clone,
{
type Service = TracedService<S, F>;
fn layer(&self, inner: S) -> Self::Service {
TracedService {
inner,
tracer: Arc::clone(&self.tracer),
make_span: self.make_span.clone(),
}
}
}
#[derive(Clone)]
pub struct TracedService<S, F> {
inner: S,
tracer: Arc<BoxedTracer>,
make_span: F,
}
impl<S, F> TracedService<S, F> {
pub fn new(inner: S, scope: &'static InstrumentationScope, make_span: F) -> Self {
let tracer = tracer_provider().tracer_with_scope(scope.clone());
Self {
inner,
tracer: Arc::new(tracer),
make_span,
}
}
}
impl<S, F, Req> Service<Req> for TracedService<S, F>
where
S: Service<Req>,
F: MakeSpan<Req>,
{
type Response = S::Response;
type Error = S::Error;
type Future = TracedFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Req) -> Self::Future {
let builder = self.make_span.make_span(&req);
TracedFuture::with_tracer(self.inner.call(req), Arc::clone(&self.tracer), builder)
}
}
enum SpanState {
Pending {
tracer: Arc<BoxedTracer>,
builder: Box<SpanBuilder>,
},
Active(OtelContext),
}
impl SpanState {
fn ensure_started(&mut self) -> OtelContext {
match self {
SpanState::Pending { tracer, builder } => {
let builder = std::mem::replace(builder, Box::new(SpanBuilder::from_name("")));
let parent_cx = OtelContext::current();
let span = tracer.build_with_context(*builder, &parent_cx);
let span_cx = parent_cx.with_span(span);
*self = SpanState::Active(span_cx.clone());
span_cx
}
SpanState::Active(cx) => cx.clone(),
}
}
}
pin_project! {
pub struct TracedFuture<F> {
#[pin]
inner: F,
state: SpanState,
}
}
impl<F> TracedFuture<F> {
pub fn new(inner: F, scope: &'static InstrumentationScope, builder: SpanBuilder) -> Self {
let tracer = tracer_provider().tracer_with_scope(scope.clone());
Self::with_tracer(inner, Arc::new(tracer), builder)
}
fn with_tracer(inner: F, tracer: Arc<BoxedTracer>, builder: SpanBuilder) -> Self {
Self {
inner,
state: SpanState::Pending {
tracer,
builder: Box::new(builder),
},
}
}
}
impl<F: std::future::Future> std::future::Future for TracedFuture<F> {
type Output = F::Output;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let span_cx = this.state.ensure_started();
let _guard = span_cx.attach();
this.inner.poll(cx)
}
}
pub trait TracedFutureExt: std::future::Future + Sized + private::Sealed {
fn traced(
self,
scope: &'static InstrumentationScope,
builder: SpanBuilder,
) -> TracedFuture<Self> {
TracedFuture::new(self, scope, builder)
}
}
impl<F: std::future::Future> TracedFutureExt for F {}
mod private {
pub trait Sealed {}
impl<F: std::future::Future> Sealed for F {}
}
#[cfg(test)]
mod tests {
use super::*;
use apollo_opentelemetry_test::{TelemetryContext, assert_spans_snapshot};
use opentelemetry::KeyValue;
fn test_scope() -> &'static InstrumentationScope {
static SCOPE: std::sync::LazyLock<InstrumentationScope> =
std::sync::LazyLock::new(|| InstrumentationScope::builder("test").build());
&SCOPE
}
#[tokio::test]
async fn test_traced_service() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
TracedService::new(inner, test_scope(), |req: &String| {
SpanBuilder::from_name("echo")
.with_attributes([KeyValue::new("input", req.clone())])
})
});
assert!(service.poll_ready().is_ready());
let response = service.call("hello".to_string());
let (req, send_response) = handle.next_request().await.unwrap();
assert_eq!(req, "hello");
send_response.send_response("hello".to_string());
assert_eq!(response.await.unwrap(), "hello");
assert_spans_snapshot!(ctx, @r#"
- name: echo
span_kind: Internal
is_sampled: true
attributes:
input: hello
"#);
}
#[tokio::test]
async fn traced_future_new_emits_span() {
let ctx = TelemetryContext::new();
async { 42 }
.traced(
test_scope(),
SpanBuilder::from_name("work").with_attributes([KeyValue::new("kind", "compute")]),
)
.await;
assert_spans_snapshot!(ctx, @r#"
- name: work
span_kind: Internal
is_sampled: true
attributes:
kind: compute
"#);
}
#[tokio::test]
async fn traced_future_dropped_before_poll_emits_no_span() {
let ctx = TelemetryContext::new();
let traced = TracedFuture::new(
async { 42 },
test_scope(),
SpanBuilder::from_name("never-started"),
);
drop(traced);
assert_spans_snapshot!(ctx, @r"[]");
}
#[tokio::test]
async fn test_traced_layer() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
let layer = TracedLayer::new(test_scope(), |req: &String| {
SpanBuilder::from_name("layered")
.with_attributes([KeyValue::new("len", req.len() as i64)])
});
layer.layer(inner)
});
assert!(service.poll_ready().is_ready());
let response = service.call("test".to_string());
let (req, send_response) = handle.next_request().await.unwrap();
assert_eq!(req, "test");
send_response.send_response("test".to_string());
assert_eq!(response.await.unwrap(), "test");
assert_spans_snapshot!(ctx, @r#"
- name: layered
span_kind: Internal
is_sampled: true
attributes:
len: "4"
"#);
}
}