1use serde::{Deserialize, Serialize};
6
7use super::EdgeType;
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
11pub struct GraphQuery {
12 pub start_nodes: Vec<String>,
14
15 pub direction: TraversalDirection,
17
18 pub max_depth: usize,
20
21 pub edge_types: Option<Vec<EdgeType>>,
23
24 pub max_results: usize,
26
27 pub language_filter: Option<String>,
29
30 pub granularity_filter: Option<String>,
32
33 pub include_start: bool,
35}
36
37#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum TraversalDirection {
41 #[default]
43 Outgoing,
44 Incoming,
46 Both,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct QueryResult {
53 pub nodes: Vec<QueryNode>,
55
56 pub nodes_visited: usize,
58
59 pub truncated: bool,
61
62 pub execution_time_ms: u64,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct QueryNode {
69 pub chunk_id: String,
71
72 pub alias: Option<String>,
74
75 pub depth: usize,
77
78 pub reached_via: Option<EdgeType>,
80
81 pub parent: Option<String>,
83
84 pub token_estimate: usize,
86}
87
88impl GraphQuery {
89 pub fn new(start_nodes: Vec<String>) -> Self {
91 Self {
92 start_nodes,
93 direction: TraversalDirection::Outgoing,
94 max_depth: 3,
95 edge_types: None,
96 max_results: 100,
97 language_filter: None,
98 granularity_filter: None,
99 include_start: true,
100 }
101 }
102
103 pub fn dependencies(chunk_id: impl Into<String>) -> Self {
105 Self::new(vec![chunk_id.into()]).with_direction(TraversalDirection::Outgoing)
106 }
107
108 pub fn dependents(chunk_id: impl Into<String>) -> Self {
110 Self::new(vec![chunk_id.into()]).with_direction(TraversalDirection::Incoming)
111 }
112
113 pub fn with_direction(mut self, direction: TraversalDirection) -> Self {
115 self.direction = direction;
116 self
117 }
118
119 pub fn with_depth(mut self, depth: usize) -> Self {
121 self.max_depth = depth;
122 self
123 }
124
125 pub fn with_edge_types(mut self, types: Vec<EdgeType>) -> Self {
127 self.edge_types = Some(types);
128 self
129 }
130
131 pub fn strong_only(self) -> Self {
133 self.with_edge_types(vec![
134 EdgeType::Imports,
135 EdgeType::TypeRef,
136 EdgeType::Implements,
137 EdgeType::Extends,
138 ])
139 }
140
141 pub fn with_limit(mut self, limit: usize) -> Self {
143 self.max_results = limit;
144 self
145 }
146
147 pub fn with_language(mut self, language: impl Into<String>) -> Self {
149 self.language_filter = Some(language.into());
150 self
151 }
152
153 pub fn include_start(mut self, include: bool) -> Self {
155 self.include_start = include;
156 self
157 }
158
159 pub fn should_follow_edge(&self, edge_type: EdgeType) -> bool {
161 match &self.edge_types {
162 Some(types) => types.contains(&edge_type),
163 None => true,
164 }
165 }
166}
167
168impl QueryResult {
169 pub fn empty() -> Self {
171 Self {
172 nodes: Vec::new(),
173 nodes_visited: 0,
174 truncated: false,
175 execution_time_ms: 0,
176 }
177 }
178
179 pub fn chunk_ids(&self) -> Vec<&String> {
181 self.nodes.iter().map(|n| &n.chunk_id).collect()
182 }
183
184 pub fn total_tokens(&self) -> usize {
186 self.nodes.iter().map(|n| n.token_estimate).sum()
187 }
188
189 pub fn at_depth(&self, depth: usize) -> impl Iterator<Item = &QueryNode> {
191 self.nodes.iter().filter(move |n| n.depth == depth)
192 }
193
194 pub fn path_to(&self, chunk_id: &str) -> Vec<&QueryNode> {
196 let mut path = Vec::new();
197 let mut current = self.nodes.iter().find(|n| n.chunk_id == chunk_id);
198
199 while let Some(node) = current {
200 path.push(node);
201 current = node
202 .parent
203 .as_ref()
204 .and_then(|p| self.nodes.iter().find(|n| &n.chunk_id == p));
205 }
206
207 path.reverse();
208 path
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct QueryBuilder {
215 query: GraphQuery,
216}
217
218impl QueryBuilder {
219 pub fn new() -> Self {
220 Self {
221 query: GraphQuery::default(),
222 }
223 }
224
225 pub fn start_from(mut self, chunks: Vec<String>) -> Self {
226 self.query.start_nodes = chunks;
227 self
228 }
229
230 pub fn find_dependencies(self) -> Self {
231 Self {
232 query: self.query.with_direction(TraversalDirection::Outgoing),
233 }
234 }
235
236 pub fn find_dependents(self) -> Self {
237 Self {
238 query: self.query.with_direction(TraversalDirection::Incoming),
239 }
240 }
241
242 pub fn depth(mut self, d: usize) -> Self {
243 self.query.max_depth = d;
244 self
245 }
246
247 pub fn limit(mut self, l: usize) -> Self {
248 self.query.max_results = l;
249 self
250 }
251
252 pub fn only_strong(self) -> Self {
253 Self {
254 query: self.query.strong_only(),
255 }
256 }
257
258 pub fn build(self) -> GraphQuery {
259 self.query
260 }
261}
262
263impl Default for QueryBuilder {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_query_creation() {
275 let query = GraphQuery::dependencies("chunk:abc123")
276 .with_depth(2)
277 .strong_only();
278
279 assert_eq!(query.start_nodes.len(), 1);
280 assert_eq!(query.direction, TraversalDirection::Outgoing);
281 assert_eq!(query.max_depth, 2);
282 assert!(query.should_follow_edge(EdgeType::Imports));
283 assert!(!query.should_follow_edge(EdgeType::Calls));
284 }
285
286 #[test]
287 fn test_query_builder() {
288 let query = QueryBuilder::new()
289 .start_from(vec!["chunk:a".to_string(), "chunk:b".to_string()])
290 .find_dependencies()
291 .depth(3)
292 .only_strong()
293 .limit(50)
294 .build();
295
296 assert_eq!(query.start_nodes.len(), 2);
297 assert_eq!(query.max_depth, 3);
298 assert_eq!(query.max_results, 50);
299 }
300
301 #[test]
302 fn test_result_operations() {
303 let result = QueryResult {
304 nodes: vec![
305 QueryNode {
306 chunk_id: "a".to_string(),
307 alias: Some("root".to_string()),
308 depth: 0,
309 reached_via: None,
310 parent: None,
311 token_estimate: 100,
312 },
313 QueryNode {
314 chunk_id: "b".to_string(),
315 alias: None,
316 depth: 1,
317 reached_via: Some(EdgeType::Imports),
318 parent: Some("a".to_string()),
319 token_estimate: 200,
320 },
321 ],
322 nodes_visited: 2,
323 truncated: false,
324 execution_time_ms: 5,
325 };
326
327 assert_eq!(result.total_tokens(), 300);
328 assert_eq!(result.at_depth(1).count(), 1);
329 assert_eq!(result.path_to("b").len(), 2);
330 }
331}