async_graphql/extensions/
opentelemetry.rs

1use std::sync::Arc;
2
3use async_graphql_parser::types::ExecutableDocument;
4use async_graphql_value::Variables;
5use futures_util::{TryFutureExt, stream::BoxStream};
6use opentelemetry::{
7    Context as OpenTelemetryContext, Key, KeyValue,
8    trace::{FutureExt, SpanKind, TraceContextExt, Tracer},
9};
10
11use crate::{
12    Response, ServerError, ServerResult, ValidationResult, Value,
13    extensions::{
14        Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextRequest,
15        NextResolve, NextSubscribe, NextValidation, ResolveInfo,
16    },
17    registry::MetaTypeName,
18};
19
20const KEY_SOURCE: Key = Key::from_static_str("graphql.source");
21const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables");
22const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType");
23const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType");
24const KEY_ERROR: Key = Key::from_static_str("graphql.error");
25const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity");
26const KEY_DEPTH: Key = Key::from_static_str("graphql.depth");
27
28/// OpenTelemetry extension
29///
30/// # Example
31///
32/// ```ignore
33/// use async_graphql::{extensions::OpenTelemetry, *};
34///
35/// let tracer = todo!("create your OpenTelemetry tracer");
36///
37/// let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
38///     .extension(OpenTelemetry::new(tracer))
39///     .finish();
40/// ```
41#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
42pub struct OpenTelemetry<T> {
43    tracer: Arc<T>,
44    trace_scalars: bool,
45}
46
47impl<T> OpenTelemetry<T> {
48    /// Use `tracer` to create an OpenTelemetry extension.
49    pub fn new(tracer: T) -> OpenTelemetry<T>
50    where
51        T: Tracer + Send + Sync + 'static,
52        <T as Tracer>::Span: Sync + Send,
53    {
54        Self {
55            tracer: Arc::new(tracer),
56            trace_scalars: false,
57        }
58    }
59
60    /// Enable or disable tracing for scalar and enum field resolutions.
61    ///
62    /// When `false` (the default), spans are not created for fields that return
63    /// scalar or enum types. This significantly reduces the number of spans
64    /// generated, preventing span explosion in queries with many scalar fields.
65    ///
66    /// When `true`, spans are created for all field resolutions, including
67    /// scalars and enums.
68    ///
69    /// # Example
70    ///
71    /// ```ignore
72    /// use async_graphql::{extensions::OpenTelemetry, *};
73    ///
74    /// let tracer = todo!("create your OpenTelemetry tracer");
75    ///
76    /// // Trace all fields including scalars
77    /// let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
78    ///     .extension(OpenTelemetry::new(tracer).with_trace_scalars(true))
79    ///     .finish();
80    /// ```
81    pub fn with_trace_scalars(mut self, trace_scalars: bool) -> Self {
82        self.trace_scalars = trace_scalars;
83        self
84    }
85}
86
87impl<T> ExtensionFactory for OpenTelemetry<T>
88where
89    T: Tracer + Send + Sync + 'static,
90    <T as Tracer>::Span: Sync + Send,
91{
92    fn create(&self) -> Arc<dyn Extension> {
93        Arc::new(OpenTelemetryExtension {
94            tracer: self.tracer.clone(),
95            trace_scalars: self.trace_scalars,
96        })
97    }
98}
99
100struct OpenTelemetryExtension<T> {
101    tracer: Arc<T>,
102    trace_scalars: bool,
103}
104
105#[async_trait::async_trait]
106impl<T> Extension for OpenTelemetryExtension<T>
107where
108    T: Tracer + Send + Sync + 'static,
109    <T as Tracer>::Span: Sync + Send,
110{
111    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
112        next.run(ctx)
113            .with_context(OpenTelemetryContext::current_with_span(
114                self.tracer
115                    .span_builder("request")
116                    .with_kind(SpanKind::Server)
117                    .start(&*self.tracer),
118            ))
119            .await
120    }
121
122    fn subscribe<'s>(
123        &self,
124        ctx: &ExtensionContext<'_>,
125        stream: BoxStream<'s, Response>,
126        next: NextSubscribe<'_>,
127    ) -> BoxStream<'s, Response> {
128        Box::pin(
129            next.run(ctx, stream)
130                .with_context(OpenTelemetryContext::current_with_span(
131                    self.tracer
132                        .span_builder("subscribe")
133                        .with_kind(SpanKind::Server)
134                        .start(&*self.tracer),
135                )),
136        )
137    }
138
139    async fn parse_query(
140        &self,
141        ctx: &ExtensionContext<'_>,
142        query: &str,
143        variables: &Variables,
144        next: NextParseQuery<'_>,
145    ) -> ServerResult<ExecutableDocument> {
146        let attributes = vec![
147            KeyValue::new(KEY_SOURCE, query.to_string()),
148            KeyValue::new(KEY_VARIABLES, serde_json::to_string(variables).unwrap()),
149        ];
150        let span = self
151            .tracer
152            .span_builder("parse")
153            .with_kind(SpanKind::Server)
154            .with_attributes(attributes)
155            .start(&*self.tracer);
156
157        async move {
158            let res = next.run(ctx, query, variables).await;
159            if let Ok(doc) = &res {
160                OpenTelemetryContext::current()
161                    .span()
162                    .set_attribute(KeyValue::new(
163                        KEY_SOURCE,
164                        ctx.stringify_execute_doc(doc, variables),
165                    ));
166            }
167            res
168        }
169        .with_context(OpenTelemetryContext::current_with_span(span))
170        .await
171    }
172
173    async fn validation(
174        &self,
175        ctx: &ExtensionContext<'_>,
176        next: NextValidation<'_>,
177    ) -> Result<ValidationResult, Vec<ServerError>> {
178        let span = self
179            .tracer
180            .span_builder("validation")
181            .with_kind(SpanKind::Server)
182            .start(&*self.tracer);
183        next.run(ctx)
184            .with_context(OpenTelemetryContext::current_with_span(span))
185            .map_ok(|res| {
186                let current_cx = OpenTelemetryContext::current();
187                let span = current_cx.span();
188                span.set_attribute(KeyValue::new(KEY_COMPLEXITY, res.complexity as i64));
189                span.set_attribute(KeyValue::new(KEY_DEPTH, res.depth as i64));
190                res
191            })
192            .await
193    }
194
195    async fn execute(
196        &self,
197        ctx: &ExtensionContext<'_>,
198        operation_name: Option<&str>,
199        next: NextExecute<'_>,
200    ) -> Response {
201        let span = self
202            .tracer
203            .span_builder("execute")
204            .with_kind(SpanKind::Server)
205            .start(&*self.tracer);
206        next.run(ctx, operation_name)
207            .with_context(OpenTelemetryContext::current_with_span(span))
208            .await
209    }
210
211    async fn resolve(
212        &self,
213        ctx: &ExtensionContext<'_>,
214        info: ResolveInfo<'_>,
215        next: NextResolve<'_>,
216    ) -> ServerResult<Option<Value>> {
217        // Check if we should skip tracing for this field
218        let should_trace = if info.is_for_introspection {
219            false
220        } else if !self.trace_scalars {
221            // Check if the return type is a scalar or enum (leaf type)
222            let concrete_type = MetaTypeName::concrete_typename(info.return_type);
223            !ctx.schema_env
224                .registry
225                .types
226                .get(concrete_type)
227                .map(|ty| ty.is_leaf())
228                .unwrap_or(false)
229        } else {
230            true
231        };
232
233        let span = if should_trace {
234            let attributes = vec![
235                KeyValue::new(KEY_PARENT_TYPE, info.parent_type.to_string()),
236                KeyValue::new(KEY_RETURN_TYPE, info.return_type.to_string()),
237            ];
238            Some(
239                self.tracer
240                    .span_builder(info.path_node.to_string())
241                    .with_kind(SpanKind::Server)
242                    .with_attributes(attributes)
243                    .start(&*self.tracer),
244            )
245        } else {
246            None
247        };
248
249        let fut = next.run(ctx, info).inspect_err(|err| {
250            let current_cx = OpenTelemetryContext::current();
251            current_cx.span().add_event(
252                "error".to_string(),
253                vec![KeyValue::new(KEY_ERROR, err.to_string())],
254            );
255        });
256
257        match span {
258            Some(span) => {
259                fut.with_context(OpenTelemetryContext::current_with_span(span))
260                    .await
261            }
262            None => fut.await,
263        }
264    }
265}