grafbase_sdk_mock/
builder.rs

1#![allow(clippy::panic)]
2
3use std::collections::HashMap;
4
5use async_graphql::{
6    ServerError,
7    dynamic::{FieldValue, ResolverContext},
8};
9use cynic_parser::{common::WrappingType, type_system as parser};
10use serde::Deserialize;
11
12use super::{
13    GraphqlSubgraph,
14    entity_resolver::{EntityResolver, EntityResolverContext},
15    resolver::Resolver,
16};
17
18type ResolverMap = HashMap<(String, String), Box<dyn Resolver>>;
19type EntityResolverMap = HashMap<String, Box<dyn EntityResolver>>;
20
21/// A builder for dynamic GraphQL schemas.
22pub struct GraphqlSubgraphBuilder {
23    sdl: String,
24    name: String,
25    field_resolvers: ResolverMap,
26    entity_resolvers: EntityResolverMap,
27}
28
29impl GraphqlSubgraphBuilder {
30    pub(crate) fn new(sdl: String, name: String) -> Self {
31        GraphqlSubgraphBuilder {
32            sdl,
33            name,
34            field_resolvers: Default::default(),
35            entity_resolvers: Default::default(),
36        }
37    }
38
39    /// Specify the name of this subgraph.
40    pub fn with_name(mut self, name: impl AsRef<str>) -> Self {
41        self.name = name.as_ref().to_string();
42        self
43    }
44
45    /// Adds a field resolver to this schema.
46    ///
47    /// # Arguments
48    /// * `ty` - The name of the type that contains the field
49    /// * `field` - The name of the field to resolve
50    /// * `resolver` - A resolver implementation for the field
51    pub fn with_resolver(mut self, ty: &str, field: &str, resolver: impl Resolver + 'static) -> Self {
52        self.field_resolvers
53            .insert((ty.into(), field.into()), Box::new(resolver));
54        self
55    }
56
57    /// Adds an entity resolver to this schema.
58    ///
59    /// # Arguments
60    /// * `entity` - The name of the entity type to resolve
61    /// * `resolver` - A resolver implementation for the entity
62    pub fn with_entity_resolver(mut self, entity: &str, resolver: impl EntityResolver + 'static) -> Self {
63        self.entity_resolvers.insert(entity.into(), Box::new(resolver));
64        self
65    }
66
67    /// Builds the GraphQL subgraph.
68    pub fn build(self) -> GraphqlSubgraph {
69        let Self {
70            sdl,
71            name,
72            mut field_resolvers,
73            entity_resolvers,
74        } = self;
75
76        let schema = cynic_parser::parse_type_system_document(&sdl)
77            .map_err(|e| e.to_report(&sdl))
78            .expect("a valid document");
79
80        let (query_type, ..) = root_types(&schema);
81
82        // Note: don't enable federation on this because we want to provide all that stuff ourselves
83        let mut builder = schema_builder(&schema);
84        builder = builder.register(service_type(&sdl));
85        let entities = find_entities(&schema);
86
87        if !entities.is_empty() {
88            builder = builder.register(any_type());
89            builder = builder.register(entity_type(&entities));
90        }
91
92        let mut entity_resolvers = Some(entity_resolvers);
93        for definition in schema.definitions() {
94            match definition {
95                parser::Definition::Type(def) => {
96                    let mut ty = convert_type(def, &mut field_resolvers);
97                    if def.name() == query_type {
98                        if let Some(entity_resolvers) = entity_resolvers.take() {
99                            ty = add_federation_fields(ty, &entities, entity_resolvers);
100                        }
101                    }
102                    builder = builder.register(ty);
103                }
104                parser::Definition::TypeExtension(_) => {
105                    unimplemented!("this is just for tests, extensions aren't supported")
106                }
107                _ => {}
108            }
109        }
110
111        if entity_resolvers.is_some() && !entities.is_empty() {
112            let entity_resolvers = entity_resolvers.unwrap();
113            builder = builder.register(add_federation_fields(
114                async_graphql::dynamic::Object::new("Query").into(),
115                &entities,
116                entity_resolvers,
117            ));
118        }
119
120        let executable_schema = builder.finish().unwrap();
121
122        GraphqlSubgraph {
123            executable_schema,
124            schema: sdl,
125            name,
126        }
127    }
128}
129
130fn find_entities(schema: &parser::TypeSystemDocument) -> Vec<&str> {
131    schema
132        .definitions()
133        .filter_map(|def| match def {
134            parser::Definition::Type(def) => Some(def),
135            parser::Definition::TypeExtension(def) => Some(def),
136            _ => None,
137        })
138        .filter(|def| def.directives().any(|directive| directive.name() == "key"))
139        .map(|def| def.name())
140        .collect()
141}
142
143fn convert_type(def: parser::TypeDefinition<'_>, resolvers: &mut ResolverMap) -> async_graphql::dynamic::Type {
144    match def {
145        parser::TypeDefinition::Scalar(def) => async_graphql::dynamic::Scalar::new(def.name()).into(),
146        parser::TypeDefinition::Object(def) => convert_object(def, resolvers),
147        parser::TypeDefinition::Interface(def) => convert_iface(def),
148        parser::TypeDefinition::Union(def) => convert_union(def),
149        parser::TypeDefinition::Enum(def) => convert_enum(def),
150        parser::TypeDefinition::InputObject(def) => convert_input_object(def),
151    }
152}
153
154fn convert_object(def: parser::ObjectDefinition<'_>, resolvers: &mut ResolverMap) -> async_graphql::dynamic::Type {
155    use async_graphql::dynamic::*;
156
157    let mut object = Object::new(def.name());
158
159    for name in def.implements_interfaces() {
160        object = object.implement(name);
161    }
162
163    for field_def in def.fields() {
164        let type_ref = convert_type_ref(field_def.ty());
165        let resolver = std::sync::Mutex::new(
166            resolvers
167                .remove(&(def.name().into(), field_def.name().into()))
168                .unwrap_or_else(|| Box::new(default_field_resolver(field_def.name()))),
169        );
170
171        let mut field = Field::new(field_def.name(), type_ref, move |context| {
172            let mut resolver = resolver.lock().expect("mutex to be unpoisoned");
173            FieldFuture::Value(resolver.resolve(context).map(|value| {
174                let value = async_graphql::Value::deserialize(value).unwrap();
175                transform_into_field_value(value)
176            }))
177        });
178
179        for argument in field_def.arguments() {
180            field = field.argument(convert_input_value(argument));
181        }
182
183        object = object.field(field);
184    }
185
186    object.into()
187}
188
189fn transform_into_field_value(mut value: async_graphql::Value) -> FieldValue<'static> {
190    match value {
191        async_graphql::Value::Object(ref mut fields) => {
192            if let Some(async_graphql::Value::String(ty)) = fields.swap_remove("__typename") {
193                FieldValue::from(value).with_type(ty)
194            } else {
195                FieldValue::from(value)
196            }
197        }
198        async_graphql::Value::List(values) => FieldValue::list(values.into_iter().map(transform_into_field_value)),
199        value => FieldValue::from(value),
200    }
201}
202
203fn convert_iface(def: parser::InterfaceDefinition<'_>) -> async_graphql::dynamic::Type {
204    use async_graphql::dynamic::*;
205    let mut interface = Interface::new(def.name());
206
207    for field_def in def.fields() {
208        let type_ref = convert_type_ref(field_def.ty());
209
210        let mut field = InterfaceField::new(field_def.name(), type_ref);
211
212        for argument in field_def.arguments() {
213            field = field.argument(convert_input_value(argument));
214        }
215
216        interface = interface.field(field);
217    }
218
219    interface.into()
220}
221
222fn convert_union(def: parser::UnionDefinition<'_>) -> async_graphql::dynamic::Type {
223    use async_graphql::dynamic::*;
224
225    let mut output = Union::new(def.name());
226
227    for member in def.members() {
228        output = output.possible_type(member.name());
229    }
230
231    output.into()
232}
233
234fn convert_enum(def: parser::EnumDefinition<'_>) -> async_graphql::dynamic::Type {
235    use async_graphql::dynamic::*;
236
237    Enum::new(def.name())
238        .items(def.values().map(|value| EnumItem::new(value.value())))
239        .into()
240}
241
242fn convert_input_object(def: parser::InputObjectDefinition<'_>) -> async_graphql::dynamic::Type {
243    use async_graphql::dynamic::*;
244
245    let mut object = InputObject::new(def.name());
246
247    for field_def in def.fields() {
248        object = object.field(convert_input_value(field_def))
249    }
250
251    object.into()
252}
253
254fn convert_type_ref(ty: parser::Type<'_>) -> async_graphql::dynamic::TypeRef {
255    use async_graphql::dynamic::TypeRef;
256
257    let mut output = TypeRef::named(ty.name());
258
259    for wrapper in ty.wrappers() {
260        match wrapper {
261            WrappingType::NonNull => {
262                output = TypeRef::NonNull(Box::new(output));
263            }
264            WrappingType::List => {
265                output = TypeRef::List(Box::new(output));
266            }
267        }
268    }
269
270    output
271}
272
273fn convert_input_value(value_def: parser::InputValueDefinition<'_>) -> async_graphql::dynamic::InputValue {
274    use async_graphql::dynamic::InputValue;
275
276    let mut value = InputValue::new(value_def.name(), convert_type_ref(value_def.ty()));
277
278    if let Some(default) = value_def.default_value() {
279        value = value.default_value(convert_value(default))
280    }
281
282    value
283}
284
285fn convert_value(value: cynic_parser::ConstValue<'_>) -> async_graphql::Value {
286    match value {
287        cynic_parser::ConstValue::Int(inner) => async_graphql::Value::Number(inner.as_i64().into()),
288        cynic_parser::ConstValue::Float(inner) => {
289            async_graphql::Value::Number(serde_json::Number::from_f64(inner.as_f64()).unwrap())
290        }
291        cynic_parser::ConstValue::String(inner) => async_graphql::Value::String(inner.as_str().into()),
292        cynic_parser::ConstValue::Boolean(inner) => async_graphql::Value::Boolean(inner.as_bool()),
293        cynic_parser::ConstValue::Null(_) => async_graphql::Value::Null,
294        cynic_parser::ConstValue::Enum(inner) => async_graphql::Value::Enum(async_graphql::Name::new(inner.name())),
295        cynic_parser::ConstValue::List(inner) => async_graphql::Value::List(inner.items().map(convert_value).collect()),
296        cynic_parser::ConstValue::Object(inner) => async_graphql::Value::Object(
297            inner
298                .fields()
299                .map(|field| (async_graphql::Name::new(field.name()), convert_value(field.value())))
300                .collect(),
301        ),
302    }
303}
304
305fn schema_builder(schema: &cynic_parser::TypeSystemDocument) -> async_graphql::dynamic::SchemaBuilder {
306    let (query_name, mutation_name, subscription_name) = root_types(schema);
307    async_graphql::dynamic::Schema::build(query_name, mutation_name, subscription_name)
308}
309
310fn root_types(schema: &cynic_parser::TypeSystemDocument) -> (&str, Option<&str>, Option<&str>) {
311    use parser::Definition;
312
313    let mut query_name = "Query";
314    let mut mutation_name = None;
315    let mut subscription_name = None;
316    let mut found_schema_def = false;
317    let mut mutation_present = false;
318    let mut subscription_present = false;
319    for definition in schema.definitions() {
320        if let Definition::Schema(_) = definition {
321            found_schema_def = true;
322        }
323        match definition {
324            Definition::Schema(schema) | Definition::SchemaExtension(schema) => {
325                if let Some(def) = schema.query_type() {
326                    query_name = def.named_type();
327                }
328                if let Some(def) = schema.mutation_type() {
329                    mutation_name = Some(def.named_type());
330                }
331                if let Some(def) = schema.subscription_type() {
332                    subscription_name = Some(def.named_type());
333                }
334            }
335            Definition::Type(ty) | Definition::TypeExtension(ty) if ty.name() == "Mutation" => mutation_present = true,
336            Definition::Type(ty) | Definition::TypeExtension(ty) if ty.name() == "Subscription" => {
337                subscription_present = true
338            }
339            _ => {}
340        }
341    }
342    if !found_schema_def {
343        if mutation_present {
344            mutation_name = Some("Mutation");
345        }
346        if subscription_present {
347            mutation_name = Some("Subscription");
348        }
349    }
350
351    (query_name, mutation_name, subscription_name)
352}
353
354fn default_field_resolver(field_name: &str) -> impl Resolver + 'static {
355    let field_name = async_graphql::Name::new(field_name);
356
357    move |context: ResolverContext<'_>| {
358        if let Some(value) = context.parent_value.as_value() {
359            return match value {
360                async_graphql::Value::Object(map) => {
361                    map.get(&field_name).map(|value| value.clone().into_json().unwrap())
362                }
363                _ => None,
364            };
365        }
366        panic!("Unexpected parent value for tests: {:?}", context.parent_value)
367    }
368}
369
370fn service_type(sdl: &str) -> async_graphql::dynamic::Type {
371    use async_graphql::dynamic::*;
372    let mut object = Object::new("_Service");
373
374    let sdl = sdl.to_string();
375
376    object = object.field(Field::new("sdl", TypeRef::named_nn("String"), move |_| {
377        FieldFuture::from_value(Some(async_graphql::Value::String(sdl.clone())))
378    }));
379
380    object.into()
381}
382
383fn entity_type(entities: &[&str]) -> async_graphql::dynamic::Type {
384    use async_graphql::dynamic::*;
385
386    let mut ty = Union::new("_Entity");
387
388    for entity in entities {
389        ty = ty.possible_type(*entity);
390    }
391
392    ty.into()
393}
394
395fn any_type() -> async_graphql::dynamic::Type {
396    use async_graphql::dynamic::*;
397
398    Scalar::new("_Any").into()
399}
400
401fn add_federation_fields(
402    query_ty: async_graphql::dynamic::Type,
403    entities: &[&str],
404    entity_resolvers: EntityResolverMap,
405) -> async_graphql::dynamic::Type {
406    use async_graphql::dynamic::*;
407
408    let async_graphql::dynamic::Type::Object(mut obj) = query_ty else {
409        panic!("this shouldn't happen probably")
410    };
411    obj = obj.field(Field::new("_service", TypeRef::named_nn("_Service"), |_| {
412        // Doesnt matter what we return here hopefully?
413        FieldFuture::from_value(Some(async_graphql::Value::Object([].into())))
414    }));
415
416    for entity in entity_resolvers.keys() {
417        if !entities.contains(&entity.as_str()) {
418            panic!("Tried to add an resolver for {entity}, but this entity doesnt exist");
419        }
420    }
421
422    if !entities.is_empty() {
423        let resolvers = std::sync::Mutex::new(entity_resolvers);
424
425        let entity_field = Field::new("_entities", TypeRef::named_list_nn("_Entity"), move |context| {
426            let mut resolvers = resolvers.lock().expect("mutex to be unpoisoned");
427            let representations = context
428                .args
429                .get("representations")
430                .expect("_entities needs representations");
431
432            let reprs = representations
433                .deserialize::<Vec<serde_json::Value>>()
434                .expect("representations to be a list of objects");
435
436            let entities = reprs.into_iter().map(|repr| {
437                let context = EntityResolverContext::new(&context, repr);
438
439                let typename = context.typename.clone();
440                let Some(resolver) = resolvers.get_mut(&context.typename) else {
441                    context.add_error(ServerError::new(format!("{} has no resolver", context.typename), None));
442                    return FieldValue::NULL;
443                };
444
445                let json_value = resolver.resolve(context).unwrap_or_default();
446
447                transform_into_field_value(async_graphql::Value::deserialize(json_value).unwrap()).with_type(typename)
448            });
449
450            FieldFuture::Value(Some(FieldValue::list(entities)))
451        })
452        .argument(InputValue::new("representations", TypeRef::named_nn_list_nn("_Any")));
453
454        obj = obj.field(entity_field);
455    }
456
457    obj.into()
458}