trustfall_core/schema/
mod.rs

1#![allow(dead_code)]
2use std::{
3    collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet, VecDeque},
4    ops::Add,
5    sync::{Arc, OnceLock},
6};
7
8use async_graphql_parser::{
9    parse_schema,
10    types::{
11        BaseType, DirectiveDefinition, FieldDefinition, ObjectType, SchemaDefinition,
12        ServiceDocument, TypeDefinition, TypeKind, TypeSystemDefinition,
13    },
14    Positioned,
15};
16
17pub use ::async_graphql_parser::Error;
18use async_graphql_value::Name;
19use itertools::Itertools;
20use serde::{Deserialize, Serialize};
21
22use crate::ir::Type;
23use crate::util::{BTreeMapTryInsertExt, HashMapTryInsertExt};
24
25use self::error::InvalidSchemaError;
26
27mod adapter;
28pub mod error;
29
30pub use adapter::SchemaAdapter;
31
32#[derive(Debug, Clone)]
33pub struct Schema {
34    pub(crate) schema: SchemaDefinition,
35    pub(crate) query_type: ObjectType,
36    pub(crate) directives: HashMap<Arc<str>, DirectiveDefinition>,
37    pub(crate) scalars: HashMap<Arc<str>, TypeDefinition>,
38    pub(crate) vertex_types: HashMap<Arc<str>, TypeDefinition>,
39    pub(crate) fields: HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
40    pub(crate) field_origins: BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub(crate) enum FieldOrigin {
45    SingleAncestor(Arc<str>), // the name of the parent (super) type that first defined this field
46    MultipleAncestors(BTreeSet<Arc<str>>),
47}
48
49impl Add for &FieldOrigin {
50    type Output = FieldOrigin;
51
52    fn add(self, rhs: Self) -> Self::Output {
53        match (self, rhs) {
54            (FieldOrigin::SingleAncestor(l), FieldOrigin::SingleAncestor(r)) => {
55                if l == r {
56                    self.clone()
57                } else {
58                    FieldOrigin::MultipleAncestors(btreeset![l.clone(), r.clone()])
59                }
60            }
61            (FieldOrigin::SingleAncestor(single), FieldOrigin::MultipleAncestors(multi))
62            | (FieldOrigin::MultipleAncestors(multi), FieldOrigin::SingleAncestor(single)) => {
63                let mut new_set = multi.clone();
64                new_set.insert(single.clone());
65                FieldOrigin::MultipleAncestors(new_set)
66            }
67            (FieldOrigin::MultipleAncestors(l_set), FieldOrigin::MultipleAncestors(r_set)) => {
68                let mut new_set = l_set.clone();
69                new_set.extend(r_set.iter().cloned());
70                FieldOrigin::MultipleAncestors(new_set)
71            }
72        }
73    }
74}
75
76static BUILTIN_SCALARS: OnceLock<HashSet<&'static str>> = OnceLock::new();
77
78pub(crate) fn get_builtin_scalars() -> &'static HashSet<&'static str> {
79    BUILTIN_SCALARS.get_or_init(|| {
80        hashset! {
81            "Int",
82            "Float",
83            "String",
84            "Boolean",
85            "ID",
86        }
87    })
88}
89
90const RESERVED_PREFIX: &str = "__";
91
92impl Schema {
93    pub const ALL_DIRECTIVE_DEFINITIONS: &'static str = "
94directive @filter(op: String!, value: [String!]) repeatable on FIELD | INLINE_FRAGMENT
95directive @tag(name: String) on FIELD
96directive @output(name: String) on FIELD
97directive @optional on FIELD
98directive @recurse(depth: Int!) on FIELD
99directive @fold on FIELD
100directive @transform(op: String!) on FIELD
101";
102
103    pub fn parse(input: impl AsRef<str>) -> Result<Self, InvalidSchemaError> {
104        let doc = parse_schema(input)?;
105        Self::new(doc)
106    }
107
108    pub fn new(doc: ServiceDocument) -> Result<Self, InvalidSchemaError> {
109        let mut schema: Option<SchemaDefinition> = None;
110        let mut directives: HashMap<Arc<str>, DirectiveDefinition> = Default::default();
111        let mut scalars: HashMap<Arc<str>, TypeDefinition> = Default::default();
112
113        // The schema is mostly type definitions, except for one schema definition, and
114        // perhaps a small number of other definitions like custom scalars or directives.
115        let mut vertex_types: HashMap<Arc<str>, TypeDefinition> =
116            HashMap::with_capacity(doc.definitions.len() - 1);
117
118        // Each type has probably at least one field.
119        let mut fields: HashMap<(Arc<str>, Arc<str>), FieldDefinition> =
120            HashMap::with_capacity(doc.definitions.len() - 1);
121
122        for definition in doc.definitions {
123            match definition {
124                TypeSystemDefinition::Schema(s) => {
125                    assert!(schema.is_none());
126                    if s.node.extend {
127                        unimplemented!("Trustfall does not support extending schemas");
128                    }
129
130                    schema = Some(s.node);
131                }
132                TypeSystemDefinition::Directive(d) => {
133                    directives
134                        .insert_or_error(Arc::from(d.node.name.node.to_string()), d.node)
135                        .unwrap();
136                }
137                TypeSystemDefinition::Type(t) => {
138                    let node = t.node;
139                    let type_name: Arc<str> = Arc::from(node.name.node.to_string());
140                    assert!(!get_builtin_scalars().contains(type_name.as_ref()));
141
142                    if node.extend {
143                        unimplemented!("Trustfall does not support extending schemas");
144                    }
145
146                    match &node.kind {
147                        TypeKind::Scalar => {
148                            scalars.insert_or_error(type_name.clone(), node.clone()).unwrap();
149                        }
150                        TypeKind::Object(_) | TypeKind::Interface(_) => {
151                            match vertex_types.insert_or_error(type_name.clone(), node.clone()) {
152                                Ok(_) => {}
153                                Err(err) => {
154                                    let type_or_interface_name = err.entry.key();
155                                    return Err(
156                                        InvalidSchemaError::DuplicateTypeOrInterfaceDefinition(
157                                            type_or_interface_name.to_string(),
158                                        ),
159                                    );
160                                }
161                            }
162                        }
163                        TypeKind::Enum(_) => unimplemented!("Trustfall does not support enum's"),
164                        TypeKind::Union(_) => unimplemented!("Trustfall does not support unions's"),
165                        TypeKind::InputObject(_) => {
166                            unimplemented!("Trustfall does not support input objects's")
167                        }
168                    }
169
170                    let field_defs = match node.kind {
171                        TypeKind::Object(object) => Some(object.fields),
172
173                        TypeKind::Interface(interface) => Some(interface.fields),
174                        _ => None,
175                    };
176                    if let Some(field_defs) = field_defs {
177                        for field in field_defs {
178                            let field_node = field.node;
179                            let field_name = Arc::from(field_node.name.node.to_string());
180
181                            match fields
182                                .insert_or_error((type_name.clone(), field_name), field_node)
183                            {
184                                Ok(_) => {}
185                                Err(err) => {
186                                    let (key, value) = err.entry.key();
187                                    return Err(InvalidSchemaError::DuplicateFieldDefinition(
188                                        key.to_string(),
189                                        value.to_string(),
190                                    ));
191                                }
192                            }
193                        }
194                    }
195                }
196            }
197        }
198
199        let schema = schema.expect("Schema definition was not present.");
200        let query_type_name =
201            schema.query.as_ref().expect("No query type was declared in the schema").node.as_ref();
202        let query_type_definition = vertex_types
203            .get(query_type_name)
204            .expect("The query type set in the schema object was never defined.");
205        let query_type = match &query_type_definition.kind {
206            TypeKind::Object(o) => o.clone(),
207            _ => unreachable!(),
208        };
209
210        let mut errors = vec![];
211        if let Err(e) = check_required_transitive_implementations(&vertex_types) {
212            errors.extend(e);
213        }
214        if let Err(e) = check_field_type_narrowing(&vertex_types, &fields) {
215            errors.extend(e);
216        }
217        if let Err(e) = check_fields_required_by_interface_implementations(&vertex_types, &fields) {
218            errors.extend(e);
219        }
220        if let Err(e) =
221            check_type_and_property_and_edge_invariants(query_type_definition, &vertex_types)
222        {
223            errors.extend(e);
224        }
225        if let Err(e) = check_root_query_type_invariants(query_type_definition, &query_type) {
226            errors.extend(e);
227        }
228
229        let field_origins = match get_field_origins(&vertex_types) {
230            Ok(field_origins) => {
231                if let Err(e) = check_ambiguous_field_origins(&fields, &field_origins) {
232                    errors.extend(e);
233                }
234                Some(field_origins)
235            }
236            Err(e) => {
237                errors.push(e);
238                None
239            }
240        };
241
242        if errors.is_empty() {
243            Ok(Self {
244                schema,
245                query_type,
246                directives,
247                scalars,
248                vertex_types,
249                fields,
250                field_origins: field_origins.expect("no field origins but also no errors"),
251            })
252        } else {
253            Err(errors.into())
254        }
255    }
256
257    /// If the named type is defined, iterate through the names of its subtypes including itself.
258    /// Otherwise, return None.
259    pub fn subtypes<'a, 'slf: 'a>(
260        &'slf self,
261        type_name: &'a str,
262    ) -> Option<impl Iterator<Item = &'slf str> + 'a> {
263        if !self.vertex_types.contains_key(type_name) {
264            return None;
265        }
266
267        Some(self.vertex_types.iter().sorted_by_key(|(name, _)| *name).filter_map(
268            move |(name, defn)| {
269                if name.as_ref() == type_name
270                    || get_vertex_type_implements(defn).iter().any(|x| x.node.as_ref() == type_name)
271                {
272                    Some(name.as_ref())
273                } else {
274                    None
275                }
276            },
277        ))
278    }
279
280    pub(crate) fn query_type_name(&self) -> &str {
281        self.schema.query.as_ref().unwrap().node.as_ref()
282    }
283
284    pub(crate) fn vertex_type_implements(&self, vertex_type: &str) -> &[Positioned<Name>] {
285        get_vertex_type_implements(&self.vertex_types[vertex_type])
286    }
287
288    pub(crate) fn is_subtype(
289        &self,
290        parent_type: &async_graphql_parser::types::Type,
291        maybe_subtype: &async_graphql_parser::types::Type,
292    ) -> bool {
293        is_subtype(&self.vertex_types, parent_type, maybe_subtype)
294    }
295
296    pub(crate) fn is_named_type_subtype(&self, parent_type: &str, maybe_subtype: &str) -> bool {
297        is_named_type_subtype(&self.vertex_types, parent_type, maybe_subtype)
298    }
299}
300
301fn check_root_query_type_invariants(
302    query_type_definition: &TypeDefinition,
303    query_type: &ObjectType,
304) -> Result<(), Vec<InvalidSchemaError>> {
305    let mut errors: Vec<InvalidSchemaError> = vec![];
306
307    for field_defn in &query_type.fields {
308        let field_type = Type::from_type(&field_defn.node.ty.node);
309        let base_named_type = field_type.base_type();
310        if get_builtin_scalars().contains(base_named_type) {
311            errors.push(InvalidSchemaError::PropertyFieldOnRootQueryType(
312                query_type_definition.name.node.to_string(),
313                field_defn.node.name.node.to_string(),
314                field_type.to_string(),
315            ));
316        }
317
318        // The invariant that vertex_types.contains_key(base_named_type) is
319        // ensured by check_type_and_property_and_edge_invariants. This is also
320        // verified by these tests:
321        // unknown_type_not_on_root
322        // unknown_type_on_root
323        // unknown_type_on_root_and_outside
324    }
325
326    if errors.is_empty() {
327        Ok(())
328    } else {
329        Err(errors)
330    }
331}
332
333fn check_type_and_property_and_edge_invariants(
334    query_type_definition: &TypeDefinition,
335    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
336) -> Result<(), Vec<InvalidSchemaError>> {
337    let mut errors: Vec<InvalidSchemaError> = vec![];
338
339    for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
340        if type_name.as_ref().starts_with(RESERVED_PREFIX) {
341            errors.push(InvalidSchemaError::ReservedTypeName(type_name.to_string()));
342        }
343
344        let type_fields = get_vertex_type_fields(type_defn);
345
346        for defn in type_fields {
347            let field_defn = &defn.node;
348            let field_type = &field_defn.ty.node;
349
350            if field_defn.name.node.as_ref().starts_with(RESERVED_PREFIX) {
351                errors.push(InvalidSchemaError::ReservedFieldName(
352                    type_name.to_string(),
353                    field_defn.name.node.to_string(),
354                ));
355            }
356
357            let field_type = Type::from_type(field_type);
358
359            let base_named_type = field_type.base_type();
360            if get_builtin_scalars().contains(base_named_type) {
361                // We're looking at a property field.
362                if !field_defn.arguments.is_empty() {
363                    errors.push(InvalidSchemaError::PropertyFieldWithParameters(
364                        type_name.to_string(),
365                        field_defn.name.node.to_string(),
366                        field_type.to_string(),
367                        field_defn.arguments.iter().map(|x| x.node.name.node.to_string()).collect(),
368                    ));
369                }
370            } else if vertex_types.contains_key(base_named_type) {
371                // We're looking at an edge field.
372                if base_named_type == query_type_definition.name.node.as_ref() {
373                    // This edge points to the root query type. That's not supported.
374                    errors.push(InvalidSchemaError::EdgePointsToRootQueryType(
375                        type_name.to_string(),
376                        field_defn.name.node.to_string(),
377                        field_type.to_string(),
378                    ));
379                } else {
380                    // Check if the parameters this edge accepts (if any) have valid default values.
381                    for param_defn in &field_defn.arguments {
382                        if let Some(value) = &param_defn.node.default_value {
383                            let param_type = &param_defn.node.ty.node;
384                            match value.node.clone().try_into() {
385                                Ok(value) => {
386                                    if !Type::from_type(param_type).is_valid_value(&value) {
387                                        errors.push(InvalidSchemaError::InvalidDefaultValueForFieldParameter(
388                                            type_name.to_string(),
389                                            field_defn.name.node.to_string(),
390                                            param_defn.node.name.node.to_string(),
391                                            param_type.to_string(),
392                                            format!("{value:?}"),
393                                        ));
394                                    }
395                                }
396                                Err(_) => {
397                                    errors.push(
398                                        InvalidSchemaError::InvalidDefaultValueForFieldParameter(
399                                            type_name.to_string(),
400                                            field_defn.name.node.to_string(),
401                                            param_defn.node.name.node.to_string(),
402                                            param_type.to_string(),
403                                            value.node.to_string(),
404                                        ),
405                                    );
406                                }
407                            }
408                        }
409                    }
410
411                    // Check that the edge field doesn't have
412                    // a list-of-list or more nested list type.
413                    if let Some(inner_list) = field_type.as_list() {
414                        if inner_list.is_list() {
415                            errors.push(InvalidSchemaError::InvalidEdgeType(
416                                type_name.to_string(),
417                                field_defn.name.node.to_string(),
418                                field_type.to_string(),
419                            ));
420                        }
421                    }
422                }
423            } else {
424                errors.push(InvalidSchemaError::UnknownPropertyOrEdgeType(
425                    field_defn.name.node.as_ref().to_string(),
426                    field_type.to_string(),
427                ))
428            }
429        }
430    }
431
432    if errors.is_empty() {
433        Ok(())
434    } else {
435        Err(errors)
436    }
437}
438
439fn is_named_type_subtype(
440    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
441    parent_type: &str,
442    maybe_subtype: &str,
443) -> bool {
444    let parent_is_vertex = vertex_types.contains_key(parent_type);
445    let maybe_sub = vertex_types.get(maybe_subtype);
446
447    match (parent_is_vertex, maybe_sub) {
448        (false, None) => {
449            // The types could both be scalars, which have no inheritance hierarchy.
450            // Any type is a subtype of itself, so we check equality.
451            parent_type == maybe_subtype
452        }
453        (true, Some(maybe_subtype_vertex)) => {
454            // Both types are vertex types. We have a subtype relationship if
455            // - the two types are actually the same type, or if
456            // - the "maybe subtype" implements the parent type.
457            parent_type == maybe_subtype
458                || get_vertex_type_implements(maybe_subtype_vertex)
459                    .iter()
460                    .any(|pos| pos.node.as_ref() == parent_type)
461        }
462        _ => {
463            // One type is a vertex type, the other should be a scalar.
464            // No subtype relationship is possible between them.
465            false
466        }
467    }
468}
469
470fn is_subtype(
471    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
472    parent_type: &async_graphql_parser::types::Type,
473    maybe_subtype: &async_graphql_parser::types::Type,
474) -> bool {
475    // If the parent type is non-nullable, all its subtypes must be non-nullable as well.
476    // If the parent type is nullable, it can have both nullable and non-nullable subtypes.
477    if !parent_type.nullable && maybe_subtype.nullable {
478        return false;
479    }
480
481    match (&parent_type.base, &maybe_subtype.base) {
482        (BaseType::Named(parent), BaseType::Named(subtype)) => {
483            is_named_type_subtype(vertex_types, parent.as_ref(), subtype.as_ref())
484        }
485        (BaseType::List(parent_type), BaseType::List(maybe_subtype)) => {
486            is_subtype(vertex_types, parent_type, maybe_subtype)
487        }
488        (BaseType::Named(..), BaseType::List(..)) | (BaseType::List(..), BaseType::Named(..)) => {
489            false
490        }
491    }
492}
493
494fn check_ambiguous_field_origins(
495    fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
496    field_origins: &BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>,
497) -> Result<(), Vec<InvalidSchemaError>> {
498    let mut errors = vec![];
499
500    for (key, origin) in field_origins {
501        let (type_name, field_name) = key;
502        if let FieldOrigin::MultipleAncestors(ancestors) = &origin {
503            let field_type = fields[key].ty.node.to_string();
504            errors.push(InvalidSchemaError::AmbiguousFieldOrigin(
505                type_name.to_string(),
506                field_name.to_string(),
507                field_type,
508                ancestors.iter().map(|x| x.to_string()).collect(),
509            ))
510        }
511    }
512
513    if errors.is_empty() {
514        Ok(())
515    } else {
516        Err(errors)
517    }
518}
519
520/// Check the `implements` portion of the type definitions.
521///
522/// Checked invariants:
523/// - Implemented types must be defined in the schema.
524/// - Implemented types must be interfaces.
525/// - If type X implements interface A, and A implements interface B,
526///   then X must also implement B by transitivity.
527fn check_required_transitive_implementations(
528    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
529) -> Result<(), Vec<InvalidSchemaError>> {
530    let mut errors: Vec<InvalidSchemaError> = vec![];
531
532    for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
533        let implementations: BTreeSet<&str> =
534            get_vertex_type_implements(type_defn).iter().map(|x| x.node.as_ref()).collect();
535
536        // Check the `implements` portion of the type definition.
537        for implements_type in implementations.iter().copied() {
538            match vertex_types.get(implements_type) {
539                Some(implementation_defn) => {
540                    if !matches!(implementation_defn.kind, TypeKind::Interface(..)) {
541                        errors.push(InvalidSchemaError::ImplementingNonInterface(
542                            type_name.to_string(),
543                            implements_type.to_string(),
544                        ));
545                    } else {
546                        for expected_impl in get_vertex_type_implements(implementation_defn) {
547                            let expected_impl_name = expected_impl.node.as_ref();
548
549                            // Ignore situations with an immediate cycle here
550                            // (`expected_impl_name != type_name`) since we have a dedicated
551                            // check for those elsewhere.
552                            if expected_impl_name != type_name.as_ref()
553                                && !implementations.contains(expected_impl_name)
554                            {
555                                errors.push(
556                                    InvalidSchemaError::MissingTransitiveInterfaceImplementation(
557                                        type_name.to_string(),
558                                        implements_type.to_string(),
559                                        expected_impl_name.to_string(),
560                                    ),
561                                );
562                            }
563                        }
564                    }
565                }
566                None => {
567                    errors.push(InvalidSchemaError::ImplementingNonExistentType(
568                        type_name.to_string(),
569                        implements_type.to_string(),
570                    ));
571                }
572            }
573        }
574    }
575
576    if errors.is_empty() {
577        Ok(())
578    } else {
579        Err(errors)
580    }
581}
582
583fn check_fields_required_by_interface_implementations(
584    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
585    fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
586) -> Result<(), Vec<InvalidSchemaError>> {
587    let mut errors: Vec<InvalidSchemaError> = vec![];
588
589    for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
590        let implementations = get_vertex_type_implements(type_defn);
591
592        for implementation in implementations {
593            let implementation = implementation.node.as_ref();
594            let Some(impl_defn) = vertex_types.get(implementation) else {
595                continue;
596            };
597
598            for field in get_vertex_type_fields(impl_defn) {
599                let field_name = field.node.name.node.as_ref();
600
601                // If the current type does not contain the implemented interface's field,
602                // that's an error.
603                if !fields.contains_key(&(type_name.clone(), Arc::from(field_name))) {
604                    errors.push(InvalidSchemaError::MissingRequiredField(
605                        type_name.to_string(),
606                        implementation.to_string(),
607                        field_name.to_string(),
608                        field.node.ty.node.to_string(),
609                    ))
610                }
611            }
612        }
613    }
614
615    if errors.is_empty() {
616        Ok(())
617    } else {
618        Err(errors)
619    }
620}
621
622fn check_field_type_narrowing(
623    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
624    fields: &HashMap<(Arc<str>, Arc<str>), FieldDefinition>,
625) -> Result<(), Vec<InvalidSchemaError>> {
626    let mut errors: Vec<InvalidSchemaError> = vec![];
627
628    for (type_name, type_defn) in vertex_types.iter().sorted_by_key(|(name, _)| *name) {
629        let implementations = get_vertex_type_implements(type_defn);
630        let type_fields = get_vertex_type_fields(type_defn);
631
632        for field in type_fields {
633            let field_name = field.node.name.node.as_ref();
634            let field_type = &field.node.ty.node;
635            let field_parameters: BTreeMap<_, _> = field
636                .node
637                .arguments
638                .iter()
639                .map(|arg| (arg.node.name.node.as_ref(), &arg.node.ty.node))
640                .collect();
641
642            for implementation in implementations {
643                let implementation = implementation.node.as_ref();
644
645                // The parent type might not contain this field. But if it does,
646                // ensure that the parent field's type is a supertype of the current field's type.
647                if let Some(parent_field) =
648                    fields.get(&(Arc::from(implementation), Arc::from(field_name)))
649                {
650                    let parent_field_type = &parent_field.ty.node;
651                    if !is_subtype(vertex_types, parent_field_type, field_type) {
652                        errors.push(InvalidSchemaError::InvalidTypeWideningOfInheritedField(
653                            field_name.to_string(),
654                            type_name.to_string(),
655                            implementation.to_string(),
656                            field_type.to_string(),
657                            parent_field_type.to_string(),
658                        ));
659                    }
660
661                    let parent_field_parameters: BTreeMap<_, _> = parent_field
662                        .arguments
663                        .iter()
664                        .map(|arg| (arg.node.name.node.as_ref(), &arg.node.ty.node))
665                        .collect();
666
667                    // Check for field parameters that the parent type requires but
668                    // the child type does not accept.
669                    let missing_parameters = parent_field_parameters
670                        .keys()
671                        .copied()
672                        .filter(|name| !field_parameters.contains_key(*name))
673                        .collect_vec();
674                    if !missing_parameters.is_empty() {
675                        errors.push(InvalidSchemaError::InheritedFieldMissingParameters(
676                            field_name.to_owned(),
677                            type_name.to_string(),
678                            implementation.to_owned(),
679                            missing_parameters.into_iter().map(ToOwned::to_owned).collect_vec(),
680                        ));
681                    }
682
683                    // Check for field parameters that the parent type does not accept,
684                    // but the child type defines anyway.
685                    let unexpected_parameters = field_parameters
686                        .keys()
687                        .copied()
688                        .filter(|name| !parent_field_parameters.contains_key(*name))
689                        .collect_vec();
690                    if !unexpected_parameters.is_empty() {
691                        errors.push(InvalidSchemaError::InheritedFieldUnexpectedParameters(
692                            field_name.to_owned(),
693                            type_name.to_string(),
694                            implementation.to_owned(),
695                            unexpected_parameters.into_iter().map(ToOwned::to_owned).collect_vec(),
696                        ));
697                    }
698
699                    // Check that all field parameters defined by the child have types
700                    // that are legal widenings of the corresponding field parameter's type
701                    // on the parent type. Field parameters are contravariant, hence widenings.
702                    for (&field_parameter, &field_type) in &field_parameters {
703                        if let Some(&parent_field_type) =
704                            parent_field_parameters.get(field_parameter)
705                        {
706                            if !Type::from_type(field_type)
707                                .is_scalar_only_subtype(&Type::from_type(parent_field_type))
708                            {
709                                errors.push(InvalidSchemaError::InvalidTypeNarrowingOfInheritedFieldParameter(
710                                    field_name.to_owned(),
711                                    type_name.to_string(),
712                                    implementation.to_owned(),
713                                    field_parameter.to_string(),
714                                    field_type.to_string(),
715                                    parent_field_type.to_string(),
716                                ));
717                            }
718                        }
719                    }
720                }
721            }
722        }
723    }
724
725    if errors.is_empty() {
726        Ok(())
727    } else {
728        Err(errors)
729    }
730}
731
732fn get_vertex_type_fields(vertex: &TypeDefinition) -> &[Positioned<FieldDefinition>] {
733    match &vertex.kind {
734        TypeKind::Object(obj) => &obj.fields,
735        TypeKind::Interface(iface) => &iface.fields,
736        _ => unreachable!(),
737    }
738}
739
740fn get_vertex_type_implements(vertex: &TypeDefinition) -> &[Positioned<Name>] {
741    match &vertex.kind {
742        TypeKind::Object(obj) => &obj.implements,
743        TypeKind::Interface(iface) => &iface.implements,
744        _ => unreachable!(),
745    }
746}
747
748#[allow(clippy::type_complexity)]
749fn get_field_origins(
750    vertex_types: &HashMap<Arc<str>, TypeDefinition>,
751) -> Result<BTreeMap<(Arc<str>, Arc<str>), FieldOrigin>, InvalidSchemaError> {
752    let mut field_origins: BTreeMap<(Arc<str>, Arc<str>), FieldOrigin> = Default::default();
753    let mut queue = VecDeque::new();
754
755    // for each type, which types have yet to have their field origins resolved first
756    let mut required_resolutions: BTreeMap<&str, BTreeSet<&str>> = vertex_types
757        .iter()
758        .sorted_by_key(|(name, _)| *name)
759        .map(|(name, defn)| {
760            let resolutions: BTreeSet<&str> = get_vertex_type_implements(defn)
761                .iter()
762                .map(|x| x.node.as_ref())
763                .filter(|name| vertex_types.contains_key(*name)) // ignore undefined types
764                .collect();
765            if resolutions.is_empty() {
766                queue.push_back(name);
767            }
768            (name.as_ref(), resolutions)
769        })
770        .collect();
771
772    // for each type, which types does it enable resolution of
773    let resolvers: BTreeMap<&str, BTreeSet<Arc<str>>> = vertex_types
774        .iter()
775        .sorted_by_key(|(name, _)| *name)
776        .flat_map(|(name, defn)| {
777            get_vertex_type_implements(defn)
778                .iter()
779                .map(|x| (x.node.as_ref(), Arc::from(name.as_ref())))
780        })
781        .fold(Default::default(), |mut acc, (interface, implementer)| {
782            match acc.entry(interface) {
783                Entry::Vacant(v) => {
784                    v.insert(btreeset![implementer]);
785                }
786                Entry::Occupied(occ) => {
787                    occ.into_mut().insert(implementer);
788                }
789            }
790            acc
791        });
792
793    while let Some(type_name) = queue.pop_front() {
794        let defn = &vertex_types[type_name];
795        let implements = get_vertex_type_implements(defn);
796        let fields = get_vertex_type_fields(defn);
797
798        let mut implemented_fields: BTreeMap<&str, FieldOrigin> = Default::default();
799        for implemented_interface in implements {
800            let implemented_interface = implemented_interface.node.as_ref();
801            let Some(implemented_defn) = vertex_types.get(implemented_interface) else {
802                continue;
803            };
804            let parent_fields = get_vertex_type_fields(implemented_defn);
805            for field in parent_fields {
806                let parent_field_origin = &field_origins
807                    [&(Arc::from(implemented_interface), Arc::from(field.node.name.node.as_ref()))];
808
809                implemented_fields
810                    .entry(field.node.name.node.as_ref())
811                    .and_modify(|origin| *origin = (origin as &FieldOrigin) + parent_field_origin)
812                    .or_insert_with(|| parent_field_origin.clone());
813            }
814        }
815
816        for field in fields {
817            let field = &field.node;
818            let field_name = &field.name.node;
819
820            let origin = implemented_fields
821                .remove(field_name.as_ref())
822                .unwrap_or_else(|| FieldOrigin::SingleAncestor(type_name.clone()));
823            field_origins
824                .insert_or_error((type_name.clone(), Arc::from(field_name.as_ref())), origin)
825                .unwrap();
826        }
827
828        if let Some(next_types) = resolvers.get(type_name.as_ref()) {
829            for next_type in next_types.iter() {
830                let remaining = required_resolutions.get_mut(next_type.as_ref()).unwrap();
831                if remaining.remove(type_name.as_ref()) && remaining.is_empty() {
832                    queue.push_back(next_type);
833                }
834            }
835        }
836    }
837
838    for (required, mut remaining) in required_resolutions.into_iter() {
839        if !remaining.is_empty() {
840            remaining.insert(required);
841            let circular_implementations =
842                remaining.into_iter().map(|x| x.to_string()).collect_vec();
843            return Err(InvalidSchemaError::CircularImplementsRelationships(
844                circular_implementations,
845            ));
846        }
847    }
848
849    Ok(field_origins)
850}
851
852#[cfg(test)]
853mod tests {
854    use std::{
855        fs,
856        path::{Path, PathBuf},
857    };
858
859    use async_graphql_parser::parse_schema;
860    use itertools::Itertools;
861    use trustfall_filetests_macros::parameterize;
862
863    use super::{error::InvalidSchemaError, Schema};
864
865    #[parameterize("trustfall_core/test_data/tests/schema_errors", "*.graphql")]
866    fn schema_errors(base: &Path, stem: &str) {
867        let mut input_path = PathBuf::from(base);
868        input_path.push(format!("{stem}.graphql"));
869
870        let input_data = fs::read_to_string(input_path).unwrap();
871
872        let mut error_path = PathBuf::from(base);
873        error_path.push(format!("{stem}.schema-error.ron"));
874        let error_data = fs::read_to_string(error_path).unwrap();
875        let expected_error: InvalidSchemaError = ron::from_str(&error_data).unwrap();
876
877        let schema_doc = parse_schema(input_data).unwrap();
878
879        match Schema::new(schema_doc) {
880            Err(e) => {
881                assert_eq!(e, expected_error);
882            }
883            Ok(_) => panic!("Expected an error but got valid schema."),
884        }
885    }
886
887    #[parameterize("trustfall_core/test_data/tests/valid_schemas", "*.graphql")]
888    fn valid_schemas(base: &Path, stem: &str) {
889        let mut input_path = PathBuf::from(base);
890        input_path.push(format!("{stem}.graphql"));
891
892        let input_data = fs::read_to_string(input_path).unwrap();
893
894        // Ensure all test schemas contain the directive definitions this module promises are valid.
895        assert!(input_data.contains(Schema::ALL_DIRECTIVE_DEFINITIONS));
896
897        match Schema::parse(input_data) {
898            Ok(_) => {}
899            Err(e) => {
900                panic!("{}", e);
901            }
902        }
903    }
904
905    #[test]
906    fn schema_subtypes() {
907        let input_data = include_str!("../../test_data/schemas/numbers.graphql");
908        let schema = Schema::parse(input_data).expect("valid schema");
909
910        assert!(schema.subtypes("Nonexistent").is_none());
911
912        let composite_subtypes = schema.subtypes("Composite").unwrap().collect_vec();
913        assert_eq!(vec!["Composite"], composite_subtypes);
914
915        let mut number_subtypes = schema.subtypes("Number").unwrap().collect_vec();
916        number_subtypes.sort_unstable();
917        assert_eq!(vec!["Composite", "Neither", "Number", "Prime"], number_subtypes);
918    }
919}