use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use crate::__private::{RecordSpanResult, SpanResultRef};
#[derive(Clone, Copy, Debug, Default)]
pub struct RecordResultLayer {
_priv: (),
}
impl RecordResultLayer {
pub fn new() -> Self {
Self { _priv: () }
}
}
impl<S> Layer<S> for RecordResultLayer {
type Service = RecordResultService<S>;
fn layer(&self, inner: S) -> Self::Service {
RecordResultService { inner }
}
}
#[derive(Clone, Debug)]
pub struct RecordResultService<S> {
inner: S,
}
impl<S> RecordResultService<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, Req> Service<Req> for RecordResultService<S>
where
S: Service<Req>,
S::Error: std::fmt::Display,
{
type Response = S::Response;
type Error = S::Error;
type Future = RecordResultFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Req) -> Self::Future {
RecordResultFuture {
inner: self.inner.call(req),
}
}
}
pin_project! {
pub struct RecordResultFuture<F> {
#[pin]
inner: F,
}
}
impl<F, T, E> Future for RecordResultFuture<F>
where
F: Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let result = std::task::ready!(self.project().inner.poll(cx));
(&&&SpanResultRef(&result)).record();
Poll::Ready(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tower::ServiceBuilderExt;
use apollo_opentelemetry_test::{TelemetryContext, assert_spans_snapshot};
use opentelemetry::InstrumentationScope;
use opentelemetry::trace::SpanBuilder;
use tower::ServiceBuilder;
fn test_scope() -> &'static InstrumentationScope {
static SCOPE: std::sync::LazyLock<InstrumentationScope> =
std::sync::LazyLock::new(|| InstrumentationScope::builder("test").build());
&SCOPE
}
#[tokio::test]
async fn test_record_result_ok() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) = tower_test::mock::spawn_with(|inner| {
ServiceBuilder::new()
.traced(test_scope(), |_req: &String| {
SpanBuilder::from_name("test-span")
})
.layer(RecordResultLayer::new())
.service(inner)
});
assert!(service.poll_ready().is_ready());
let response = service.call("hello".to_string());
let (req, send_response) = handle.next_request().await.unwrap();
assert_eq!(req, "hello");
send_response.send_response("world".to_string());
assert_eq!(response.await.unwrap(), "world");
assert_spans_snapshot!(ctx, @r#"
- name: test-span
span_kind: Internal
is_sampled: true
status: Ok
"#);
}
#[tokio::test]
async fn test_record_result_err() {
let ctx = TelemetryContext::new();
let (mut service, mut handle) =
tower_test::mock::spawn_with(|inner: tower_test::mock::Mock<String, String>| {
ServiceBuilder::new()
.traced(test_scope(), |_req: &String| {
SpanBuilder::from_name("test-span")
})
.layer(RecordResultLayer::new())
.service(inner)
});
assert!(service.poll_ready().is_ready());
let response = service.call("hello".to_string());
let (req, send_response) = handle.next_request().await.unwrap();
assert_eq!(req, "hello");
send_response.send_error("something went wrong");
let err = response.await.unwrap_err();
assert_eq!(err.to_string(), "something went wrong");
assert_spans_snapshot!(ctx, @r#"
- name: test-span
span_kind: Internal
is_sampled: true
status: "Error: something went wrong"
"#);
}
}