async_graphql/extensions/
mod.rs

1//! Extensions for schema
2
3mod analyzer;
4#[cfg(feature = "apollo_persisted_queries")]
5pub mod apollo_persisted_queries;
6#[cfg(feature = "apollo_tracing")]
7mod apollo_tracing;
8#[cfg(feature = "log")]
9mod logger;
10
11#[cfg(feature = "tracing")]
12mod tracing;
13
14use std::{
15    any::{Any, TypeId},
16    future::Future,
17    sync::Arc,
18};
19
20use futures_util::{FutureExt, future::BoxFuture, stream::BoxStream};
21
22pub use self::analyzer::Analyzer;
23#[cfg(feature = "apollo_tracing")]
24pub use self::apollo_tracing::ApolloTracing;
25#[cfg(feature = "log")]
26pub use self::logger::Logger;
27#[cfg(feature = "tracing")]
28pub use self::tracing::Tracing;
29use crate::{
30    Data, DataContext, Error, QueryPathNode, Request, Response, Result, SDLExportOptions,
31    SchemaEnv, ServerError, ServerResult, ValidationResult, Value, Variables,
32    parser::types::{ExecutableDocument, Field},
33};
34
35/// Context for extension
36pub struct ExtensionContext<'a> {
37    /// Schema-scope context data, [`Registry`], and custom directives.
38    pub schema_env: &'a SchemaEnv,
39
40    /// Extension-scoped context data shared across all extensions.
41    ///
42    /// Can be accessed only from hooks that implement the [`Extension`] trait.
43    ///
44    /// It is created with each new [`Request`] and is empty by default.
45    ///
46    /// For subscriptions, the session ends when the subscription is closed.
47    pub session_data: &'a Data,
48
49    /// Request-scoped context data shared across all resolvers.
50    ///
51    /// This is a reference to [`Request::data`](Request) field.
52    /// If the request has not initialized yet, the value is seen as `None`
53    /// inside the [`Extension::request`], [`Extension::subscribe`], and
54    /// [`Extension::prepare_request`] hooks.
55    pub query_data: Option<&'a Data>,
56}
57
58impl<'a> DataContext<'a> for ExtensionContext<'a> {
59    fn data<D: Any + Send + Sync>(&self) -> Result<&'a D> {
60        ExtensionContext::data::<D>(self)
61    }
62
63    fn data_unchecked<D: Any + Send + Sync>(&self) -> &'a D {
64        ExtensionContext::data_unchecked::<D>(self)
65    }
66
67    fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> {
68        ExtensionContext::data_opt::<D>(self)
69    }
70}
71
72impl<'a> ExtensionContext<'a> {
73    /// Convert the specified [ExecutableDocument] into a query string.
74    ///
75    /// Usually used for log extension, it can hide secret arguments.
76    pub fn stringify_execute_doc(&self, doc: &ExecutableDocument, variables: &Variables) -> String {
77        self.schema_env
78            .registry
79            .stringify_exec_doc(variables, doc)
80            .unwrap_or_default()
81    }
82
83    /// Returns SDL(Schema Definition Language) of this schema.
84    pub fn sdl(&self) -> String {
85        self.schema_env.registry.export_sdl(Default::default())
86    }
87
88    /// Returns SDL(Schema Definition Language) of this schema with options.
89    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
90        self.schema_env.registry.export_sdl(options)
91    }
92
93    /// Gets the global data defined in the `Context` or `Schema`.
94    ///
95    /// If both `Schema` and `Query` have the same data type, the data in the
96    /// `Query` is obtained.
97    ///
98    /// # Errors
99    ///
100    /// Returns a `Error` if the specified type data does not exist.
101    pub fn data<D: Any + Send + Sync>(&self) -> Result<&'a D> {
102        self.data_opt::<D>().ok_or_else(|| {
103            Error::new(format!(
104                "Data `{}` does not exist.",
105                std::any::type_name::<D>()
106            ))
107        })
108    }
109
110    /// Gets the global data defined in the `Context` or `Schema`.
111    ///
112    /// # Panics
113    ///
114    /// It will panic if the specified data type does not exist.
115    pub fn data_unchecked<D: Any + Send + Sync>(&self) -> &'a D {
116        self.data_opt::<D>()
117            .unwrap_or_else(|| panic!("Data `{}` does not exist.", std::any::type_name::<D>()))
118    }
119
120    /// Gets the global data defined in the `Context` or `Schema` or `None` if
121    /// the specified type data does not exist.
122    pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> {
123        self.query_data
124            .and_then(|query_data| query_data.get(&TypeId::of::<D>()))
125            .or_else(|| self.session_data.get(&TypeId::of::<D>()))
126            .or_else(|| self.schema_env.data.get(&TypeId::of::<D>()))
127            .and_then(|d| d.downcast_ref::<D>())
128    }
129}
130
131/// Parameters for `Extension::resolve_field_start`
132pub struct ResolveInfo<'a> {
133    /// Current path node, You can go through the entire path.
134    pub path_node: &'a QueryPathNode<'a>,
135
136    /// Parent type
137    pub parent_type: &'a str,
138
139    /// Current return type, is qualified name.
140    pub return_type: &'a str,
141
142    /// Current field name
143    pub name: &'a str,
144
145    /// Current field alias
146    pub alias: Option<&'a str>,
147
148    /// If `true` means the current field is for introspection.
149    pub is_for_introspection: bool,
150
151    /// Current field
152    pub field: &'a Field,
153}
154
155type RequestFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
156
157type ParseFut<'a> = &'a mut (dyn Future<Output = ServerResult<ExecutableDocument>> + Send + Unpin);
158
159type ValidationFut<'a> =
160    &'a mut (dyn Future<Output = Result<ValidationResult, Vec<ServerError>>> + Send + Unpin);
161
162type ExecuteFutFactory<'a> = Box<dyn FnOnce(Option<Data>) -> BoxFuture<'a, Response> + Send + 'a>;
163
164/// A future type used to resolve the field
165pub type ResolveFut<'a> = &'a mut (dyn Future<Output = ServerResult<Option<Value>>> + Send + Unpin);
166
167/// The remainder of a extension chain for request.
168pub struct NextRequest<'a> {
169    chain: &'a [Arc<dyn Extension>],
170    request_fut: RequestFut<'a>,
171}
172
173impl NextRequest<'_> {
174    /// Call the [Extension::request] function of next extension.
175    pub async fn run(self, ctx: &ExtensionContext<'_>) -> Response {
176        if let Some((first, next)) = self.chain.split_first() {
177            first
178                .request(
179                    ctx,
180                    NextRequest {
181                        chain: next,
182                        request_fut: self.request_fut,
183                    },
184                )
185                .await
186        } else {
187            self.request_fut.await
188        }
189    }
190}
191
192/// The remainder of a extension chain for subscribe.
193pub struct NextSubscribe<'a> {
194    chain: &'a [Arc<dyn Extension>],
195}
196
197impl NextSubscribe<'_> {
198    /// Call the [Extension::subscribe] function of next extension.
199    pub fn run<'s>(
200        self,
201        ctx: &ExtensionContext<'_>,
202        stream: BoxStream<'s, Response>,
203    ) -> BoxStream<'s, Response> {
204        if let Some((first, next)) = self.chain.split_first() {
205            first.subscribe(ctx, stream, NextSubscribe { chain: next })
206        } else {
207            stream
208        }
209    }
210}
211
212/// The remainder of a extension chain for subscribe.
213pub struct NextPrepareRequest<'a> {
214    chain: &'a [Arc<dyn Extension>],
215}
216
217impl NextPrepareRequest<'_> {
218    /// Call the [Extension::prepare_request] function of next extension.
219    pub async fn run(self, ctx: &ExtensionContext<'_>, request: Request) -> ServerResult<Request> {
220        if let Some((first, next)) = self.chain.split_first() {
221            first
222                .prepare_request(ctx, request, NextPrepareRequest { chain: next })
223                .await
224        } else {
225            Ok(request)
226        }
227    }
228}
229
230/// The remainder of a extension chain for parse query.
231pub struct NextParseQuery<'a> {
232    chain: &'a [Arc<dyn Extension>],
233    parse_query_fut: ParseFut<'a>,
234}
235
236impl NextParseQuery<'_> {
237    /// Call the [Extension::parse_query] function of next extension.
238    pub async fn run(
239        self,
240        ctx: &ExtensionContext<'_>,
241        query: &str,
242        variables: &Variables,
243    ) -> ServerResult<ExecutableDocument> {
244        if let Some((first, next)) = self.chain.split_first() {
245            first
246                .parse_query(
247                    ctx,
248                    query,
249                    variables,
250                    NextParseQuery {
251                        chain: next,
252                        parse_query_fut: self.parse_query_fut,
253                    },
254                )
255                .await
256        } else {
257            self.parse_query_fut.await
258        }
259    }
260}
261
262/// The remainder of a extension chain for validation.
263pub struct NextValidation<'a> {
264    chain: &'a [Arc<dyn Extension>],
265    validation_fut: ValidationFut<'a>,
266}
267
268impl NextValidation<'_> {
269    /// Call the [Extension::validation] function of next extension.
270    pub async fn run(
271        self,
272        ctx: &ExtensionContext<'_>,
273    ) -> Result<ValidationResult, Vec<ServerError>> {
274        if let Some((first, next)) = self.chain.split_first() {
275            first
276                .validation(
277                    ctx,
278                    NextValidation {
279                        chain: next,
280                        validation_fut: self.validation_fut,
281                    },
282                )
283                .await
284        } else {
285            self.validation_fut.await
286        }
287    }
288}
289
290/// The remainder of a extension chain for execute.
291pub struct NextExecute<'a> {
292    chain: &'a [Arc<dyn Extension>],
293    execute_fut_factory: ExecuteFutFactory<'a>,
294    execute_data: Option<Data>,
295}
296
297impl NextExecute<'_> {
298    async fn internal_run(
299        self,
300        ctx: &ExtensionContext<'_>,
301        operation_name: Option<&str>,
302        data: Option<Data>,
303    ) -> Response {
304        let execute_data = match (self.execute_data, data) {
305            (Some(mut data1), Some(data2)) => {
306                data1.merge(data2);
307                Some(data1)
308            }
309            (Some(data), None) => Some(data),
310            (None, Some(data)) => Some(data),
311            (None, None) => None,
312        };
313
314        if let Some((first, next)) = self.chain.split_first() {
315            first
316                .execute(
317                    ctx,
318                    operation_name,
319                    NextExecute {
320                        chain: next,
321                        execute_fut_factory: self.execute_fut_factory,
322                        execute_data,
323                    },
324                )
325                .await
326        } else {
327            (self.execute_fut_factory)(execute_data).await
328        }
329    }
330
331    /// Call the [Extension::execute] function of next extension.
332    pub async fn run(self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>) -> Response {
333        self.internal_run(ctx, operation_name, None).await
334    }
335
336    /// Call the [Extension::execute] function of next extension with context
337    /// data.
338    pub async fn run_with_data(
339        self,
340        ctx: &ExtensionContext<'_>,
341        operation_name: Option<&str>,
342        data: Data,
343    ) -> Response {
344        self.internal_run(ctx, operation_name, Some(data)).await
345    }
346}
347
348/// The remainder of a extension chain for resolve.
349pub struct NextResolve<'a> {
350    chain: &'a [Arc<dyn Extension>],
351    resolve_fut: ResolveFut<'a>,
352}
353
354impl NextResolve<'_> {
355    /// Call the [Extension::resolve] function of next extension.
356    pub async fn run(
357        self,
358        ctx: &ExtensionContext<'_>,
359        info: ResolveInfo<'_>,
360    ) -> ServerResult<Option<Value>> {
361        if let Some((first, next)) = self.chain.split_first() {
362            first
363                .resolve(
364                    ctx,
365                    info,
366                    NextResolve {
367                        chain: next,
368                        resolve_fut: self.resolve_fut,
369                    },
370                )
371                .await
372        } else {
373            self.resolve_fut.await
374        }
375    }
376}
377
378/// Represents a GraphQL extension
379#[async_trait::async_trait]
380pub trait Extension: Sync + Send + 'static {
381    /// Called at start query/mutation request.
382    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
383        next.run(ctx).await
384    }
385
386    /// Called at subscribe request.
387    fn subscribe<'s>(
388        &self,
389        ctx: &ExtensionContext<'_>,
390        stream: BoxStream<'s, Response>,
391        next: NextSubscribe<'_>,
392    ) -> BoxStream<'s, Response> {
393        next.run(ctx, stream)
394    }
395
396    /// Called at prepare request.
397    async fn prepare_request(
398        &self,
399        ctx: &ExtensionContext<'_>,
400        request: Request,
401        next: NextPrepareRequest<'_>,
402    ) -> ServerResult<Request> {
403        next.run(ctx, request).await
404    }
405
406    /// Called at parse query.
407    async fn parse_query(
408        &self,
409        ctx: &ExtensionContext<'_>,
410        query: &str,
411        variables: &Variables,
412        next: NextParseQuery<'_>,
413    ) -> ServerResult<ExecutableDocument> {
414        next.run(ctx, query, variables).await
415    }
416
417    /// Called at validation query.
418    async fn validation(
419        &self,
420        ctx: &ExtensionContext<'_>,
421        next: NextValidation<'_>,
422    ) -> Result<ValidationResult, Vec<ServerError>> {
423        next.run(ctx).await
424    }
425
426    /// Called at execute query.
427    async fn execute(
428        &self,
429        ctx: &ExtensionContext<'_>,
430        operation_name: Option<&str>,
431        next: NextExecute<'_>,
432    ) -> Response {
433        next.run(ctx, operation_name).await
434    }
435
436    /// Called at resolve field.
437    async fn resolve(
438        &self,
439        ctx: &ExtensionContext<'_>,
440        info: ResolveInfo<'_>,
441        next: NextResolve<'_>,
442    ) -> ServerResult<Option<Value>> {
443        next.run(ctx, info).await
444    }
445}
446
447/// Extension factory
448///
449/// Used to create an extension instance.
450pub trait ExtensionFactory: Send + Sync + 'static {
451    /// Create an extended instance.
452    fn create(&self) -> Arc<dyn Extension>;
453}
454
455#[derive(Clone)]
456#[doc(hidden)]
457pub struct Extensions {
458    extensions: Vec<Arc<dyn Extension>>,
459    schema_env: SchemaEnv,
460    session_data: Arc<Data>,
461    query_data: Option<Arc<Data>>,
462}
463
464#[doc(hidden)]
465impl Extensions {
466    pub(crate) fn new(
467        extensions: impl IntoIterator<Item = Arc<dyn Extension>>,
468        schema_env: SchemaEnv,
469        session_data: Arc<Data>,
470    ) -> Self {
471        Extensions {
472            extensions: extensions.into_iter().collect(),
473            schema_env,
474            session_data,
475            query_data: None,
476        }
477    }
478
479    #[inline]
480    pub(crate) fn attach_query_data(&mut self, data: Arc<Data>) {
481        self.query_data = Some(data);
482    }
483
484    #[inline]
485    pub(crate) fn is_empty(&self) -> bool {
486        self.extensions.is_empty()
487    }
488
489    #[inline]
490    fn create_context(&self) -> ExtensionContext<'_> {
491        ExtensionContext {
492            schema_env: &self.schema_env,
493            session_data: &self.session_data,
494            query_data: self.query_data.as_deref(),
495        }
496    }
497
498    pub async fn request(&self, request_fut: RequestFut<'_>) -> Response {
499        let next = NextRequest {
500            chain: &self.extensions,
501            request_fut,
502        };
503        next.run(&self.create_context()).await
504    }
505
506    pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> {
507        let next = NextSubscribe {
508            chain: &self.extensions,
509        };
510        next.run(&self.create_context(), stream)
511    }
512
513    pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> {
514        let next = NextPrepareRequest {
515            chain: &self.extensions,
516        };
517        next.run(&self.create_context(), request).await
518    }
519
520    pub async fn parse_query(
521        &self,
522        query: &str,
523        variables: &Variables,
524        parse_query_fut: ParseFut<'_>,
525    ) -> ServerResult<ExecutableDocument> {
526        let next = NextParseQuery {
527            chain: &self.extensions,
528            parse_query_fut,
529        };
530        next.run(&self.create_context(), query, variables).await
531    }
532
533    pub async fn validation(
534        &self,
535        validation_fut: ValidationFut<'_>,
536    ) -> Result<ValidationResult, Vec<ServerError>> {
537        let next = NextValidation {
538            chain: &self.extensions,
539            validation_fut,
540        };
541        next.run(&self.create_context()).await
542    }
543
544    pub async fn execute<'a, 'b, F, T>(
545        &'a self,
546        operation_name: Option<&str>,
547        execute_fut_factory: F,
548    ) -> Response
549    where
550        F: FnOnce(Option<Data>) -> T + Send + 'a,
551        T: Future<Output = Response> + Send + 'a,
552    {
553        let next = NextExecute {
554            chain: &self.extensions,
555            execute_fut_factory: Box::new(|data| execute_fut_factory(data).boxed()),
556            execute_data: None,
557        };
558        next.run(&self.create_context(), operation_name).await
559    }
560
561    pub async fn resolve(
562        &self,
563        info: ResolveInfo<'_>,
564        resolve_fut: ResolveFut<'_>,
565    ) -> ServerResult<Option<Value>> {
566        let next = NextResolve {
567            chain: &self.extensions,
568            resolve_fut,
569        };
570        next.run(&self.create_context(), info).await
571    }
572}