async_graphql/validation/visitors/
complexity.rs

1use async_graphql_parser::types::{ExecutableDocument, OperationDefinition, VariableDefinition};
2use async_graphql_value::Name;
3
4use crate::{
5    Positioned,
6    parser::types::Field,
7    registry::{MetaType, MetaTypeName},
8    validation::visitor::{VisitMode, Visitor, VisitorContext},
9};
10
11pub struct ComplexityCalculate<'ctx, 'a> {
12    pub complexity: &'a mut usize,
13    pub complexity_stack: Vec<usize>,
14    pub variable_definition: Option<&'ctx [Positioned<VariableDefinition>]>,
15}
16
17impl<'a> ComplexityCalculate<'_, 'a> {
18    pub fn new(complexity: &'a mut usize) -> Self {
19        Self {
20            complexity,
21            complexity_stack: Default::default(),
22            variable_definition: None,
23        }
24    }
25}
26
27impl<'ctx> Visitor<'ctx> for ComplexityCalculate<'ctx, '_> {
28    fn mode(&self) -> VisitMode {
29        VisitMode::Inline
30    }
31
32    fn enter_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
33        self.complexity_stack.push(0);
34    }
35
36    fn exit_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
37        *self.complexity = self.complexity_stack.pop().unwrap();
38    }
39
40    fn enter_operation_definition(
41        &mut self,
42        _ctx: &mut VisitorContext<'ctx>,
43        _name: Option<&'ctx Name>,
44        operation_definition: &'ctx Positioned<OperationDefinition>,
45    ) {
46        self.variable_definition = Some(&operation_definition.node.variable_definitions);
47    }
48
49    fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned<Field>) {
50        self.complexity_stack.push(0);
51    }
52
53    fn exit_field(&mut self, ctx: &mut VisitorContext<'ctx>, field: &'ctx Positioned<Field>) {
54        let children_complex = self.complexity_stack.pop().unwrap();
55
56        if let Some(MetaType::Object { fields, .. }) = ctx.parent_type()
57            && let Some(meta_field) = fields.get(MetaTypeName::concrete_typename(
58                field.node.name.node.as_str(),
59            ))
60            && let Some(f) = &meta_field.compute_complexity
61        {
62            match f(
63                ctx,
64                self.variable_definition.unwrap_or(&[]),
65                &field.node,
66                children_complex,
67            ) {
68                Ok(n) => {
69                    *self.complexity_stack.last_mut().unwrap() += n;
70                }
71                Err(err) => ctx.report_error(vec![field.pos], err.to_string()),
72            }
73            return;
74        }
75
76        *self.complexity_stack.last_mut().unwrap() += 1 + children_complex;
77    }
78}
79
80#[cfg(test)]
81#[allow(clippy::diverging_sub_expression)]
82mod tests {
83    use async_graphql_derive::SimpleObject;
84    use futures_util::stream::BoxStream;
85
86    use super::*;
87    use crate::{
88        EmptyMutation, Object, Schema, Subscription, parser::parse_query, validation::visit,
89    };
90
91    struct Query;
92
93    #[derive(SimpleObject)]
94    #[graphql(internal)]
95    struct MySimpleObj {
96        #[graphql(complexity = 0)]
97        a: i32,
98        #[graphql(complexity = 0)]
99        b: String,
100        #[graphql(complexity = 5)]
101        c: i32,
102    }
103
104    #[derive(Copy, Clone)]
105    struct MyObj;
106
107    #[Object(internal)]
108    #[allow(unreachable_code)]
109    impl MyObj {
110        async fn a(&self) -> i32 {
111            todo!()
112        }
113
114        async fn b(&self) -> i32 {
115            todo!()
116        }
117
118        async fn c(&self) -> MyObj {
119            todo!()
120        }
121    }
122
123    #[Object(internal)]
124    #[allow(unreachable_code)]
125    impl Query {
126        async fn value(&self) -> i32 {
127            todo!()
128        }
129
130        async fn simple_obj(&self) -> MySimpleObj {
131            todo!()
132        }
133
134        #[graphql(complexity = "count * child_complexity + 2")]
135        #[allow(unused_variables)]
136        async fn simple_objs(
137            &self,
138            #[graphql(default_with = "5")] count: usize,
139        ) -> Vec<MySimpleObj> {
140            todo!()
141        }
142
143        async fn obj(&self) -> MyObj {
144            todo!()
145        }
146
147        #[graphql(complexity = "5 * child_complexity")]
148        async fn obj2(&self) -> MyObj {
149            todo!()
150        }
151
152        #[graphql(complexity = "count * child_complexity")]
153        #[allow(unused_variables)]
154        async fn objs(&self, #[graphql(default_with = "5")] count: usize) -> Vec<MyObj> {
155            todo!()
156        }
157
158        #[graphql(complexity = 3)]
159        async fn d(&self) -> MyObj {
160            todo!()
161        }
162    }
163
164    struct Subscription;
165
166    #[Subscription(internal)]
167    impl Subscription {
168        async fn value(&self) -> BoxStream<'static, i32> {
169            todo!()
170        }
171
172        async fn obj(&self) -> BoxStream<'static, MyObj> {
173            todo!()
174        }
175
176        #[graphql(complexity = "count * child_complexity")]
177        #[allow(unused_variables)]
178        async fn objs(
179            &self,
180            #[graphql(default_with = "5")] count: usize,
181        ) -> BoxStream<'static, Vec<MyObj>> {
182            todo!()
183        }
184
185        #[graphql(complexity = 3)]
186        async fn d(&self) -> BoxStream<'static, MyObj> {
187            todo!()
188        }
189    }
190
191    #[track_caller]
192    fn check_complexity(query: &str, expect_complexity: usize) {
193        let registry =
194            Schema::<Query, EmptyMutation, Subscription>::create_registry(Default::default());
195        let doc = parse_query(query).unwrap();
196        let mut ctx = VisitorContext::new(&registry, &doc, None, None);
197        let mut complexity = 0;
198        let mut complexity_calculate = ComplexityCalculate::new(&mut complexity);
199        visit(&mut complexity_calculate, &mut ctx, &doc);
200        assert_eq!(complexity, expect_complexity);
201    }
202
203    #[test]
204    fn simple_object() {
205        check_complexity(
206            r#"{
207                simpleObj { a b }
208            }"#,
209            1,
210        );
211
212        check_complexity(
213            r#"{
214                simpleObj { a b c }
215            }"#,
216            6,
217        );
218
219        check_complexity(
220            r#"{
221                simpleObjs(count: 7) { a b c }
222            }"#,
223            7 * 5 + 2,
224        );
225    }
226
227    #[test]
228    fn complex_object() {
229        check_complexity(
230            r#"
231        {
232            value #1
233        }"#,
234            1,
235        );
236
237        check_complexity(
238            r#"
239        {
240            value #1
241            d #3
242        }"#,
243            4,
244        );
245
246        check_complexity(
247            r#"
248        {
249            value obj { #2
250                a b #2
251            }
252        }"#,
253            4,
254        );
255
256        check_complexity(
257            r#"
258        {
259            value obj { #2
260                a b obj { #3
261                    a b obj { #3
262                        a #1
263                    }
264                }
265            }
266        }"#,
267            9,
268        );
269
270        check_complexity(
271            r#"
272        fragment A on MyObj {
273            a b ... A2 #2
274        }
275
276        fragment A2 on MyObj {
277            obj { # 1
278                a # 1
279            }
280        }
281
282        query {
283            obj { # 1
284                ... A
285            }
286        }"#,
287            5,
288        );
289
290        check_complexity(
291            r#"
292        {
293            obj { # 1
294                ... on MyObj {
295                    a b #2
296                    ... on MyObj {
297                        obj { #1
298                            a #1
299                        }
300                    }
301                }
302            }
303        }"#,
304            5,
305        );
306
307        check_complexity(
308            r#"
309        {
310            objs(count: 10) {
311                a b
312            }
313        }"#,
314            20,
315        );
316
317        check_complexity(
318            r#"
319        {
320            objs {
321                a b
322            }
323        }"#,
324            10,
325        );
326
327        check_complexity(
328            r#"
329        fragment A on MyObj {
330            a b
331        }
332
333        query {
334            objs(count: 10) {
335                ... A
336            }
337        }"#,
338            20,
339        );
340    }
341
342    #[test]
343    fn complex_subscription() {
344        check_complexity(
345            r#"
346        subscription {
347            value #1
348        }"#,
349            1,
350        );
351
352        check_complexity(
353            r#"
354        subscription {
355            value #1
356            d #3
357        }"#,
358            4,
359        );
360
361        check_complexity(
362            r#"
363        subscription {
364            value obj { #2
365                a b #2
366            }
367        }"#,
368            4,
369        );
370
371        check_complexity(
372            r#"
373        subscription {
374            value obj { #2
375                a b obj { #3
376                    a b obj { #3
377                        a #1
378                    }
379                }
380            }
381        }"#,
382            9,
383        );
384
385        check_complexity(
386            r#"
387        fragment A on MyObj {
388            a b ... A2 #2
389        }
390
391        fragment A2 on MyObj {
392            obj { # 1
393                a # 1
394            }
395        }
396
397        subscription query {
398            obj { # 1
399                ... A
400            }
401        }"#,
402            5,
403        );
404
405        check_complexity(
406            r#"
407        subscription {
408            obj { # 1
409                ... on MyObj {
410                    a b #2
411                    ... on MyObj {
412                        obj { #1
413                            a #1
414                        }
415                    }
416                }
417            }
418        }"#,
419            5,
420        );
421
422        check_complexity(
423            r#"
424        subscription {
425            objs(count: 10) {
426                a b
427            }
428        }"#,
429            20,
430        );
431
432        check_complexity(
433            r#"
434        subscription {
435            objs {
436                a b
437            }
438        }"#,
439            10,
440        );
441
442        check_complexity(
443            r#"
444        fragment A on MyObj {
445            a b
446        }
447
448        subscription query {
449            objs(count: 10) {
450                ... A
451            }
452        }"#,
453            20,
454        );
455
456        check_complexity(
457            r#"
458            query {
459                obj2 { a b }
460            }"#,
461            10,
462        );
463    }
464}