1use std::collections::{HashMap, HashSet};
9
10use thiserror::Error;
11
12use crate::graphql::types::{FieldSelection, FragmentDefinition};
13
14#[derive(Debug, Error)]
16pub enum FragmentError {
17 #[error("Fragment not found: {0}")]
19 FragmentNotFound(String),
20
21 #[error("Fragment depth exceeded (max: {0})")]
23 FragmentDepthExceeded(u32),
24
25 #[error("Circular fragment reference detected")]
27 CircularFragmentReference,
28}
29
30pub struct FragmentResolver {
58 fragments: HashMap<String, FragmentDefinition>,
59 max_depth: u32,
60}
61
62impl FragmentResolver {
63 #[must_use]
65 pub fn new(fragments: &[FragmentDefinition]) -> Self {
66 let map = fragments.iter().map(|f| (f.name.clone(), f.clone())).collect();
67 Self {
68 fragments: map,
69 max_depth: 10,
70 }
71 }
72
73 #[must_use]
75 pub fn with_max_depth(mut self, max_depth: u32) -> Self {
76 self.max_depth = max_depth;
77 self
78 }
79
80 pub fn resolve_spreads(
88 &self,
89 selections: &[FieldSelection],
90 ) -> Result<Vec<FieldSelection>, FragmentError> {
91 self.resolve_selections(selections, 0, &mut HashSet::new())
92 }
93
94 fn resolve_selections(
96 &self,
97 selections: &[FieldSelection],
98 depth: u32,
99 visited_fragments: &mut HashSet<String>,
100 ) -> Result<Vec<FieldSelection>, FragmentError> {
101 if depth > self.max_depth {
102 return Err(FragmentError::FragmentDepthExceeded(self.max_depth));
103 }
104
105 let mut result = Vec::new();
106
107 for selection in selections {
108 if let Some(fragment_name) = selection.name.strip_prefix("...") {
110 if fragment_name.starts_with("on ") {
112 let mut field = selection.clone();
114 if !field.nested_fields.is_empty() {
115 field.nested_fields = self.resolve_selections(
116 &field.nested_fields,
117 depth,
118 visited_fragments,
119 )?;
120 }
121 result.push(field);
122 continue;
123 }
124
125 let fragment_name = fragment_name.to_string();
127
128 if visited_fragments.contains(&fragment_name) {
130 return Err(FragmentError::CircularFragmentReference);
131 }
132
133 let fragment = self
135 .fragments
136 .get(&fragment_name)
137 .ok_or_else(|| FragmentError::FragmentNotFound(fragment_name.clone()))?;
138
139 visited_fragments.insert(fragment_name.clone());
141
142 let resolved =
144 self.resolve_selections(&fragment.selections, depth + 1, visited_fragments)?;
145 result.extend(resolved);
146
147 visited_fragments.remove(&fragment_name);
149 } else {
150 let mut field = selection.clone();
152 if !field.nested_fields.is_empty() {
153 field.nested_fields =
154 self.resolve_selections(&field.nested_fields, depth, visited_fragments)?;
155 }
156 result.push(field);
157 }
158 }
159
160 Ok(result)
161 }
162
163 #[must_use]
168 pub fn evaluate_inline_fragment(
169 selections: &[FieldSelection],
170 type_condition: Option<&str>,
171 actual_type: &str,
172 ) -> Vec<FieldSelection> {
173 if type_condition.is_none() {
175 return selections.to_vec();
176 }
177
178 if type_condition == Some(actual_type) {
180 selections.to_vec()
181 } else {
182 vec![]
184 }
185 }
186
187 #[must_use]
194 pub fn merge_selections(
195 base: &[FieldSelection],
196 additional: Vec<FieldSelection>,
197 ) -> Vec<FieldSelection> {
198 let mut by_key: HashMap<String, FieldSelection> =
200 base.iter().map(|f| (f.response_key().to_string(), f.clone())).collect();
201
202 for field in additional {
204 let key = field.response_key().to_string();
205 if let Some(existing) = by_key.get_mut(&key) {
206 if !field.nested_fields.is_empty() {
208 existing.nested_fields.extend(field.nested_fields);
209 existing.nested_fields = Self::deduplicate_fields(&existing.nested_fields);
211 }
212 } else {
213 by_key.insert(key, field);
215 }
216 }
217
218 by_key.into_values().collect()
219 }
220
221 fn deduplicate_fields(fields: &[FieldSelection]) -> Vec<FieldSelection> {
223 let mut seen = HashSet::new();
224 fields
225 .iter()
226 .filter(|f| seen.insert(f.response_key().to_string()))
227 .cloned()
228 .collect()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 fn make_field(name: &str, nested: Vec<FieldSelection>) -> FieldSelection {
237 FieldSelection {
238 name: name.to_string(),
239 alias: None,
240 arguments: vec![],
241 nested_fields: nested,
242 directives: vec![],
243 }
244 }
245
246 fn make_fragment(name: &str, selections: Vec<FieldSelection>) -> FragmentDefinition {
247 FragmentDefinition {
248 name: name.to_string(),
249 type_condition: "User".to_string(),
250 selections,
251 fragment_spreads: vec![],
252 }
253 }
254
255 #[test]
256 fn test_simple_fragment_spread_resolution() {
257 let fragment =
258 make_fragment("UserFields", vec![make_field("id", vec![]), make_field("name", vec![])]);
259
260 let selections = vec![FieldSelection {
261 name: "...UserFields".to_string(),
262 alias: None,
263 arguments: vec![],
264 nested_fields: vec![],
265 directives: vec![],
266 }];
267
268 let resolver = FragmentResolver::new(&[fragment]);
269 let result_selections = resolver.resolve_spreads(&selections).unwrap();
270
271 assert_eq!(result_selections.len(), 2);
272 assert_eq!(result_selections[0].name, "id");
273 assert_eq!(result_selections[1].name, "name");
274 }
275
276 #[test]
277 fn test_fragment_not_found() {
278 let selections = vec![FieldSelection {
279 name: "...NonexistentFragment".to_string(),
280 alias: None,
281 arguments: vec![],
282 nested_fields: vec![],
283 directives: vec![],
284 }];
285
286 let resolver = FragmentResolver::new(&[]);
287 let result = resolver.resolve_spreads(&selections);
288
289 assert!(matches!(result, Err(FragmentError::FragmentNotFound(_))));
290 }
291
292 #[test]
293 fn test_nested_fragment_spreads() {
294 let fragment_a = make_fragment("FragmentA", vec![make_field("id", vec![])]);
296
297 let fragment_b = make_fragment(
299 "FragmentB",
300 vec![
301 FieldSelection {
302 name: "...FragmentA".to_string(),
303 alias: None,
304 arguments: vec![],
305 nested_fields: vec![],
306 directives: vec![],
307 },
308 make_field("name", vec![]),
309 ],
310 );
311
312 let selections = vec![FieldSelection {
314 name: "...FragmentB".to_string(),
315 alias: None,
316 arguments: vec![],
317 nested_fields: vec![],
318 directives: vec![],
319 }];
320
321 let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
322 let result_selections = resolver.resolve_spreads(&selections).unwrap();
323
324 assert_eq!(result_selections.len(), 2);
325 assert_eq!(result_selections[0].name, "id");
326 assert_eq!(result_selections[1].name, "name");
327 }
328
329 #[test]
330 fn test_inline_fragment_matching_type() {
331 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
332
333 let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "User");
334
335 assert_eq!(result.len(), 2);
336 assert_eq!(result[0].name, "id");
337 }
338
339 #[test]
340 fn test_inline_fragment_non_matching_type() {
341 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
342
343 let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "Post");
344
345 assert_eq!(result.len(), 0);
346 }
347
348 #[test]
349 fn test_inline_fragment_without_type_condition() {
350 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
351
352 let result = FragmentResolver::evaluate_inline_fragment(&selections, None, "User");
353
354 assert_eq!(result.len(), 2);
355 }
356
357 #[test]
358 fn test_merge_non_conflicting_fields() {
359 let base = vec![make_field("id", vec![]), make_field("name", vec![])];
360
361 let additional = vec![make_field("email", vec![])];
362
363 let merged = FragmentResolver::merge_selections(&base, additional);
364
365 assert_eq!(merged.len(), 3);
366 let names: Vec<_> = merged.iter().map(|f| f.name.as_str()).collect();
367 assert!(names.contains(&"id"));
368 assert!(names.contains(&"name"));
369 assert!(names.contains(&"email"));
370 }
371
372 #[test]
373 fn test_merge_conflicting_fields_with_alias() {
374 let base = vec![FieldSelection {
375 name: "user".to_string(),
376 alias: Some("primaryUser".to_string()),
377 arguments: vec![],
378 nested_fields: vec![make_field("id", vec![])],
379 directives: vec![],
380 }];
381
382 let additional = vec![FieldSelection {
383 name: "user".to_string(),
384 alias: Some("primaryUser".to_string()),
385 arguments: vec![],
386 nested_fields: vec![make_field("name", vec![])],
387 directives: vec![],
388 }];
389
390 let merged = FragmentResolver::merge_selections(&base, additional);
391
392 assert_eq!(merged.len(), 1);
393 assert_eq!(merged[0].nested_fields.len(), 2); }
395
396 #[test]
397 fn test_depth_limit() {
398 let mut fragments = vec![];
400 for i in 0..12 {
401 let name = format!("Fragment{i}");
402 let next_spread = if i < 11 {
403 FieldSelection {
404 name: format!("...Fragment{}", i + 1),
405 alias: None,
406 arguments: vec![],
407 nested_fields: vec![],
408 directives: vec![],
409 }
410 } else {
411 make_field("field", vec![])
412 };
413
414 fragments.push(FragmentDefinition {
415 name,
416 type_condition: "User".to_string(),
417 selections: vec![next_spread],
418 fragment_spreads: vec![],
419 });
420 }
421
422 let selections = vec![FieldSelection {
423 name: "...Fragment0".to_string(),
424 alias: None,
425 arguments: vec![],
426 nested_fields: vec![],
427 directives: vec![],
428 }];
429
430 let resolver = FragmentResolver::new(&fragments);
431 let result = resolver.resolve_spreads(&selections);
432
433 assert!(matches!(result, Err(FragmentError::FragmentDepthExceeded(_))));
434 }
435
436 #[test]
437 fn test_circular_reference_detection() {
438 let fragment_a = FragmentDefinition {
440 name: "FragmentA".to_string(),
441 type_condition: "User".to_string(),
442 selections: vec![FieldSelection {
443 name: "...FragmentB".to_string(),
444 alias: None,
445 arguments: vec![],
446 nested_fields: vec![],
447 directives: vec![],
448 }],
449 fragment_spreads: vec!["FragmentB".to_string()],
450 };
451
452 let fragment_b = FragmentDefinition {
453 name: "FragmentB".to_string(),
454 type_condition: "User".to_string(),
455 selections: vec![FieldSelection {
456 name: "...FragmentA".to_string(),
457 alias: None,
458 arguments: vec![],
459 nested_fields: vec![],
460 directives: vec![],
461 }],
462 fragment_spreads: vec!["FragmentA".to_string()],
463 };
464
465 let selections = vec![FieldSelection {
466 name: "...FragmentA".to_string(),
467 alias: None,
468 arguments: vec![],
469 nested_fields: vec![],
470 directives: vec![],
471 }];
472
473 let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
474 let result = resolver.resolve_spreads(&selections);
475
476 assert!(matches!(result, Err(FragmentError::CircularFragmentReference)));
477 }
478}