Skip to main content

graphql_tools/ast/
ext.rs

1use std::collections::HashMap;
2
3use lazy_static::lazy_static;
4
5use crate::static_graphql::query::{
6    self, Directive, FragmentSpread, OperationDefinition, SelectionSet, Type, Value,
7    VariableDefinition,
8};
9use crate::static_graphql::schema::{
10    self, DirectiveDefinition, EnumValue, Field, InputValue, InterfaceType, ObjectType,
11    TypeDefinition, TypeExtension, UnionType,
12};
13
14lazy_static! {
15    static ref QUERY_TYPE_DEFAULT_NAME: String = "Query".to_string();
16    static ref MUTATION_TYPE_DEFAULT_NAME: String = "Mutation".to_string();
17    static ref SUBSCRIPTION_TYPE_DEFAULT_NAME: String = "Subscription".to_string();
18}
19
20impl TypeDefinition {
21    pub fn field_by_name(&self, name: &str) -> Option<&schema::Field> {
22        match self {
23            TypeDefinition::Object(object) => {
24                object.fields.iter().find(|field| field.name.eq(name))
25            }
26            TypeDefinition::Interface(interface) => {
27                interface.fields.iter().find(|field| field.name.eq(name))
28            }
29            _ => None,
30        }
31    }
32
33    pub fn input_field_by_name(&self, name: &str) -> Option<&InputValue> {
34        match self {
35            TypeDefinition::InputObject(input_object) => {
36                input_object.fields.iter().find(|field| field.name.eq(name))
37            }
38            _ => None,
39        }
40    }
41}
42
43impl OperationDefinition {
44    pub fn variable_definitions(&self) -> &[VariableDefinition] {
45        match self {
46            OperationDefinition::Query(query) => &query.variable_definitions,
47            OperationDefinition::SelectionSet(_) => &[],
48            OperationDefinition::Mutation(mutation) => &mutation.variable_definitions,
49            OperationDefinition::Subscription(subscription) => &subscription.variable_definitions,
50        }
51    }
52
53    pub fn selection_set(&self) -> &SelectionSet {
54        match self {
55            OperationDefinition::Query(query) => &query.selection_set,
56            OperationDefinition::SelectionSet(selection_set) => selection_set,
57            OperationDefinition::Mutation(mutation) => &mutation.selection_set,
58            OperationDefinition::Subscription(subscription) => &subscription.selection_set,
59        }
60    }
61
62    pub fn directives(&self) -> &[Directive] {
63        match self {
64            OperationDefinition::Query(query) => &query.directives,
65            OperationDefinition::SelectionSet(_) => &[],
66            OperationDefinition::Mutation(mutation) => &mutation.directives,
67            OperationDefinition::Subscription(subscription) => &subscription.directives,
68        }
69    }
70}
71
72impl schema::Document {
73    pub fn type_by_name(&self, name: &str) -> Option<&TypeDefinition> {
74        for def in &self.definitions {
75            if let schema::Definition::TypeDefinition(type_def) = def {
76                if type_def.name().eq(name) {
77                    return Some(type_def);
78                }
79            }
80        }
81
82        None
83    }
84
85    pub fn directive_by_name(&self, name: &str) -> Option<&DirectiveDefinition> {
86        for def in &self.definitions {
87            if let schema::Definition::DirectiveDefinition(directive_def) = def {
88                if directive_def.name.eq(name) {
89                    return Some(directive_def);
90                }
91            }
92        }
93
94        None
95    }
96
97    fn schema_definition(&self) -> &schema::SchemaDefinition {
98        lazy_static! {
99            static ref DEFAULT_SCHEMA_DEF: schema::SchemaDefinition = {
100                schema::SchemaDefinition {
101                    query: Some("Query".to_string()),
102                    ..Default::default()
103                }
104            };
105        }
106        self.definitions
107            .iter()
108            .find_map(|definition| match definition {
109                schema::Definition::SchemaDefinition(schema_definition) => Some(schema_definition),
110                _ => None,
111            })
112            .unwrap_or(&*DEFAULT_SCHEMA_DEF)
113    }
114
115    pub fn query_type(&self) -> &ObjectType {
116        let schema_definition = self.schema_definition();
117        self.object_type_by_name(
118            schema_definition
119                .query
120                .as_ref()
121                .unwrap_or(&QUERY_TYPE_DEFAULT_NAME),
122        )
123        .unwrap()
124    }
125
126    pub fn mutation_type(&self) -> Option<&ObjectType> {
127        let schema_definition = self.schema_definition();
128        self.object_type_by_name(
129            schema_definition
130                .mutation
131                .as_ref()
132                .unwrap_or(&MUTATION_TYPE_DEFAULT_NAME),
133        )
134    }
135
136    pub fn subscription_type(&self) -> Option<&ObjectType> {
137        let schema_definition = self.schema_definition();
138
139        self.object_type_by_name(
140            schema_definition
141                .subscription
142                .as_ref()
143                .unwrap_or(&SUBSCRIPTION_TYPE_DEFAULT_NAME),
144        )
145    }
146
147    fn object_type_by_name(&self, name: &str) -> Option<&ObjectType> {
148        match self.type_by_name(name) {
149            Some(TypeDefinition::Object(object_def)) => Some(object_def),
150            _ => None,
151        }
152    }
153
154    pub fn type_map(&self) -> HashMap<&str, &TypeDefinition> {
155        let mut type_map = HashMap::new();
156
157        for def in &self.definitions {
158            if let schema::Definition::TypeDefinition(type_def) = def {
159                type_map.insert(type_def.name(), type_def);
160            }
161        }
162
163        type_map
164    }
165
166    pub fn is_named_subtype(&self, sub_type_name: &str, super_type_name: &str) -> bool {
167        if sub_type_name == super_type_name {
168            true
169        } else if let (Some(sub_type), Some(super_type)) = (
170            self.type_by_name(sub_type_name),
171            self.type_by_name(super_type_name),
172        ) {
173            super_type.is_abstract_type() && self.is_possible_type(super_type, sub_type)
174        } else {
175            false
176        }
177    }
178
179    fn is_possible_type(
180        &self,
181        abstract_type: &TypeDefinition,
182        possible_type: &TypeDefinition,
183    ) -> bool {
184        match abstract_type {
185            TypeDefinition::Union(union_typedef) => union_typedef
186                .types
187                .iter()
188                .any(|t| t == possible_type.name()),
189            TypeDefinition::Interface(interface_typedef) => {
190                let implementes_interfaces = possible_type.interfaces();
191
192                implementes_interfaces.contains(&interface_typedef.name)
193            }
194            _ => false,
195        }
196    }
197
198    pub fn is_subtype(&self, sub_type: &Type, super_type: &Type) -> bool {
199        // Equivalent type is a valid subtype
200        if sub_type == super_type {
201            return true;
202        }
203
204        // If superType is non-null, maybeSubType must also be non-null.
205        if super_type.is_non_null() {
206            if sub_type.is_non_null() {
207                return self.is_subtype(sub_type.of_type(), super_type.of_type());
208            }
209            return false;
210        }
211
212        if sub_type.is_non_null() {
213            // If superType is nullable, maybeSubType may be non-null or nullable.
214            return self.is_subtype(sub_type.of_type(), super_type);
215        }
216
217        // If superType type is a list, maybeSubType type must also be a list.
218        if super_type.is_list_type() {
219            if sub_type.is_list_type() {
220                return self.is_subtype(sub_type.of_type(), super_type.of_type());
221            }
222
223            return false;
224        }
225
226        if sub_type.is_list_type() {
227            // If superType is nullable, maybeSubType may be non-null or nullable.
228            return false;
229        }
230
231        // If superType type is an abstract type, check if it is super type of maybeSubType.
232        // Otherwise, the child type is not a valid subtype of the parent type.
233        if let (Some(sub_type), Some(super_type)) = (
234            self.type_by_name(sub_type.inner_type()),
235            self.type_by_name(super_type.inner_type()),
236        ) {
237            return super_type.is_abstract_type()
238                && (sub_type.is_interface_type() || sub_type.is_object_type())
239                && self.is_possible_type(super_type, sub_type);
240        }
241
242        false
243    }
244
245    pub fn query_type_name(&self) -> &str {
246        "Query"
247    }
248
249    pub fn mutation_type_name(&self) -> Option<&str> {
250        for def in &self.definitions {
251            if let schema::Definition::SchemaDefinition(schema_def) = def {
252                if let Some(name) = schema_def.mutation.as_ref() {
253                    return Some(name.as_str());
254                }
255            }
256        }
257
258        self.type_by_name("Mutation").map(|typ| typ.name())
259    }
260
261    pub fn subscription_type_name(&self) -> Option<&str> {
262        for def in &self.definitions {
263            if let schema::Definition::SchemaDefinition(schema_def) = def {
264                if let Some(name) = schema_def.subscription.as_ref() {
265                    return Some(name.as_str());
266                }
267            }
268        }
269
270        self.type_by_name("Subscription").map(|typ| typ.name())
271    }
272}
273
274impl Type {
275    pub fn inner_type(&self) -> &str {
276        match self {
277            Type::NamedType(name) => name.as_str(),
278            Type::ListType(child) => child.inner_type(),
279            Type::NonNullType(child) => child.inner_type(),
280        }
281    }
282
283    fn of_type(&self) -> &Type {
284        match self {
285            Type::ListType(child) => child,
286            Type::NonNullType(child) => child,
287            Type::NamedType(_) => self,
288        }
289    }
290
291    pub fn is_non_null(&self) -> bool {
292        matches!(self, Type::NonNullType(_))
293    }
294
295    fn is_list_type(&self) -> bool {
296        matches!(self, Type::ListType(_))
297    }
298
299    pub fn is_named_type(&self) -> bool {
300        matches!(self, Type::NamedType(_))
301    }
302}
303
304impl Value {
305    pub fn compare(&self, other: &Self) -> bool {
306        match (self, other) {
307            (Value::Null, Value::Null) => true,
308            (Value::Boolean(a), Value::Boolean(b)) => a == b,
309            (Value::Int(a), Value::Int(b)) => a == b,
310            (Value::Float(a), Value::Float(b)) => a == b,
311            (Value::String(a), Value::String(b)) => a.eq(b),
312            (Value::Enum(a), Value::Enum(b)) => a.eq(b),
313            (Value::List(a), Value::List(b)) => a.iter().zip(b.iter()).all(|(a, b)| a.compare(b)),
314            (Value::Object(a), Value::Object(b)) => {
315                a.iter().zip(b.iter()).all(|(a, b)| a.1.compare(b.1))
316            }
317            (Value::Variable(a), Value::Variable(b)) => a.eq(b),
318            _ => false,
319        }
320    }
321
322    pub fn variables_in_use(&self) -> Vec<&str> {
323        match self {
324            Value::Variable(v) => vec![v],
325            Value::List(list) => list.iter().flat_map(|v| v.variables_in_use()).collect(),
326            Value::Object(object) => object.values().flat_map(|v| v.variables_in_use()).collect(),
327            _ => vec![],
328        }
329    }
330}
331
332impl InputValue {
333    pub fn is_required(&self) -> bool {
334        if let Type::NonNullType(_inner_type) = &self.value_type {
335            if self.default_value.is_none() {
336                return true;
337            }
338        }
339
340        false
341    }
342}
343
344impl TypeDefinition {
345    fn interfaces(&self) -> Vec<String> {
346        match self {
347            schema::TypeDefinition::Object(o) => o.interfaces(),
348            schema::TypeDefinition::Interface(i) => i.interfaces(),
349            _ => vec![],
350        }
351    }
352
353    pub fn has_sub_type(&self, other_type: &TypeDefinition) -> bool {
354        match self {
355            TypeDefinition::Interface(interface_type) => {
356                interface_type.is_implemented_by(other_type)
357            }
358            TypeDefinition::Union(union_type) => union_type.has_sub_type(other_type.name()),
359            _ => false,
360        }
361    }
362
363    pub fn has_concrete_sub_type(&self, concrete_type: &TypeDefinition) -> bool {
364        match self {
365            TypeDefinition::Interface(interface_type) => {
366                interface_type.is_implemented_by(concrete_type)
367            }
368            TypeDefinition::Union(union_type) => union_type.has_sub_type(concrete_type.name()),
369            _ => false,
370        }
371    }
372}
373
374impl TypeDefinition {
375    pub fn possible_types<'a>(&self, schema: &'a schema::Document) -> Vec<&'a TypeDefinition> {
376        match self {
377            TypeDefinition::Object(_) => vec![],
378            TypeDefinition::InputObject(_) => vec![],
379            TypeDefinition::Enum(_) => vec![],
380            TypeDefinition::Scalar(_) => vec![],
381            TypeDefinition::Interface(i) => schema
382                .type_map()
383                .values()
384                .filter_map(|type_def| {
385                    if i.is_implemented_by(type_def) {
386                        return Some(*type_def);
387                    }
388
389                    None
390                })
391                .collect(),
392            TypeDefinition::Union(u) => u
393                .types
394                .iter()
395                .filter_map(|type_name| {
396                    if let Some(type_def) = schema.type_by_name(type_name) {
397                        return Some(type_def);
398                    }
399
400                    None
401                })
402                .collect(),
403        }
404    }
405}
406
407impl InterfaceType {
408    fn interfaces(&self) -> Vec<String> {
409        self.implements_interfaces.clone()
410    }
411
412    pub fn has_sub_type(&self, other_type: &TypeDefinition) -> bool {
413        self.is_implemented_by(other_type)
414    }
415
416    pub fn has_concrete_sub_type(&self, concrete_type: &TypeDefinition) -> bool {
417        self.is_implemented_by(concrete_type)
418    }
419}
420
421impl ObjectType {
422    fn interfaces(&self) -> Vec<String> {
423        self.implements_interfaces.clone()
424    }
425
426    pub fn has_sub_type(&self, _other_type: &TypeDefinition) -> bool {
427        false
428    }
429
430    pub fn has_concrete_sub_type(&self, _concrete_type: &ObjectType) -> bool {
431        false
432    }
433}
434
435impl UnionType {
436    pub fn has_sub_type(&self, other_type_name: &str) -> bool {
437        self.types.iter().any(|v| other_type_name.eq(v))
438    }
439}
440
441impl InterfaceType {
442    pub fn is_implemented_by(&self, other_type: &TypeDefinition) -> bool {
443        other_type.interfaces().iter().any(|v| self.name.eq(v))
444    }
445}
446
447impl schema::TypeDefinition {
448    pub fn name(&self) -> &str {
449        match self {
450            schema::TypeDefinition::Object(o) => &o.name,
451            schema::TypeDefinition::Interface(i) => &i.name,
452            schema::TypeDefinition::Union(u) => &u.name,
453            schema::TypeDefinition::Scalar(s) => &s.name,
454            schema::TypeDefinition::Enum(e) => &e.name,
455            schema::TypeDefinition::InputObject(i) => &i.name,
456        }
457    }
458
459    pub fn is_abstract_type(&self) -> bool {
460        matches!(
461            self,
462            schema::TypeDefinition::Interface(_) | schema::TypeDefinition::Union(_)
463        )
464    }
465
466    fn is_interface_type(&self) -> bool {
467        matches!(self, schema::TypeDefinition::Interface(_))
468    }
469
470    pub fn is_leaf_type(&self) -> bool {
471        matches!(
472            self,
473            schema::TypeDefinition::Scalar(_) | schema::TypeDefinition::Enum(_)
474        )
475    }
476
477    pub fn is_input_type(&self) -> bool {
478        matches!(
479            self,
480            schema::TypeDefinition::Scalar(_)
481                | schema::TypeDefinition::Enum(_)
482                | schema::TypeDefinition::InputObject(_)
483        )
484    }
485
486    pub fn is_composite_type(&self) -> bool {
487        matches!(
488            self,
489            schema::TypeDefinition::Object(_)
490                | schema::TypeDefinition::Interface(_)
491                | schema::TypeDefinition::Union(_)
492        )
493    }
494
495    pub fn is_object_type(&self) -> bool {
496        matches!(self, schema::TypeDefinition::Object(_o))
497    }
498
499    pub fn is_union_type(&self) -> bool {
500        matches!(self, schema::TypeDefinition::Union(_o))
501    }
502
503    pub fn is_enum_type(&self) -> bool {
504        matches!(self, schema::TypeDefinition::Enum(_o))
505    }
506
507    pub fn is_scalar_type(&self) -> bool {
508        matches!(self, schema::TypeDefinition::Scalar(_o))
509    }
510}
511
512pub trait AstNodeWithName {
513    fn node_name(&self) -> Option<&str>;
514}
515
516impl AstNodeWithName for query::OperationDefinition {
517    fn node_name(&self) -> Option<&str> {
518        match self {
519            query::OperationDefinition::Query(q) => q.name.as_deref(),
520            query::OperationDefinition::SelectionSet(_s) => None,
521            query::OperationDefinition::Mutation(m) => m.name.as_deref(),
522            query::OperationDefinition::Subscription(s) => s.name.as_deref(),
523        }
524    }
525}
526
527impl AstNodeWithName for query::FragmentDefinition {
528    fn node_name(&self) -> Option<&str> {
529        Some(&self.name)
530    }
531}
532
533impl AstNodeWithName for query::FragmentSpread {
534    fn node_name(&self) -> Option<&str> {
535        Some(&self.fragment_name)
536    }
537}
538
539impl query::SelectionSet {
540    pub fn get_recursive_fragment_spreads(&self) -> Vec<&FragmentSpread> {
541        self.items
542            .iter()
543            .flat_map(|v| match v {
544                query::Selection::FragmentSpread(f) => vec![f],
545                query::Selection::Field(f) => f.selection_set.get_fragment_spreads(),
546                query::Selection::InlineFragment(f) => f.selection_set.get_fragment_spreads(),
547            })
548            .collect()
549    }
550
551    fn get_fragment_spreads(&self) -> Vec<&FragmentSpread> {
552        self.items
553            .iter()
554            .flat_map(|v| match v {
555                query::Selection::FragmentSpread(f) => vec![f],
556                _ => vec![],
557            })
558            .collect()
559    }
560}
561
562impl query::Selection {
563    pub fn directives(&self) -> &[Directive] {
564        match self {
565            query::Selection::Field(f) => &f.directives,
566            query::Selection::FragmentSpread(f) => &f.directives,
567            query::Selection::InlineFragment(f) => &f.directives,
568        }
569    }
570    pub fn selection_set(&self) -> Option<&SelectionSet> {
571        match self {
572            query::Selection::Field(f) => Some(&f.selection_set),
573            query::Selection::FragmentSpread(_) => None,
574            query::Selection::InlineFragment(f) => Some(&f.selection_set),
575        }
576    }
577}
578
579impl schema::Definition<'static, String> {
580    pub fn name(&self) -> Option<&str> {
581        match self {
582            schema::Definition::SchemaDefinition(_) => None,
583            schema::Definition::TypeDefinition(type_def) => Some(type_def.name()),
584            schema::Definition::TypeExtension(type_ext) => Some(type_ext.name()),
585            schema::Definition::DirectiveDefinition(directive_def) => Some(&directive_def.name),
586        }
587    }
588    pub fn fields<'a>(&'a self) -> Option<TypeDefinitionFields<'a>> {
589        match self {
590            schema::Definition::SchemaDefinition(_) => None,
591            schema::Definition::TypeDefinition(type_def) => type_def.fields(),
592            schema::Definition::TypeExtension(type_ext) => type_ext.fields(),
593            schema::Definition::DirectiveDefinition(_) => None,
594        }
595    }
596    pub fn directives(&self) -> Option<&[Directive]> {
597        match self {
598            schema::Definition::SchemaDefinition(schema_def) => Some(&schema_def.directives),
599            schema::Definition::TypeDefinition(type_def) => type_def.directives(),
600            schema::Definition::TypeExtension(type_ext) => type_ext.directives(),
601            schema::Definition::DirectiveDefinition(_) => None,
602        }
603    }
604}
605
606pub enum TypeDefinitionFields<'a> {
607    Fields(&'a [Field]),
608    InputValues(&'a [InputValue]),
609    EnumValues(&'a [EnumValue]),
610}
611
612impl TypeDefinition {
613    pub fn fields<'a>(&'a self) -> Option<TypeDefinitionFields<'a>> {
614        match self {
615            TypeDefinition::Scalar(_) => None,
616            TypeDefinition::Object(object) => Some(TypeDefinitionFields::Fields(&object.fields)),
617            TypeDefinition::Interface(interface) => {
618                Some(TypeDefinitionFields::Fields(&interface.fields))
619            }
620            TypeDefinition::Union(_) => None,
621            TypeDefinition::Enum(enum_) => Some(TypeDefinitionFields::EnumValues(&enum_.values)),
622            TypeDefinition::InputObject(input_object) => {
623                Some(TypeDefinitionFields::InputValues(&input_object.fields))
624            }
625        }
626    }
627    pub fn directives(&self) -> Option<&[Directive]> {
628        match self {
629            TypeDefinition::Scalar(_) => None,
630            TypeDefinition::Object(object) => Some(&object.directives),
631            TypeDefinition::Interface(interface) => Some(&interface.directives),
632            TypeDefinition::Union(union) => Some(&union.directives),
633            TypeDefinition::Enum(enum_) => Some(&enum_.directives),
634            TypeDefinition::InputObject(input_object) => Some(&input_object.directives),
635        }
636    }
637}
638
639impl TypeExtension<'static, String> {
640    pub fn name(&self) -> &str {
641        match self {
642            TypeExtension::Object(object) => &object.name,
643            TypeExtension::Interface(interface) => &interface.name,
644            TypeExtension::Union(union) => &union.name,
645            TypeExtension::Scalar(scalar) => &scalar.name,
646            TypeExtension::Enum(enum_) => &enum_.name,
647            TypeExtension::InputObject(input_object) => &input_object.name,
648        }
649    }
650    pub fn fields<'a>(&'a self) -> Option<TypeDefinitionFields<'a>> {
651        match self {
652            TypeExtension::Object(object) => Some(TypeDefinitionFields::Fields(&object.fields)),
653            TypeExtension::Interface(interface) => {
654                Some(TypeDefinitionFields::Fields(&interface.fields))
655            }
656            _ => None,
657        }
658    }
659    pub fn directives(&self) -> Option<&[Directive]> {
660        match self {
661            TypeExtension::Object(object) => Some(&object.directives),
662            TypeExtension::Interface(interface) => Some(&interface.directives),
663            TypeExtension::Union(union) => Some(&union.directives),
664            TypeExtension::Enum(enum_) => Some(&enum_.directives),
665            TypeExtension::InputObject(input_object) => Some(&input_object.directives),
666            TypeExtension::Scalar(scalar) => Some(&scalar.directives),
667        }
668    }
669}