async_graphql/validation/rules/
variables_in_allowed_position.rs

1use std::collections::{HashMap, HashSet};
2
3use async_graphql_value::Value;
4
5use crate::{
6    Name, Pos, Positioned,
7    parser::types::{
8        ExecutableDocument, FragmentDefinition, FragmentSpread, OperationDefinition,
9        VariableDefinition,
10    },
11    registry::MetaTypeName,
12    validation::{
13        utils::Scope,
14        visitor::{Visitor, VisitorContext},
15    },
16};
17
18#[derive(Default)]
19pub struct VariableInAllowedPosition<'a> {
20    spreads: HashMap<Scope<'a>, HashSet<&'a str>>,
21    variable_usages: HashMap<Scope<'a>, Vec<(&'a str, Pos, MetaTypeName<'a>)>>,
22    variable_defs: HashMap<Scope<'a>, Vec<&'a Positioned<VariableDefinition>>>,
23    current_scope: Option<Scope<'a>>,
24}
25
26impl<'a> VariableInAllowedPosition<'a> {
27    fn collect_incorrect_usages(
28        &self,
29        from: &Scope<'a>,
30        var_defs: &[&'a Positioned<VariableDefinition>],
31        ctx: &mut VisitorContext<'a>,
32        visited: &mut HashSet<Scope<'a>>,
33    ) {
34        if visited.contains(from) {
35            return;
36        }
37
38        visited.insert(*from);
39
40        if let Some(usages) = self.variable_usages.get(from) {
41            for (var_name, usage_pos, var_type) in usages {
42                if let Some(def) = var_defs.iter().find(|def| def.node.name.node == *var_name) {
43                    let expected_type =
44                        if def.node.var_type.node.nullable && def.node.default_value.is_some() {
45                            // A nullable type with a default value functions as a non-nullable
46                            format!("{}!", def.node.var_type.node)
47                        } else {
48                            def.node.var_type.node.to_string()
49                        };
50
51                    if !var_type.is_subtype(&MetaTypeName::create(&expected_type)) {
52                        ctx.report_error(
53                            vec![def.pos, *usage_pos],
54                            format!(
55                                "Variable \"{}\" of type \"{}\" used in position expecting type \"{}\"",
56                                var_name, var_type, expected_type
57                            ),
58                        );
59                    }
60                }
61            }
62        }
63
64        if let Some(spreads) = self.spreads.get(from) {
65            for spread in spreads {
66                self.collect_incorrect_usages(&Scope::Fragment(spread), var_defs, ctx, visited);
67            }
68        }
69    }
70}
71
72impl<'a> Visitor<'a> for VariableInAllowedPosition<'a> {
73    fn exit_document(&mut self, ctx: &mut VisitorContext<'a>, _doc: &'a ExecutableDocument) {
74        for (op_scope, var_defs) in &self.variable_defs {
75            self.collect_incorrect_usages(op_scope, var_defs, ctx, &mut HashSet::new());
76        }
77    }
78
79    fn enter_operation_definition(
80        &mut self,
81        _ctx: &mut VisitorContext<'a>,
82        name: Option<&'a Name>,
83        _operation_definition: &'a Positioned<OperationDefinition>,
84    ) {
85        self.current_scope = Some(Scope::Operation(name.map(Name::as_str)));
86    }
87
88    fn enter_fragment_definition(
89        &mut self,
90        _ctx: &mut VisitorContext<'a>,
91        name: &'a Name,
92        _fragment_definition: &'a Positioned<FragmentDefinition>,
93    ) {
94        self.current_scope = Some(Scope::Fragment(name));
95    }
96
97    fn enter_variable_definition(
98        &mut self,
99        _ctx: &mut VisitorContext<'a>,
100        variable_definition: &'a Positioned<VariableDefinition>,
101    ) {
102        if let Some(ref scope) = self.current_scope {
103            self.variable_defs
104                .entry(*scope)
105                .or_default()
106                .push(variable_definition);
107        }
108    }
109
110    fn enter_fragment_spread(
111        &mut self,
112        _ctx: &mut VisitorContext<'a>,
113        fragment_spread: &'a Positioned<FragmentSpread>,
114    ) {
115        if let Some(ref scope) = self.current_scope {
116            self.spreads
117                .entry(*scope)
118                .or_default()
119                .insert(&fragment_spread.node.fragment_name.node);
120        }
121    }
122
123    fn enter_input_value(
124        &mut self,
125        _ctx: &mut VisitorContext<'a>,
126        pos: Pos,
127        expected_type: &Option<MetaTypeName<'a>>,
128        value: &'a Value,
129    ) {
130        if let Value::Variable(name) = value
131            && let Some(expected_type) = expected_type
132            && let Some(scope) = &self.current_scope
133        {
134            self.variable_usages
135                .entry(*scope)
136                .or_default()
137                .push((name, pos, *expected_type));
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    pub fn factory<'a>() -> VariableInAllowedPosition<'a> {
147        VariableInAllowedPosition::default()
148    }
149
150    #[test]
151    fn boolean_into_boolean() {
152        expect_passes_rule!(
153            factory,
154            r#"
155          query Query($booleanArg: Boolean)
156          {
157            complicatedArgs {
158              booleanArgField(booleanArg: $booleanArg)
159            }
160          }
161        "#,
162        );
163    }
164
165    #[test]
166    fn boolean_into_boolean_within_fragment() {
167        expect_passes_rule!(
168            factory,
169            r#"
170          fragment booleanArgFrag on ComplicatedArgs {
171            booleanArgField(booleanArg: $booleanArg)
172          }
173          query Query($booleanArg: Boolean)
174          {
175            complicatedArgs {
176              ...booleanArgFrag
177            }
178          }
179        "#,
180        );
181
182        expect_passes_rule!(
183            factory,
184            r#"
185          query Query($booleanArg: Boolean)
186          {
187            complicatedArgs {
188              ...booleanArgFrag
189            }
190          }
191          fragment booleanArgFrag on ComplicatedArgs {
192            booleanArgField(booleanArg: $booleanArg)
193          }
194        "#,
195        );
196    }
197
198    #[test]
199    fn non_null_boolean_into_boolean() {
200        expect_passes_rule!(
201            factory,
202            r#"
203          query Query($nonNullBooleanArg: Boolean!)
204          {
205            complicatedArgs {
206              booleanArgField(booleanArg: $nonNullBooleanArg)
207            }
208          }
209        "#,
210        );
211    }
212
213    #[test]
214    fn non_null_boolean_into_boolean_within_fragment() {
215        expect_passes_rule!(
216            factory,
217            r#"
218          fragment booleanArgFrag on ComplicatedArgs {
219            booleanArgField(booleanArg: $nonNullBooleanArg)
220          }
221          query Query($nonNullBooleanArg: Boolean!)
222          {
223            complicatedArgs {
224              ...booleanArgFrag
225            }
226          }
227        "#,
228        );
229    }
230
231    #[test]
232    fn int_into_non_null_int_with_default() {
233        expect_passes_rule!(
234            factory,
235            r#"
236          query Query($intArg: Int = 1)
237          {
238            complicatedArgs {
239              nonNullIntArgField(nonNullIntArg: $intArg)
240            }
241          }
242        "#,
243        );
244    }
245
246    #[test]
247    fn string_list_into_string_list() {
248        expect_passes_rule!(
249            factory,
250            r#"
251          query Query($stringListVar: [String])
252          {
253            complicatedArgs {
254              stringListArgField(stringListArg: $stringListVar)
255            }
256          }
257        "#,
258        );
259    }
260
261    #[test]
262    fn non_null_string_list_into_string_list() {
263        expect_passes_rule!(
264            factory,
265            r#"
266          query Query($stringListVar: [String!])
267          {
268            complicatedArgs {
269              stringListArgField(stringListArg: $stringListVar)
270            }
271          }
272        "#,
273        );
274    }
275
276    #[test]
277    fn string_into_string_list_in_item_position() {
278        expect_passes_rule!(
279            factory,
280            r#"
281          query Query($stringVar: String)
282          {
283            complicatedArgs {
284              stringListArgField(stringListArg: [$stringVar])
285            }
286          }
287        "#,
288        );
289    }
290
291    #[test]
292    fn non_null_string_into_string_list_in_item_position() {
293        expect_passes_rule!(
294            factory,
295            r#"
296          query Query($stringVar: String!)
297          {
298            complicatedArgs {
299              stringListArgField(stringListArg: [$stringVar])
300            }
301          }
302        "#,
303        );
304    }
305
306    #[test]
307    fn complex_input_into_complex_input() {
308        expect_passes_rule!(
309            factory,
310            r#"
311          query Query($complexVar: ComplexInput)
312          {
313            complicatedArgs {
314              complexArgField(complexArg: $complexVar)
315            }
316          }
317        "#,
318        );
319    }
320
321    #[test]
322    fn complex_input_into_complex_input_in_field_position() {
323        expect_passes_rule!(
324            factory,
325            r#"
326          query Query($boolVar: Boolean = false)
327          {
328            complicatedArgs {
329              complexArgField(complexArg: {requiredArg: $boolVar})
330            }
331          }
332        "#,
333        );
334    }
335
336    #[test]
337    fn non_null_boolean_into_non_null_boolean_in_directive() {
338        expect_passes_rule!(
339            factory,
340            r#"
341          query Query($boolVar: Boolean!)
342          {
343            dog @include(if: $boolVar)
344          }
345        "#,
346        );
347    }
348
349    #[test]
350    fn boolean_in_non_null_in_directive_with_default() {
351        expect_passes_rule!(
352            factory,
353            r#"
354          query Query($boolVar: Boolean = false)
355          {
356            dog @include(if: $boolVar)
357          }
358        "#,
359        );
360    }
361
362    #[test]
363    fn int_into_non_null_int() {
364        expect_fails_rule!(
365            factory,
366            r#"
367          query Query($intArg: Int) {
368            complicatedArgs {
369              nonNullIntArgField(nonNullIntArg: $intArg)
370            }
371          }
372        "#,
373        );
374    }
375
376    #[test]
377    fn int_into_non_null_int_within_fragment() {
378        expect_fails_rule!(
379            factory,
380            r#"
381          fragment nonNullIntArgFieldFrag on ComplicatedArgs {
382            nonNullIntArgField(nonNullIntArg: $intArg)
383          }
384          query Query($intArg: Int) {
385            complicatedArgs {
386              ...nonNullIntArgFieldFrag
387            }
388          }
389        "#,
390        );
391    }
392
393    #[test]
394    fn int_into_non_null_int_within_nested_fragment() {
395        expect_fails_rule!(
396            factory,
397            r#"
398          fragment outerFrag on ComplicatedArgs {
399            ...nonNullIntArgFieldFrag
400          }
401          fragment nonNullIntArgFieldFrag on ComplicatedArgs {
402            nonNullIntArgField(nonNullIntArg: $intArg)
403          }
404          query Query($intArg: Int) {
405            complicatedArgs {
406              ...outerFrag
407            }
408          }
409        "#,
410        );
411    }
412
413    #[test]
414    fn string_over_boolean() {
415        expect_fails_rule!(
416            factory,
417            r#"
418          query Query($stringVar: String) {
419            complicatedArgs {
420              booleanArgField(booleanArg: $stringVar)
421            }
422          }
423        "#,
424        );
425    }
426
427    #[test]
428    fn string_into_string_list() {
429        expect_fails_rule!(
430            factory,
431            r#"
432          query Query($stringVar: String) {
433            complicatedArgs {
434              stringListArgField(stringListArg: $stringVar)
435            }
436          }
437        "#,
438        );
439    }
440
441    #[test]
442    fn boolean_into_non_null_boolean_in_directive() {
443        expect_fails_rule!(
444            factory,
445            r#"
446          query Query($boolVar: Boolean) {
447            dog @include(if: $boolVar)
448          }
449        "#,
450        );
451    }
452
453    #[test]
454    fn string_into_non_null_boolean_in_directive() {
455        expect_fails_rule!(
456            factory,
457            r#"
458          query Query($stringVar: String) {
459            dog @include(if: $stringVar)
460          }
461        "#,
462        );
463    }
464}