apollo-opentelemetry 0.8.0

OpenTelemetry configuration types for Apollo platform
Documentation
//! Layer for recording service results to the current span.

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use pin_project_lite::pin_project;
use tower::{Layer, Service};

use crate::__private::{RecordSpanResult, SpanResultRef};

/// A tower [`Layer`] that sets span status based on the service's `Result`.
///
/// This layer should be placed between a span-creating layer (like [`TracedLayer`])
/// and the inner service, so the result is recorded before the span ends.
///
/// When the service returns:
/// - `Ok(_)` → span status set to `Ok`
/// - `Err(e)` → span status set to `Error` with the error's display message
///
/// # Example
///
/// ```
/// use apollo_opentelemetry::{default_instrumentation_scope, tower::{RecordResultLayer, ServiceBuilderExt}};
/// use opentelemetry::trace::SpanBuilder;
/// use tower::ServiceBuilder;
///
/// # fn wrap<S, E: std::fmt::Display>(inner: S) -> impl tower::Service<String, Response = String, Error = E>
/// # where S: tower::Service<String, Response = String, Error = E> {
/// ServiceBuilder::new()
///     .traced(default_instrumentation_scope!(), |req: &String| SpanBuilder::from_name("handle-request"))
///     .layer(RecordResultLayer::new())
///     .service(inner)
/// # }
/// ```
///
/// [`TracedLayer`]: super::TracedLayer
#[derive(Clone, Copy, Debug, Default)]
pub struct RecordResultLayer {
    _priv: (),
}

impl RecordResultLayer {
    /// Creates a new `RecordResultLayer`.
    pub fn new() -> Self {
        Self { _priv: () }
    }
}

impl<S> Layer<S> for RecordResultLayer {
    type Service = RecordResultService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RecordResultService { inner }
    }
}

/// A tower [`Service`] that records the result to the current span.
///
/// This is typically created via [`RecordResultLayer`] rather than constructed directly.
#[derive(Clone, Debug)]
pub struct RecordResultService<S> {
    inner: S,
}

impl<S> RecordResultService<S> {
    /// Creates a new `RecordResultService` wrapping the given service.
    pub fn new(inner: S) -> Self {
        Self { inner }
    }
}

impl<S, Req> Service<Req> for RecordResultService<S>
where
    S: Service<Req>,
    S::Error: std::fmt::Display,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = RecordResultFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Req) -> Self::Future {
        RecordResultFuture {
            inner: self.inner.call(req),
        }
    }
}

pin_project! {
    /// Future returned by [`RecordResultService`].
    ///
    /// When the inner future completes, this sets the current span's status
    /// based on whether the result is `Ok` or `Err`.
    pub struct RecordResultFuture<F> {
        #[pin]
        inner: F,
    }
}

impl<F, T, E> Future for RecordResultFuture<F>
where
    F: Future<Output = Result<T, E>>,
    E: std::fmt::Display,
{
    type Output = F::Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let result = std::task::ready!(self.project().inner.poll(cx));
        (&&&SpanResultRef(&result)).record();
        Poll::Ready(result)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tower::ServiceBuilderExt;
    use apollo_opentelemetry_test::{TelemetryContext, assert_spans_snapshot};
    use opentelemetry::InstrumentationScope;
    use opentelemetry::trace::SpanBuilder;
    use tower::ServiceBuilder;

    fn test_scope() -> &'static InstrumentationScope {
        static SCOPE: std::sync::LazyLock<InstrumentationScope> =
            std::sync::LazyLock::new(|| InstrumentationScope::builder("test").build());
        &SCOPE
    }

    #[tokio::test]
    async fn test_record_result_ok() {
        let ctx = TelemetryContext::new();

        let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
            ServiceBuilder::new()
                .traced(test_scope(), |_req: &String| {
                    SpanBuilder::from_name("test-span")
                })
                .layer(RecordResultLayer::new())
                .service(inner)
        });

        assert!(service.poll_ready().is_ready());
        let response = service.call("hello".to_string());
        let (req, send_response) = handle.next_request().await.unwrap();
        assert_eq!(req, "hello");
        send_response.send_response("world".to_string());
        assert_eq!(response.await.unwrap(), "world");

        assert_spans_snapshot!(ctx, @r#"
        - name: test-span
          span_kind: Internal
          is_sampled: true
          status: Ok
        "#);
    }

    #[tokio::test]
    async fn test_record_result_err() {
        let ctx = TelemetryContext::new();

        let (mut service, mut handle) =
            tower_test::mock::spawn_with(|inner: tower_test::mock::Mock<String, String>| {
                ServiceBuilder::new()
                    .traced(test_scope(), |_req: &String| {
                        SpanBuilder::from_name("test-span")
                    })
                    .layer(RecordResultLayer::new())
                    .service(inner)
            });

        assert!(service.poll_ready().is_ready());
        let response = service.call("hello".to_string());
        let (req, send_response) = handle.next_request().await.unwrap();
        assert_eq!(req, "hello");
        send_response.send_error("something went wrong");
        let err = response.await.unwrap_err();
        assert_eq!(err.to_string(), "something went wrong");

        assert_spans_snapshot!(ctx, @r#"
        - name: test-span
          span_kind: Internal
          is_sampled: true
          status: "Error: something went wrong"
        "#);
    }
}