async_graphql/validation/visitors/
complexity.rs

1use async_graphql_parser::types::{ExecutableDocument, OperationDefinition, VariableDefinition};
2use async_graphql_value::Name;
3
4use crate::{
5    Positioned,
6    parser::types::Field,
7    registry::{MetaType, MetaTypeName},
8    validation::visitor::{VisitMode, Visitor, VisitorContext},
9};
10
11pub struct ComplexityCalculate<'ctx, 'a> {
12    pub complexity: &'a mut usize,
13    pub complexity_stack: Vec<usize>,
14    pub variable_definition: Option<&'ctx [Positioned<VariableDefinition>]>,
15}
16
17impl<'a> ComplexityCalculate<'_, 'a> {
18    pub fn new(complexity: &'a mut usize) -> Self {
19        Self {
20            complexity,
21            complexity_stack: Default::default(),
22            variable_definition: None,
23        }
24    }
25}
26
27impl<'ctx> Visitor<'ctx> for ComplexityCalculate<'ctx, '_> {
28    fn mode(&self) -> VisitMode {
29        VisitMode::Inline
30    }
31
32    fn enter_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
33        self.complexity_stack.push(0);
34    }
35
36    fn exit_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
37        *self.complexity = self.complexity_stack.pop().unwrap();
38    }
39
40    fn enter_operation_definition(
41        &mut self,
42        _ctx: &mut VisitorContext<'ctx>,
43        _name: Option<&'ctx Name>,
44        operation_definition: &'ctx Positioned<OperationDefinition>,
45    ) {
46        self.variable_definition = Some(&operation_definition.node.variable_definitions);
47    }
48
49    fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned<Field>) {
50        self.complexity_stack.push(0);
51    }
52
53    fn exit_field(&mut self, ctx: &mut VisitorContext<'ctx>, field: &'ctx Positioned<Field>) {
54        let children_complex = self.complexity_stack.pop().unwrap();
55
56        if let Some(MetaType::Object { fields, .. }) = ctx.parent_type() {
57            if let Some(meta_field) = fields.get(MetaTypeName::concrete_typename(
58                field.node.name.node.as_str(),
59            )) {
60                if let Some(f) = &meta_field.compute_complexity {
61                    match f(
62                        ctx,
63                        self.variable_definition.unwrap_or(&[]),
64                        &field.node,
65                        children_complex,
66                    ) {
67                        Ok(n) => {
68                            *self.complexity_stack.last_mut().unwrap() += n;
69                        }
70                        Err(err) => ctx.report_error(vec![field.pos], err.to_string()),
71                    }
72                    return;
73                }
74            }
75        }
76
77        *self.complexity_stack.last_mut().unwrap() += 1 + children_complex;
78    }
79}
80
81#[cfg(test)]
82#[allow(clippy::diverging_sub_expression)]
83mod tests {
84    use async_graphql_derive::SimpleObject;
85    use futures_util::stream::BoxStream;
86
87    use super::*;
88    use crate::{
89        EmptyMutation, Object, Schema, Subscription, parser::parse_query, validation::visit,
90    };
91
92    struct Query;
93
94    #[derive(SimpleObject)]
95    #[graphql(internal)]
96    struct MySimpleObj {
97        #[graphql(complexity = 0)]
98        a: i32,
99        #[graphql(complexity = 0)]
100        b: String,
101        #[graphql(complexity = 5)]
102        c: i32,
103    }
104
105    #[derive(Copy, Clone)]
106    struct MyObj;
107
108    #[Object(internal)]
109    #[allow(unreachable_code)]
110    impl MyObj {
111        async fn a(&self) -> i32 {
112            todo!()
113        }
114
115        async fn b(&self) -> i32 {
116            todo!()
117        }
118
119        async fn c(&self) -> MyObj {
120            todo!()
121        }
122    }
123
124    #[Object(internal)]
125    #[allow(unreachable_code)]
126    impl Query {
127        async fn value(&self) -> i32 {
128            todo!()
129        }
130
131        async fn simple_obj(&self) -> MySimpleObj {
132            todo!()
133        }
134
135        #[graphql(complexity = "count * child_complexity + 2")]
136        #[allow(unused_variables)]
137        async fn simple_objs(
138            &self,
139            #[graphql(default_with = "5")] count: usize,
140        ) -> Vec<MySimpleObj> {
141            todo!()
142        }
143
144        async fn obj(&self) -> MyObj {
145            todo!()
146        }
147
148        #[graphql(complexity = "5 * child_complexity")]
149        async fn obj2(&self) -> MyObj {
150            todo!()
151        }
152
153        #[graphql(complexity = "count * child_complexity")]
154        #[allow(unused_variables)]
155        async fn objs(&self, #[graphql(default_with = "5")] count: usize) -> Vec<MyObj> {
156            todo!()
157        }
158
159        #[graphql(complexity = 3)]
160        async fn d(&self) -> MyObj {
161            todo!()
162        }
163    }
164
165    struct Subscription;
166
167    #[Subscription(internal)]
168    impl Subscription {
169        async fn value(&self) -> BoxStream<'static, i32> {
170            todo!()
171        }
172
173        async fn obj(&self) -> BoxStream<'static, MyObj> {
174            todo!()
175        }
176
177        #[graphql(complexity = "count * child_complexity")]
178        #[allow(unused_variables)]
179        async fn objs(
180            &self,
181            #[graphql(default_with = "5")] count: usize,
182        ) -> BoxStream<'static, Vec<MyObj>> {
183            todo!()
184        }
185
186        #[graphql(complexity = 3)]
187        async fn d(&self) -> BoxStream<'static, MyObj> {
188            todo!()
189        }
190    }
191
192    #[track_caller]
193    fn check_complexity(query: &str, expect_complexity: usize) {
194        let registry =
195            Schema::<Query, EmptyMutation, Subscription>::create_registry(Default::default());
196        let doc = parse_query(query).unwrap();
197        let mut ctx = VisitorContext::new(&registry, &doc, None);
198        let mut complexity = 0;
199        let mut complexity_calculate = ComplexityCalculate::new(&mut complexity);
200        visit(&mut complexity_calculate, &mut ctx, &doc);
201        assert_eq!(complexity, expect_complexity);
202    }
203
204    #[test]
205    fn simple_object() {
206        check_complexity(
207            r#"{
208                simpleObj { a b }
209            }"#,
210            1,
211        );
212
213        check_complexity(
214            r#"{
215                simpleObj { a b c }
216            }"#,
217            6,
218        );
219
220        check_complexity(
221            r#"{
222                simpleObjs(count: 7) { a b c }
223            }"#,
224            7 * 5 + 2,
225        );
226    }
227
228    #[test]
229    fn complex_object() {
230        check_complexity(
231            r#"
232        {
233            value #1
234        }"#,
235            1,
236        );
237
238        check_complexity(
239            r#"
240        {
241            value #1
242            d #3
243        }"#,
244            4,
245        );
246
247        check_complexity(
248            r#"
249        {
250            value obj { #2
251                a b #2
252            }
253        }"#,
254            4,
255        );
256
257        check_complexity(
258            r#"
259        {
260            value obj { #2
261                a b obj { #3
262                    a b obj { #3
263                        a #1
264                    }
265                }
266            }
267        }"#,
268            9,
269        );
270
271        check_complexity(
272            r#"
273        fragment A on MyObj {
274            a b ... A2 #2
275        }
276
277        fragment A2 on MyObj {
278            obj { # 1
279                a # 1
280            }
281        }
282
283        query {
284            obj { # 1
285                ... A
286            }
287        }"#,
288            5,
289        );
290
291        check_complexity(
292            r#"
293        {
294            obj { # 1
295                ... on MyObj {
296                    a b #2
297                    ... on MyObj {
298                        obj { #1
299                            a #1
300                        }
301                    }
302                }
303            }
304        }"#,
305            5,
306        );
307
308        check_complexity(
309            r#"
310        {
311            objs(count: 10) {
312                a b
313            }
314        }"#,
315            20,
316        );
317
318        check_complexity(
319            r#"
320        {
321            objs {
322                a b
323            }
324        }"#,
325            10,
326        );
327
328        check_complexity(
329            r#"
330        fragment A on MyObj {
331            a b
332        }
333
334        query {
335            objs(count: 10) {
336                ... A
337            }
338        }"#,
339            20,
340        );
341    }
342
343    #[test]
344    fn complex_subscription() {
345        check_complexity(
346            r#"
347        subscription {
348            value #1
349        }"#,
350            1,
351        );
352
353        check_complexity(
354            r#"
355        subscription {
356            value #1
357            d #3
358        }"#,
359            4,
360        );
361
362        check_complexity(
363            r#"
364        subscription {
365            value obj { #2
366                a b #2
367            }
368        }"#,
369            4,
370        );
371
372        check_complexity(
373            r#"
374        subscription {
375            value obj { #2
376                a b obj { #3
377                    a b obj { #3
378                        a #1
379                    }
380                }
381            }
382        }"#,
383            9,
384        );
385
386        check_complexity(
387            r#"
388        fragment A on MyObj {
389            a b ... A2 #2
390        }
391
392        fragment A2 on MyObj {
393            obj { # 1
394                a # 1
395            }
396        }
397
398        subscription query {
399            obj { # 1
400                ... A
401            }
402        }"#,
403            5,
404        );
405
406        check_complexity(
407            r#"
408        subscription {
409            obj { # 1
410                ... on MyObj {
411                    a b #2
412                    ... on MyObj {
413                        obj { #1
414                            a #1
415                        }
416                    }
417                }
418            }
419        }"#,
420            5,
421        );
422
423        check_complexity(
424            r#"
425        subscription {
426            objs(count: 10) {
427                a b
428            }
429        }"#,
430            20,
431        );
432
433        check_complexity(
434            r#"
435        subscription {
436            objs {
437                a b
438            }
439        }"#,
440            10,
441        );
442
443        check_complexity(
444            r#"
445        fragment A on MyObj {
446            a b
447        }
448
449        subscription query {
450            objs(count: 10) {
451                ... A
452            }
453        }"#,
454            20,
455        );
456
457        check_complexity(
458            r#"
459            query {
460                obj2 { a b }
461            }"#,
462            10,
463        );
464    }
465}