1use std::collections::{HashMap, HashSet};
9
10use thiserror::Error;
11
12use crate::graphql::types::{FieldSelection, FragmentDefinition};
13
14#[derive(Debug, Error)]
16#[non_exhaustive]
17pub enum FragmentError {
18 #[error("Fragment not found: {0}")]
20 FragmentNotFound(String),
21
22 #[error("Fragment depth exceeded (max: {0})")]
24 FragmentDepthExceeded(u32),
25
26 #[error("Circular fragment reference detected")]
28 CircularFragmentReference,
29}
30
31pub struct FragmentResolver {
59 fragments: HashMap<String, FragmentDefinition>,
60 max_depth: u32,
61}
62
63impl FragmentResolver {
64 #[must_use]
66 pub fn new(fragments: &[FragmentDefinition]) -> Self {
67 let map = fragments.iter().map(|f| (f.name.clone(), f.clone())).collect();
68 Self {
69 fragments: map,
70 max_depth: 10,
71 }
72 }
73
74 #[must_use]
76 pub const fn with_max_depth(mut self, max_depth: u32) -> Self {
77 self.max_depth = max_depth;
78 self
79 }
80
81 pub fn resolve_spreads(
89 &self,
90 selections: &[FieldSelection],
91 ) -> Result<Vec<FieldSelection>, FragmentError> {
92 self.resolve_selections(selections, 0, &mut HashSet::new())
93 }
94
95 fn resolve_selections(
97 &self,
98 selections: &[FieldSelection],
99 depth: u32,
100 visited_fragments: &mut HashSet<String>,
101 ) -> Result<Vec<FieldSelection>, FragmentError> {
102 if depth > self.max_depth {
103 return Err(FragmentError::FragmentDepthExceeded(self.max_depth));
104 }
105
106 let mut result = Vec::new();
107
108 for selection in selections {
109 if let Some(fragment_name) = selection.name.strip_prefix("...") {
111 if fragment_name.starts_with("on ") {
113 let mut field = selection.clone();
116 if !field.nested_fields.is_empty() {
117 field.nested_fields = self.resolve_selections(
118 &field.nested_fields,
119 depth + 1,
120 visited_fragments,
121 )?;
122 }
123 result.push(field);
124 continue;
125 }
126
127 let fragment_name = fragment_name.to_string();
129
130 if visited_fragments.contains(&fragment_name) {
132 return Err(FragmentError::CircularFragmentReference);
133 }
134
135 let fragment = self
137 .fragments
138 .get(&fragment_name)
139 .ok_or_else(|| FragmentError::FragmentNotFound(fragment_name.clone()))?;
140
141 visited_fragments.insert(fragment_name.clone());
143
144 let resolved =
146 self.resolve_selections(&fragment.selections, depth + 1, visited_fragments)?;
147 result.extend(resolved);
148
149 visited_fragments.remove(&fragment_name);
151 } else {
152 let mut field = selection.clone();
154 if !field.nested_fields.is_empty() {
155 field.nested_fields = self.resolve_selections(
156 &field.nested_fields,
157 depth + 1,
158 visited_fragments,
159 )?;
160 }
161 result.push(field);
162 }
163 }
164
165 Ok(result)
166 }
167
168 #[must_use]
173 pub fn evaluate_inline_fragment(
174 selections: &[FieldSelection],
175 type_condition: Option<&str>,
176 actual_type: &str,
177 ) -> Vec<FieldSelection> {
178 if type_condition.is_none() {
180 return selections.to_vec();
181 }
182
183 if type_condition == Some(actual_type) {
185 selections.to_vec()
186 } else {
187 vec![]
189 }
190 }
191
192 #[must_use]
199 pub fn merge_selections(
200 base: &[FieldSelection],
201 additional: Vec<FieldSelection>,
202 ) -> Vec<FieldSelection> {
203 let mut by_key: HashMap<String, FieldSelection> =
205 base.iter().map(|f| (f.response_key().to_string(), f.clone())).collect();
206
207 for field in additional {
209 let key = field.response_key().to_string();
210 if let Some(existing) = by_key.get_mut(&key) {
211 if !field.nested_fields.is_empty() {
213 existing.nested_fields.extend(field.nested_fields);
214 existing.nested_fields = Self::deduplicate_fields(&existing.nested_fields);
216 }
217 } else {
218 by_key.insert(key, field);
220 }
221 }
222
223 by_key.into_values().collect()
224 }
225
226 fn deduplicate_fields(fields: &[FieldSelection]) -> Vec<FieldSelection> {
228 let mut seen = HashSet::new();
229 fields
230 .iter()
231 .filter(|f| seen.insert(f.response_key().to_string()))
232 .cloned()
233 .collect()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 #![allow(clippy::unwrap_used)] use super::*;
242
243 fn make_field(name: &str, nested: Vec<FieldSelection>) -> FieldSelection {
244 FieldSelection {
245 name: name.to_string(),
246 alias: None,
247 arguments: vec![],
248 nested_fields: nested,
249 directives: vec![],
250 }
251 }
252
253 fn make_fragment(name: &str, selections: Vec<FieldSelection>) -> FragmentDefinition {
254 FragmentDefinition {
255 name: name.to_string(),
256 type_condition: "User".to_string(),
257 selections,
258 fragment_spreads: vec![],
259 }
260 }
261
262 #[test]
263 fn test_simple_fragment_spread_resolution() {
264 let fragment =
265 make_fragment("UserFields", vec![make_field("id", vec![]), make_field("name", vec![])]);
266
267 let selections = vec![FieldSelection {
268 name: "...UserFields".to_string(),
269 alias: None,
270 arguments: vec![],
271 nested_fields: vec![],
272 directives: vec![],
273 }];
274
275 let resolver = FragmentResolver::new(&[fragment]);
276 let result_selections = resolver.resolve_spreads(&selections).unwrap();
277
278 assert_eq!(result_selections.len(), 2);
279 assert_eq!(result_selections[0].name, "id");
280 assert_eq!(result_selections[1].name, "name");
281 }
282
283 #[test]
284 fn test_fragment_not_found() {
285 let selections = vec![FieldSelection {
286 name: "...NonexistentFragment".to_string(),
287 alias: None,
288 arguments: vec![],
289 nested_fields: vec![],
290 directives: vec![],
291 }];
292
293 let resolver = FragmentResolver::new(&[]);
294 let result = resolver.resolve_spreads(&selections);
295
296 assert!(matches!(result, Err(FragmentError::FragmentNotFound(_))));
297 }
298
299 #[test]
300 fn test_nested_fragment_spreads() {
301 let fragment_a = make_fragment("FragmentA", vec![make_field("id", vec![])]);
303
304 let fragment_b = make_fragment(
306 "FragmentB",
307 vec![
308 FieldSelection {
309 name: "...FragmentA".to_string(),
310 alias: None,
311 arguments: vec![],
312 nested_fields: vec![],
313 directives: vec![],
314 },
315 make_field("name", vec![]),
316 ],
317 );
318
319 let selections = vec![FieldSelection {
321 name: "...FragmentB".to_string(),
322 alias: None,
323 arguments: vec![],
324 nested_fields: vec![],
325 directives: vec![],
326 }];
327
328 let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
329 let result_selections = resolver.resolve_spreads(&selections).unwrap();
330
331 assert_eq!(result_selections.len(), 2);
332 assert_eq!(result_selections[0].name, "id");
333 assert_eq!(result_selections[1].name, "name");
334 }
335
336 #[test]
337 fn test_inline_fragment_matching_type() {
338 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
339
340 let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "User");
341
342 assert_eq!(result.len(), 2);
343 assert_eq!(result[0].name, "id");
344 }
345
346 #[test]
347 fn test_inline_fragment_non_matching_type() {
348 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
349
350 let result = FragmentResolver::evaluate_inline_fragment(&selections, Some("User"), "Post");
351
352 assert_eq!(result.len(), 0);
353 }
354
355 #[test]
356 fn test_inline_fragment_without_type_condition() {
357 let selections = vec![make_field("id", vec![]), make_field("name", vec![])];
358
359 let result = FragmentResolver::evaluate_inline_fragment(&selections, None, "User");
360
361 assert_eq!(result.len(), 2);
362 }
363
364 #[test]
365 fn test_merge_non_conflicting_fields() {
366 let base = vec![make_field("id", vec![]), make_field("name", vec![])];
367
368 let additional = vec![make_field("email", vec![])];
369
370 let merged = FragmentResolver::merge_selections(&base, additional);
371
372 assert_eq!(merged.len(), 3);
373 let names: Vec<_> = merged.iter().map(|f| f.name.as_str()).collect();
374 assert!(names.contains(&"id"));
375 assert!(names.contains(&"name"));
376 assert!(names.contains(&"email"));
377 }
378
379 #[test]
380 fn test_merge_conflicting_fields_with_alias() {
381 let base = vec![FieldSelection {
382 name: "user".to_string(),
383 alias: Some("primaryUser".to_string()),
384 arguments: vec![],
385 nested_fields: vec![make_field("id", vec![])],
386 directives: vec![],
387 }];
388
389 let additional = vec![FieldSelection {
390 name: "user".to_string(),
391 alias: Some("primaryUser".to_string()),
392 arguments: vec![],
393 nested_fields: vec![make_field("name", vec![])],
394 directives: vec![],
395 }];
396
397 let merged = FragmentResolver::merge_selections(&base, additional);
398
399 assert_eq!(merged.len(), 1);
400 assert_eq!(merged[0].nested_fields.len(), 2); }
402
403 #[test]
404 fn test_depth_limit() {
405 let mut fragments = vec![];
407 for i in 0..12 {
408 let name = format!("Fragment{i}");
409 let next_spread = if i < 11 {
410 FieldSelection {
411 name: format!("...Fragment{}", i + 1),
412 alias: None,
413 arguments: vec![],
414 nested_fields: vec![],
415 directives: vec![],
416 }
417 } else {
418 make_field("field", vec![])
419 };
420
421 fragments.push(FragmentDefinition {
422 name,
423 type_condition: "User".to_string(),
424 selections: vec![next_spread],
425 fragment_spreads: vec![],
426 });
427 }
428
429 let selections = vec![FieldSelection {
430 name: "...Fragment0".to_string(),
431 alias: None,
432 arguments: vec![],
433 nested_fields: vec![],
434 directives: vec![],
435 }];
436
437 let resolver = FragmentResolver::new(&fragments);
438 let result = resolver.resolve_spreads(&selections);
439
440 assert!(matches!(result, Err(FragmentError::FragmentDepthExceeded(_))));
441 }
442
443 #[test]
444 fn test_circular_reference_detection() {
445 let fragment_a = FragmentDefinition {
447 name: "FragmentA".to_string(),
448 type_condition: "User".to_string(),
449 selections: vec![FieldSelection {
450 name: "...FragmentB".to_string(),
451 alias: None,
452 arguments: vec![],
453 nested_fields: vec![],
454 directives: vec![],
455 }],
456 fragment_spreads: vec!["FragmentB".to_string()],
457 };
458
459 let fragment_b = FragmentDefinition {
460 name: "FragmentB".to_string(),
461 type_condition: "User".to_string(),
462 selections: vec![FieldSelection {
463 name: "...FragmentA".to_string(),
464 alias: None,
465 arguments: vec![],
466 nested_fields: vec![],
467 directives: vec![],
468 }],
469 fragment_spreads: vec!["FragmentA".to_string()],
470 };
471
472 let selections = vec![FieldSelection {
473 name: "...FragmentA".to_string(),
474 alias: None,
475 arguments: vec![],
476 nested_fields: vec![],
477 directives: vec![],
478 }];
479
480 let resolver = FragmentResolver::new(&[fragment_a, fragment_b]);
481 let result = resolver.resolve_spreads(&selections);
482
483 assert!(matches!(result, Err(FragmentError::CircularFragmentReference)));
484 }
485}