grafbase_sdk/
extension.rs

1#![allow(static_mut_refs)]
2
3pub mod authentication;
4pub mod resolver;
5
6pub use authentication::Authenticator;
7pub use resolver::Resolver;
8
9use crate::{
10    types::{Configuration, FieldInputs},
11    wit::{Directive, Error, ExtensionType, FieldDefinition, FieldOutput, Guest, Headers, SharedContext, Token},
12    Component,
13};
14
15/// A trait representing an extension that can be initialized from schema directives.
16///
17/// This trait is intended to define a common interface for extensions in Grafbase Gateway,
18/// particularly focusing on their initialization. Extensions are constructed using
19/// a vector of `Directive` instances provided by the type definitions in the schema.
20pub trait Extension {
21    /// Creates a new instance of the extension from the given schema directives.
22    ///
23    /// The directives must be defined in the extension configuration, and written
24    /// to the federated schema. The directives are deserialized from their GraphQL
25    /// definitions to the corresponding `Directive` instances.
26    fn new(
27        schema_directives: Vec<crate::types::Directive>,
28        config: Configuration,
29    ) -> Result<Self, Box<dyn std::error::Error>>
30    where
31        Self: Sized;
32}
33
34impl Guest for Component {
35    fn init_gateway_extension(
36        r#type: ExtensionType,
37        directives: Vec<Directive>,
38        configuration: Vec<u8>,
39    ) -> Result<(), String> {
40        let directives = directives.into_iter().map(Into::into).collect();
41        let config = Configuration::new(configuration);
42
43        let result = match r#type {
44            ExtensionType::Resolver => resolver::init(directives, config),
45            ExtensionType::Authentication => authentication::init(directives, config),
46        };
47
48        result.map_err(|e| e.to_string())
49    }
50
51    fn resolve_field(
52        context: SharedContext,
53        directive: Directive,
54        definition: FieldDefinition,
55        inputs: Vec<Vec<u8>>,
56    ) -> Result<FieldOutput, Error> {
57        let result = resolver::get_extension()?.resolve_field(
58            context,
59            directive.into(),
60            definition.into(),
61            FieldInputs::new(inputs),
62        );
63
64        result.map(Into::into)
65    }
66
67    fn resolve_subscription(
68        context: SharedContext,
69        directive: Directive,
70        definition: FieldDefinition,
71    ) -> Result<(), Error> {
72        let subscriber =
73            resolver::get_extension()?.resolve_subscription(context, directive.into(), definition.into())?;
74
75        resolver::set_subscriber(subscriber);
76
77        Ok(())
78    }
79
80    fn resolve_next_subscription_item() -> Result<Option<FieldOutput>, Error> {
81        Ok(resolver::get_subscriber()?.next()?.map(Into::into))
82    }
83
84    fn authenticate(headers: Headers) -> Result<Token, crate::wit::ErrorResponse> {
85        let result = authentication::get_extension()
86            .map_err(|_| crate::wit::ErrorResponse {
87                status_code: 500,
88                errors: vec![Error {
89                    extensions: Vec::new(),
90                    message: String::from("internal server error"),
91                }],
92            })?
93            .authenticate(headers);
94
95        result.map(Into::into).map_err(Into::into)
96    }
97}