Skip to main content

async_graphql_extras/extensions/
opentelemetry.rs

1use std::sync::Arc;
2
3use async_graphql::{
4    Response, ServerError, ServerResult, ValidationResult, Value,
5    extensions::{
6        Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextRequest,
7        NextResolve, NextSubscribe, NextValidation, ResolveInfo,
8    },
9    registry::{MetaType, MetaTypeName},
10};
11use async_graphql_parser::types::ExecutableDocument;
12use async_graphql_value::Variables;
13use futures_util::{TryFutureExt, stream::BoxStream};
14use opentelemetry::{
15    Context as OpenTelemetryContext, Key, KeyValue,
16    trace::{FutureExt, SpanKind, TraceContextExt, Tracer},
17};
18
19const KEY_SOURCE: Key = Key::from_static_str("graphql.source");
20const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables");
21const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType");
22const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType");
23const KEY_ERROR: Key = Key::from_static_str("graphql.error");
24const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity");
25const KEY_DEPTH: Key = Key::from_static_str("graphql.depth");
26
27/// OpenTelemetry extension
28///
29/// # Example
30///
31/// ```ignore
32/// use async_graphql_extras::OpenTelemetry;
33///
34/// let tracer = todo!("create your OpenTelemetry tracer");
35///
36/// let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
37///     .extension(OpenTelemetry::new(tracer))
38///     .finish();
39/// ```
40#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
41pub struct OpenTelemetry<T> {
42    tracer: Arc<T>,
43    trace_scalars: bool,
44}
45
46impl<T> OpenTelemetry<T> {
47    /// Use `tracer` to create an OpenTelemetry extension.
48    pub fn new(tracer: T) -> OpenTelemetry<T>
49    where
50        T: Tracer + Send + Sync + 'static,
51        <T as Tracer>::Span: Sync + Send,
52    {
53        Self {
54            tracer: Arc::new(tracer),
55            trace_scalars: false,
56        }
57    }
58
59    /// Enable or disable tracing for scalar and enum field resolutions.
60    ///
61    /// When `false` (the default), spans are not created for fields that return
62    /// scalar or enum types. This significantly reduces the number of spans
63    /// generated, preventing span explosion in queries with many scalar fields.
64    ///
65    /// When `true`, spans are created for all field resolutions, including
66    /// scalars and enums.
67    ///
68    /// # Example
69    ///
70    /// ```ignore
71    /// use async_graphql::extensions::OpenTelemetry;
72    /// use async_graphql_extras::OpenTelemetry as ExtrasOpenTelemetry;
73    ///
74    /// let tracer = todo!("create your OpenTelemetry tracer");
75    ///
76    /// // Trace all fields including scalars
77    /// let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
78    ///     .extension(ExtrasOpenTelemetry::new(tracer).with_trace_scalars(true))
79    ///     .finish();
80    /// ```
81    #[must_use]
82    pub fn with_trace_scalars(mut self, trace_scalars: bool) -> Self {
83        self.trace_scalars = trace_scalars;
84        self
85    }
86}
87
88impl<T> ExtensionFactory for OpenTelemetry<T>
89where
90    T: Tracer + Send + Sync + 'static,
91    <T as Tracer>::Span: Sync + Send,
92{
93    fn create(&self) -> Arc<dyn Extension> {
94        Arc::new(OpenTelemetryExtension {
95            tracer: self.tracer.clone(),
96            trace_scalars: self.trace_scalars,
97        })
98    }
99}
100
101struct OpenTelemetryExtension<T> {
102    tracer: Arc<T>,
103    trace_scalars: bool,
104}
105
106#[async_trait::async_trait]
107impl<T> Extension for OpenTelemetryExtension<T>
108where
109    T: Tracer + Send + Sync + 'static,
110    <T as Tracer>::Span: Sync + Send,
111{
112    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
113        next.run(ctx)
114            .with_context(OpenTelemetryContext::current_with_span(
115                self.tracer
116                    .span_builder("request")
117                    .with_kind(SpanKind::Server)
118                    .start(&*self.tracer),
119            ))
120            .await
121    }
122
123    fn subscribe<'s>(
124        &self,
125        ctx: &ExtensionContext<'_>,
126        stream: BoxStream<'s, Response>,
127        next: NextSubscribe<'_>,
128    ) -> BoxStream<'s, Response> {
129        Box::pin(
130            next.run(ctx, stream)
131                .with_context(OpenTelemetryContext::current_with_span(
132                    self.tracer
133                        .span_builder("subscribe")
134                        .with_kind(SpanKind::Server)
135                        .start(&*self.tracer),
136                )),
137        )
138    }
139
140    async fn parse_query(
141        &self,
142        ctx: &ExtensionContext<'_>,
143        query: &str,
144        variables: &Variables,
145        next: NextParseQuery<'_>,
146    ) -> ServerResult<ExecutableDocument> {
147        let attributes = vec![
148            KeyValue::new(KEY_SOURCE, query.to_string()),
149            KeyValue::new(KEY_VARIABLES, serde_json::to_string(variables).unwrap()),
150        ];
151        let span = self
152            .tracer
153            .span_builder("parse")
154            .with_kind(SpanKind::Server)
155            .with_attributes(attributes)
156            .start(&*self.tracer);
157
158        async move {
159            let res = next.run(ctx, query, variables).await;
160            if let Ok(doc) = &res {
161                OpenTelemetryContext::current()
162                    .span()
163                    .set_attribute(KeyValue::new(
164                        KEY_SOURCE,
165                        ctx.stringify_execute_doc(doc, variables),
166                    ));
167            }
168            res
169        }
170        .with_context(OpenTelemetryContext::current_with_span(span))
171        .await
172    }
173
174    async fn validation(
175        &self,
176        ctx: &ExtensionContext<'_>,
177        next: NextValidation<'_>,
178    ) -> Result<ValidationResult, Vec<ServerError>> {
179        let span = self
180            .tracer
181            .span_builder("validation")
182            .with_kind(SpanKind::Server)
183            .start(&*self.tracer);
184        next.run(ctx)
185            .with_context(OpenTelemetryContext::current_with_span(span))
186            .map_ok(|res| {
187                let current_cx = OpenTelemetryContext::current();
188                let span = current_cx.span();
189                span.set_attribute(KeyValue::new(KEY_COMPLEXITY, res.complexity as i64));
190                span.set_attribute(KeyValue::new(KEY_DEPTH, res.depth as i64));
191                res
192            })
193            .await
194    }
195
196    async fn execute(
197        &self,
198        ctx: &ExtensionContext<'_>,
199        operation_name: Option<&str>,
200        next: NextExecute<'_>,
201    ) -> Response {
202        let span = self
203            .tracer
204            .span_builder("execute")
205            .with_kind(SpanKind::Server)
206            .start(&*self.tracer);
207        next.run(ctx, operation_name)
208            .with_context(OpenTelemetryContext::current_with_span(span))
209            .await
210    }
211
212    async fn resolve(
213        &self,
214        ctx: &ExtensionContext<'_>,
215        info: ResolveInfo<'_>,
216        next: NextResolve<'_>,
217    ) -> ServerResult<Option<Value>> {
218        // Check if we should skip tracing for this field
219        let should_trace = if info.is_for_introspection {
220            false
221        } else if !self.trace_scalars {
222            // Check if the return type is a scalar or enum (leaf type)
223            let concrete_type = MetaTypeName::concrete_typename(info.return_type);
224            !ctx.schema_env
225                .registry
226                .types
227                .get(concrete_type)
228                .is_some_and(MetaType::is_leaf)
229        } else {
230            true
231        };
232
233        let span = if should_trace {
234            let attributes = vec![
235                KeyValue::new(KEY_PARENT_TYPE, info.parent_type.to_string()),
236                KeyValue::new(KEY_RETURN_TYPE, info.return_type.to_string()),
237            ];
238            Some(
239                self.tracer
240                    .span_builder(info.path_node.to_string())
241                    .with_kind(SpanKind::Server)
242                    .with_attributes(attributes)
243                    .start(&*self.tracer),
244            )
245        } else {
246            None
247        };
248
249        let fut = next.run(ctx, info).inspect_err(|err| {
250            let current_cx = OpenTelemetryContext::current();
251            current_cx.span().add_event(
252                "error".to_string(),
253                vec![KeyValue::new(KEY_ERROR, err.to_string())],
254            );
255        });
256
257        match span {
258            Some(span) => {
259                fut.with_context(OpenTelemetryContext::current_with_span(span))
260                    .await
261            }
262            None => fut.await,
263        }
264    }
265}