fraiseql_core/graphql/
fragments.rs1use std::collections::{HashMap, HashSet};
7
8use crate::graphql::types::{FragmentDefinition, ParsedQuery};
9
10#[derive(Debug)]
12pub struct FragmentGraph {
13 dependencies: HashMap<String, HashSet<String>>,
15}
16
17impl FragmentGraph {
18 #[must_use]
20 pub fn new(query: &ParsedQuery) -> Self {
21 let mut dependencies = HashMap::new();
22
23 for fragment in &query.fragments {
25 let deps = Self::extract_fragment_dependencies(fragment, &query.fragments);
26 dependencies.insert(fragment.name.clone(), deps);
27 }
28
29 Self { dependencies }
30 }
31
32 fn extract_fragment_dependencies(
34 fragment: &FragmentDefinition,
35 all_fragments: &[FragmentDefinition],
36 ) -> HashSet<String> {
37 let mut deps = HashSet::new();
38
39 deps.extend(fragment.fragment_spreads.iter().cloned());
41
42 for selection in &fragment.selections {
44 Self::extract_selection_dependencies(selection, all_fragments, &mut deps);
45 }
46
47 deps
48 }
49
50 #[allow(clippy::only_used_in_recursion)] fn extract_selection_dependencies(
53 selection: &crate::graphql::types::FieldSelection,
54 all_fragments: &[FragmentDefinition],
55 deps: &mut HashSet<String>,
56 ) {
57 for nested in &selection.nested_fields {
63 Self::extract_selection_dependencies(nested, all_fragments, deps);
64 }
65 }
66
67 pub fn detect_cycles(&self) -> Result<(), Vec<String>> {
77 let mut visited = HashSet::new();
78 let mut recursion_stack = HashSet::new();
79 let mut cycle_path = Vec::new();
80
81 for fragment_name in self.dependencies.keys() {
82 if !visited.contains(fragment_name) {
83 if let Some(cycle) = self.dfs_cycle_detect(
84 fragment_name,
85 &mut visited,
86 &mut recursion_stack,
87 &mut cycle_path,
88 ) {
89 return Err(cycle);
90 }
91 }
92 }
93
94 Ok(())
95 }
96
97 fn dfs_cycle_detect(
99 &self,
100 fragment_name: &str,
101 visited: &mut HashSet<String>,
102 recursion_stack: &mut HashSet<String>,
103 cycle_path: &mut Vec<String>,
104 ) -> Option<Vec<String>> {
105 visited.insert(fragment_name.to_string());
106 recursion_stack.insert(fragment_name.to_string());
107 cycle_path.push(fragment_name.to_string());
108
109 if let Some(deps) = self.dependencies.get(fragment_name) {
110 for dep in deps {
111 if let Some(cycle) =
112 self.check_dependency_cycle(dep, visited, recursion_stack, cycle_path)
113 {
114 return Some(cycle);
115 }
116 }
117 }
118
119 recursion_stack.remove(fragment_name);
120 cycle_path.pop();
121 None
122 }
123
124 fn check_dependency_cycle(
126 &self,
127 dep: &str,
128 visited: &mut HashSet<String>,
129 recursion_stack: &mut HashSet<String>,
130 cycle_path: &mut Vec<String>,
131 ) -> Option<Vec<String>> {
132 if !visited.contains(dep) {
133 return self.dfs_cycle_detect(dep, visited, recursion_stack, cycle_path);
135 }
136
137 if recursion_stack.contains(dep) {
138 #[allow(clippy::expect_used)]
140 let cycle_start = cycle_path
142 .iter()
143 .position(|f| f == dep)
144 .expect("dep must be in cycle_path when in recursion_stack");
145 let cycle = cycle_path[cycle_start..].to_vec();
146 return Some(cycle);
147 }
148
149 None
150 }
151
152 pub fn validate_fragments(&self) -> Result<(), String> {
159 self.detect_cycles()
160 .map_err(|cycle| format!("Fragment cycle detected: {}", cycle.join(" -> ")))
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_no_cycles() {
170 let graph = FragmentGraph {
171 dependencies: HashMap::from([
172 ("FragA".to_string(), HashSet::from(["FragB".to_string()])),
173 ("FragB".to_string(), HashSet::from(["FragC".to_string()])),
174 ("FragC".to_string(), HashSet::new()),
175 ]),
176 };
177 assert!(graph.detect_cycles().is_ok());
178 }
179
180 #[test]
181 fn test_simple_cycle() {
182 let graph = FragmentGraph {
183 dependencies: HashMap::from([
184 ("FragA".to_string(), HashSet::from(["FragB".to_string()])),
185 ("FragB".to_string(), HashSet::from(["FragA".to_string()])),
186 ]),
187 };
188 let result = graph.detect_cycles();
189 assert!(result.is_err());
190 let cycle = result.unwrap_err();
191 assert!(cycle.len() >= 2);
193 }
194
195 #[test]
196 fn test_complex_cycle() {
197 let graph = FragmentGraph {
198 dependencies: HashMap::from([
199 ("FragA".to_string(), HashSet::from(["FragB".to_string()])),
200 ("FragB".to_string(), HashSet::from(["FragC".to_string()])),
201 ("FragC".to_string(), HashSet::from(["FragA".to_string()])),
202 ("FragD".to_string(), HashSet::from(["FragE".to_string()])),
203 ("FragE".to_string(), HashSet::new()),
204 ]),
205 };
206 let result = graph.detect_cycles();
207 assert!(result.is_err());
208 }
209
210 #[test]
211 fn test_multiple_cycles() {
212 let graph = FragmentGraph {
213 dependencies: HashMap::from([
214 ("FragA".to_string(), HashSet::from(["FragB".to_string()])),
215 ("FragB".to_string(), HashSet::from(["FragA".to_string()])),
216 ("FragC".to_string(), HashSet::from(["FragD".to_string()])),
217 ("FragD".to_string(), HashSet::from(["FragC".to_string()])),
218 ]),
219 };
220 let result = graph.detect_cycles();
221 assert!(result.is_err());
222 let cycle = result.unwrap_err();
224 assert!(cycle.len() >= 2); }
226
227 #[test]
228 fn test_self_reference_cycle() {
229 let graph = FragmentGraph {
230 dependencies: HashMap::from([(
231 "FragA".to_string(),
232 HashSet::from(["FragA".to_string()]),
233 )]),
234 };
235 let result = graph.detect_cycles();
236 assert!(result.is_err());
237 }
238}