Skip to main content

fraiseql_core/graphql/
fragment_resolver.rs

1//! Fragment resolution for GraphQL queries.
2//!
3//! Handles:
4//! - Fragment spread resolution (`...FragmentName`)
5//! - Inline fragment handling (`... on TypeName { fields }`)
6//! - Selection set merging with deduplication
7
8use std::collections::{HashMap, HashSet};
9
10use thiserror::Error;
11
12use crate::graphql::types::{FieldSelection, FragmentDefinition};
13
14/// Errors that can occur during fragment resolution.
15#[derive(Debug, Error)]
16#[non_exhaustive]
17pub enum FragmentError {
18    /// Indicates that the requested fragment was not found.
19    #[error("Fragment not found: {0}")]
20    FragmentNotFound(String),
21
22    /// Indicates that fragment depth limit was exceeded.
23    #[error("Fragment depth exceeded (max: {0})")]
24    FragmentDepthExceeded(u32),
25
26    /// Indicates a circular reference was detected in fragments.
27    #[error("Circular fragment reference detected")]
28    CircularFragmentReference,
29}
30
31/// Resolves GraphQL fragment spreads in query selection sets.
32///
33/// Handles fragment spreads (`...FragmentName`) by expanding them to their actual field selections.
34/// Also merges multiple fragment definitions and handles field deduplication.
35///
36/// # Example
37///
38/// ```
39/// use fraiseql_core::graphql::{FragmentResolver, FragmentDefinition, FieldSelection};
40///
41/// let fragment = FragmentDefinition {
42///     name: "UserFields".to_string(),
43///     type_condition: "User".to_string(),
44///     selections: vec![
45///         FieldSelection {
46///             name: "id".to_string(),
47///             alias: None,
48///             arguments: vec![],
49///             nested_fields: vec![],
50///             directives: vec![],
51///         },
52///     ],
53///     fragment_spreads: vec![],
54/// };
55///
56/// let resolver = FragmentResolver::new(&[fragment]);
57/// ```
58pub struct FragmentResolver {
59    fragments: HashMap<String, FragmentDefinition>,
60    max_depth: u32,
61}
62
63impl FragmentResolver {
64    /// Create a new fragment resolver from a list of fragment definitions.
65    #[must_use]
66    pub fn new(fragments: &[FragmentDefinition]) -> Self {
67        let map = fragments.iter().map(|f| (f.name.clone(), f.clone())).collect();
68        Self {
69            fragments: map,
70            max_depth: 10,
71        }
72    }
73
74    /// Create a resolver with a custom max depth.
75    #[must_use]
76    pub const fn with_max_depth(mut self, max_depth: u32) -> Self {
77        self.max_depth = max_depth;
78        self
79    }
80
81    /// Resolve all fragment spreads in selections.
82    ///
83    /// # Errors
84    /// Returns error if:
85    /// - Fragment is not found
86    /// - Fragment depth exceeds maximum
87    /// - Circular references are detected
88    pub fn resolve_spreads(
89        &self,
90        selections: &[FieldSelection],
91    ) -> Result<Vec<FieldSelection>, FragmentError> {
92        self.resolve_selections(selections, 0, &mut HashSet::new())
93    }
94
95    /// Recursively resolve selections at a given depth.
96    fn resolve_selections(
97        &self,
98        selections: &[FieldSelection],
99        depth: u32,
100        visited_fragments: &mut HashSet<String>,
101    ) -> Result<Vec<FieldSelection>, FragmentError> {
102        if depth > self.max_depth {
103            return Err(FragmentError::FragmentDepthExceeded(self.max_depth));
104        }
105
106        let mut result = Vec::new();
107
108        for selection in selections {
109            // Check if this is a fragment spread (starts with "...")
110            if let Some(fragment_name) = selection.name.strip_prefix("...") {
111                // Skip inline fragments (they have " on " in the name)
112                if fragment_name.starts_with("on ") {
113                    // Inline fragment — counts as a nesting level (depth + 1) so that
114                    // deeply-nested inline fragments cannot bypass the depth limit.
115                    let mut field = selection.clone();
116                    if !field.nested_fields.is_empty() {
117                        field.nested_fields = self.resolve_selections(
118                            &field.nested_fields,
119                            depth + 1,
120                            visited_fragments,
121                        )?;
122                    }
123                    result.push(field);
124                    continue;
125                }
126
127                // Named fragment spread
128                let fragment_name = fragment_name.to_string();
129
130                // Detect circular references
131                if visited_fragments.contains(&fragment_name) {
132                    return Err(FragmentError::CircularFragmentReference);
133                }
134
135                // Get fragment definition
136                let fragment = self
137                    .fragments
138                    .get(&fragment_name)
139                    .ok_or_else(|| FragmentError::FragmentNotFound(fragment_name.clone()))?;
140
141                // Mark as visited
142                visited_fragments.insert(fragment_name.clone());
143
144                // Recursively resolve the fragment's selections
145                let resolved =
146                    self.resolve_selections(&fragment.selections, depth + 1, visited_fragments)?;
147                result.extend(resolved);
148
149                // Unmark for other paths
150                visited_fragments.remove(&fragment_name);
151            } else {
152                // Regular field: nested fields are one level deeper.
153                let mut field = selection.clone();
154                if !field.nested_fields.is_empty() {
155                    field.nested_fields = self.resolve_selections(
156                        &field.nested_fields,
157                        depth + 1,
158                        visited_fragments,
159                    )?;
160                }
161                result.push(field);
162            }
163        }
164
165        Ok(result)
166    }
167
168    /// Handle inline fragments with type conditions.
169    ///
170    /// Evaluates whether an inline fragment applies based on type conditions.
171    /// Returns the selections if the type condition matches, or an empty vector if it doesn't.
172    #[must_use]
173    pub fn evaluate_inline_fragment(
174        selections: &[FieldSelection],
175        type_condition: Option<&str>,
176        actual_type: &str,
177    ) -> Vec<FieldSelection> {
178        // If no type condition, inline fragment applies to all types
179        if type_condition.is_none() {
180            return selections.to_vec();
181        }
182
183        // If type condition matches actual type, include the fields
184        if type_condition == Some(actual_type) {
185            selections.to_vec()
186        } else {
187            // Type condition doesn't match - skip these fields
188            vec![]
189        }
190    }
191
192    /// Merge field selections from multiple sources (e.g., fragment spreads).
193    ///
194    /// Handles:
195    /// - Combining fields from multiple fragments
196    /// - Deduplicating fields by name/alias
197    /// - Merging nested selections
198    #[must_use]
199    pub fn merge_selections(
200        base: &[FieldSelection],
201        additional: Vec<FieldSelection>,
202    ) -> Vec<FieldSelection> {
203        // Build map of existing fields by response key (alias or name)
204        let mut by_key: HashMap<String, FieldSelection> =
205            base.iter().map(|f| (f.response_key().to_string(), f.clone())).collect();
206
207        // Merge additional fields
208        for field in additional {
209            let key = field.response_key().to_string();
210            if let Some(existing) = by_key.get_mut(&key) {
211                // Field already exists - merge nested selections
212                if !field.nested_fields.is_empty() {
213                    existing.nested_fields.extend(field.nested_fields);
214                    // Deduplicate nested fields
215                    existing.nested_fields = Self::deduplicate_fields(&existing.nested_fields);
216                }
217            } else {
218                // New field - add it
219                by_key.insert(key, field);
220            }
221        }
222
223        by_key.into_values().collect()
224    }
225
226    /// Deduplicate fields in a selection set by response key.
227    fn deduplicate_fields(fields: &[FieldSelection]) -> Vec<FieldSelection> {
228        let mut seen = HashSet::new();
229        fields
230            .iter()
231            .filter(|f| seen.insert(f.response_key().to_string()))
232            .cloned()
233            .collect()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
240
241    use super::*;
242
243    fn make_field(name: &str, nested: Vec<FieldSelection>) -> FieldSelection {
244        FieldSelection {
245            name:          name.to_string(),
246            alias:         None,
247            arguments:     vec![],
248            nested_fields: nested,
249            directives:    vec![],
250        }
251    }
252
253    fn make_fragment(name: &str, selections: Vec<FieldSelection>) -> FragmentDefinition {
254        FragmentDefinition {
255            name: name.to_string(),
256            type_condition: "User".to_string(),
257            selections,
258            fragment_spreads: vec![],
259        }
260    }
261
262    #[test]
263    fn test_simple_fragment_spread_resolution() {
264        let fragment =
265            make_fragment("UserFields", vec![make_field("id", vec![]), make_field("name", vec![])]);
266
267        let selections = vec![FieldSelection {
268            name:          "...UserFields".to_string(),
269            alias:         None,
270            arguments:     vec![],
271            nested_fields: vec![],
272            directives:    vec![],
273        }];
274
275        let resolver = FragmentResolver::new(&[fragment]);
276        let result_selections = resolver.resolve_spreads(&selections).unwrap();
277
278        assert_eq!(result_selections.len(), 2);
279        assert_eq!(result_selections[0].name, "id");
280        assert_eq!(result_selections[1].name, "name");
281    }
282
283    #[test]
284    fn test_fragment_not_found() {
285        let selections = vec![FieldSelection {
286            name:          "...NonexistentFragment".to_string(),
287            alias:         None,
288            arguments:     vec![],
289            nested_fields: vec![],
290            directives:    vec![],
291        }];
292
293        let resolver = FragmentResolver::new(&[]);
294        let result = resolver.resolve_spreads(&selections);
295
296        assert!(matches!(result, Err(FragmentError::FragmentNotFound(_))));
297    }
298
299    #[test]
300    fn test_nested_fragment_spreads() {
301        // Fragment A references fields
302        let fragment_a = make_fragment("FragmentA", vec![make_field("id", vec![])]);
303
304        // Fragment B spreads Fragment A
305        let fragment_b = make_fragment(
306            "FragmentB",
307            vec![
308                FieldSelection {
309                    name:          "...FragmentA".to_string(),
310                    alias:         None,
311                    arguments:     vec![],
312                    nested_fields: vec![],
313                    directives:    vec![],
314                },
315                make_field("name", vec![]),
316            ],
317        );
318
319        // Query spreads Fragment B
320        let selections = vec![FieldSelection {
321            name:          "...FragmentB".to_string(),
322            alias:         None,
323            arguments:     vec![],
324            nested_fields: vec![],
325            directives:    vec![],
326        }];
327
328        let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
329        let result_selections = resolver.resolve_spreads(&selections).unwrap();
330
331        assert_eq!(result_selections.len(), 2);
332        assert_eq!(result_selections[0].name, "id");
333        assert_eq!(result_selections[1].name, "name");
334    }
335
336    #[test]
337    fn test_inline_fragment_matching_type() {
338        let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
339
340        let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "User");
341
342        assert_eq!(result.len(), 2);
343        assert_eq!(result[0].name, "id");
344    }
345
346    #[test]
347    fn test_inline_fragment_non_matching_type() {
348        let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
349
350        let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "Post");
351
352        assert_eq!(result.len(), 0);
353    }
354
355    #[test]
356    fn test_inline_fragment_without_type_condition() {
357        let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
358
359        let result = FragmentResolver::evaluate_inline_fragment(&selections, None, "User");
360
361        assert_eq!(result.len(), 2);
362    }
363
364    #[test]
365    fn test_merge_non_conflicting_fields() {
366        let base = vec![make_field("id", vec![]), make_field("name", vec![])];
367
368        let additional = vec![make_field("email", vec![])];
369
370        let merged = FragmentResolver::merge_selections(&base, additional);
371
372        assert_eq!(merged.len(), 3);
373        let names: Vec<_> = merged.iter().map(|f| f.name.as_str()).collect();
374        assert!(names.contains(&"id"));
375        assert!(names.contains(&"name"));
376        assert!(names.contains(&"email"));
377    }
378
379    #[test]
380    fn test_merge_conflicting_fields_with_alias() {
381        let base = vec![FieldSelection {
382            name:          "user".to_string(),
383            alias:         Some("primaryUser".to_string()),
384            arguments:     vec![],
385            nested_fields: vec![make_field("id", vec![])],
386            directives:    vec![],
387        }];
388
389        let additional = vec![FieldSelection {
390            name:          "user".to_string(),
391            alias:         Some("primaryUser".to_string()),
392            arguments:     vec![],
393            nested_fields: vec![make_field("name", vec![])],
394            directives:    vec![],
395        }];
396
397        let merged = FragmentResolver::merge_selections(&base, additional);
398
399        assert_eq!(merged.len(), 1);
400        assert_eq!(merged[0].nested_fields.len(), 2); // id and name merged
401    }
402
403    #[test]
404    fn test_depth_limit() {
405        // Create deeply nested fragment references
406        let mut fragments = vec![];
407        for i in 0..12 {
408            let name = format!("Fragment{i}");
409            let next_spread = if i < 11 {
410                FieldSelection {
411                    name:          format!("...Fragment{}", i + 1),
412                    alias:         None,
413                    arguments:     vec![],
414                    nested_fields: vec![],
415                    directives:    vec![],
416                }
417            } else {
418                make_field("field", vec![])
419            };
420
421            fragments.push(FragmentDefinition {
422                name,
423                type_condition: "User".to_string(),
424                selections: vec![next_spread],
425                fragment_spreads: vec![],
426            });
427        }
428
429        let selections = vec![FieldSelection {
430            name:          "...Fragment0".to_string(),
431            alias:         None,
432            arguments:     vec![],
433            nested_fields: vec![],
434            directives:    vec![],
435        }];
436
437        let resolver = FragmentResolver::new(&fragments);
438        let result = resolver.resolve_spreads(&selections);
439
440        assert!(matches!(result, Err(FragmentError::FragmentDepthExceeded(_))));
441    }
442
443    #[test]
444    fn test_circular_reference_detection() {
445        // FragmentA -> FragmentB -> FragmentA (circular)
446        let fragment_a = FragmentDefinition {
447            name:             "FragmentA".to_string(),
448            type_condition:   "User".to_string(),
449            selections:       vec![FieldSelection {
450                name:          "...FragmentB".to_string(),
451                alias:         None,
452                arguments:     vec![],
453                nested_fields: vec![],
454                directives:    vec![],
455            }],
456            fragment_spreads: vec!["FragmentB".to_string()],
457        };
458
459        let fragment_b = FragmentDefinition {
460            name:             "FragmentB".to_string(),
461            type_condition:   "User".to_string(),
462            selections:       vec![FieldSelection {
463                name:          "...FragmentA".to_string(),
464                alias:         None,
465                arguments:     vec![],
466                nested_fields: vec![],
467                directives:    vec![],
468            }],
469            fragment_spreads: vec!["FragmentA".to_string()],
470        };
471
472        let selections = vec![FieldSelection {
473            name:          "...FragmentA".to_string(),
474            alias:         None,
475            arguments:     vec![],
476            nested_fields: vec![],
477            directives:    vec![],
478        }];
479
480        let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
481        let result = resolver.resolve_spreads(&selections);
482
483        assert!(matches!(result, Err(FragmentError::CircularFragmentReference)));
484    }
485}