1use std::collections::{BTreeMap, BTreeSet, HashMap};
21
22use serde::Serialize;
23
24use super::hasher::hash_content;
25use super::types::EmbedChunk;
26
27#[derive(Debug, Clone, Serialize)]
29pub struct GraphVertex {
30 #[serde(rename = "~id")]
32 pub id: String,
33
34 #[serde(rename = "~label")]
36 pub label: String,
37
38 #[serde(flatten)]
40 pub properties: serde_json::Map<String, serde_json::Value>,
41}
42
43#[derive(Debug, Clone, Serialize)]
45pub struct GraphEdge {
46 #[serde(rename = "~id")]
48 pub id: String,
49
50 #[serde(rename = "~from")]
52 pub from: String,
53
54 #[serde(rename = "~to")]
56 pub to: String,
57
58 #[serde(rename = "~label")]
60 pub label: String,
61}
62
63#[derive(Debug, Clone)]
65pub struct GraphExport {
66 pub vertices: Vec<GraphVertex>,
68
69 pub edges: Vec<GraphEdge>,
71}
72
73fn edge_id(from: &str, to: &str, label: &str) -> String {
75 let input = format!("{from}\0{to}\0{label}");
76 let result = hash_content(&input);
77 format!("e_{}", &result.short_id[3..])
79}
80
81pub fn generate_graph_export(chunks: &[EmbedChunk]) -> GraphExport {
86 let mut symbol_by_name_and_file: HashMap<(&str, &str), &str> = HashMap::new();
89 let mut symbol_by_name: HashMap<&str, &str> = HashMap::new();
90
91 for chunk in chunks {
92 let name = chunk.source.symbol.as_str();
93 let file = chunk.source.file.as_str();
94 symbol_by_name_and_file.insert((name, file), &chunk.id);
95 symbol_by_name.entry(name).or_insert(&chunk.id);
97 }
98
99 let mut files: BTreeSet<&str> = BTreeSet::new();
101 let mut modules: BTreeSet<String> = BTreeSet::new();
102
103 let mut parent_map: HashMap<(&str, &str), Vec<&str>> = HashMap::new();
105
106 for chunk in chunks {
107 files.insert(&chunk.source.file);
108
109 if let Some(module_path) = derive_module_path(&chunk.source.file) {
111 modules.insert(module_path);
112 }
113
114 if let Some(ref parent_name) = chunk.source.parent {
116 parent_map
117 .entry((parent_name.as_str(), chunk.source.file.as_str()))
118 .or_default()
119 .push(&chunk.id);
120 }
121 }
122
123 let mut vertices = Vec::new();
124 let mut edges = Vec::new();
125
126 for chunk in chunks {
128 let mut props = serde_json::Map::new();
129 props.insert("name".to_owned(), serde_json::Value::String(chunk.source.symbol.clone()));
130 props.insert("file".to_owned(), serde_json::Value::String(chunk.source.file.clone()));
131 props.insert(
132 "language".to_owned(),
133 serde_json::Value::String(chunk.source.language.clone()),
134 );
135 props.insert(
136 "visibility".to_owned(),
137 serde_json::Value::String(chunk.source.visibility.name().to_owned()),
138 );
139 if let Some(ref sig) = chunk.context.signature {
140 props.insert("signature".to_owned(), serde_json::Value::String(sig.clone()));
141 }
142 props.insert(
143 "start_line".to_owned(),
144 serde_json::Value::Number(chunk.source.lines.0.into()),
145 );
146 props.insert("end_line".to_owned(), serde_json::Value::Number(chunk.source.lines.1.into()));
147 props.insert("tokens".to_owned(), serde_json::Value::Number(chunk.tokens.into()));
148
149 vertices.push(GraphVertex {
150 id: chunk.id.clone(),
151 label: chunk.kind.name().to_owned(),
152 properties: props,
153 });
154 }
155
156 for file in &files {
158 let mut props = serde_json::Map::new();
159 props.insert("path".to_owned(), serde_json::Value::String((*file).to_owned()));
160
161 if let Some(chunk) = chunks.iter().find(|c| c.source.file == *file) {
163 props.insert(
164 "language".to_owned(),
165 serde_json::Value::String(chunk.source.language.clone()),
166 );
167 }
168
169 vertices.push(GraphVertex {
170 id: format!("file:{file}"),
171 label: "file".to_owned(),
172 properties: props,
173 });
174 }
175
176 let module_files: BTreeMap<&str, Vec<&str>> = {
179 let mut mf: BTreeMap<&str, Vec<&str>> = BTreeMap::new();
180 for file in &files {
181 if let Some(module_path) = derive_module_path(file) {
182 if modules.contains(&module_path) {
183 mf.entry(modules.get(&module_path).map_or("", |s| s.as_str()))
184 .or_default()
185 .push(file);
186 }
187 }
188 }
189 mf
190 };
191
192 for module_path in &modules {
193 let mut props = serde_json::Map::new();
194 props.insert("module_path".to_owned(), serde_json::Value::String(module_path.clone()));
195
196 vertices.push(GraphVertex {
197 id: format!("mod:{module_path}"),
198 label: "module".to_owned(),
199 properties: props,
200 });
201 }
202
203 for chunk in chunks {
205 let file_id = format!("file:{}", chunk.source.file);
206 edges.push(GraphEdge {
207 id: edge_id(&chunk.id, &file_id, "DEFINED_IN"),
208 from: chunk.id.clone(),
209 to: file_id,
210 label: "DEFINED_IN".to_owned(),
211 });
212 }
213
214 for chunk in chunks {
216 for call_name in &chunk.context.calls {
217 let target_id = symbol_by_name_and_file
219 .get(&(call_name.as_str(), chunk.source.file.as_str()))
220 .or_else(|| symbol_by_name.get(call_name.as_str()));
221
222 if let Some(target) = target_id {
223 if *target != chunk.id.as_str() {
225 edges.push(GraphEdge {
226 id: edge_id(&chunk.id, target, "CALLS"),
227 from: chunk.id.clone(),
228 to: (*target).to_owned(),
229 label: "CALLS".to_owned(),
230 });
231 }
232 }
233 }
234 }
235
236 for file in &files {
238 if let Some(module_path) = derive_module_path(file) {
239 if modules.contains(&module_path) {
240 let file_id = format!("file:{file}");
241 let mod_id = format!("mod:{module_path}");
242 edges.push(GraphEdge {
243 id: edge_id(&file_id, &mod_id, "BELONGS_TO"),
244 from: file_id,
245 to: mod_id,
246 label: "BELONGS_TO".to_owned(),
247 });
248 }
249 }
250 }
251
252 for chunk in chunks {
254 if let Some(ref parent_name) = chunk.source.parent {
255 let parent_id =
257 symbol_by_name_and_file.get(&(parent_name.as_str(), chunk.source.file.as_str()));
258
259 if let Some(pid) = parent_id {
260 if *pid != chunk.id.as_str() {
261 edges.push(GraphEdge {
262 id: edge_id(pid, &chunk.id, "CONTAINS"),
263 from: (*pid).to_owned(),
264 to: chunk.id.clone(),
265 label: "CONTAINS".to_owned(),
266 });
267 }
268 }
269 }
270 }
271
272 vertices.sort_by(|a, b| a.id.cmp(&b.id));
274 edges.sort_by(|a, b| a.id.cmp(&b.id));
275
276 GraphExport { vertices, edges }
277}
278
279fn derive_module_path(file_path: &str) -> Option<String> {
287 let path = std::path::Path::new(file_path);
288 let parent = path.parent()?;
289 let parent_str = parent.to_str()?;
290 if parent_str.is_empty() {
291 None
292 } else {
293 Some(parent_str.replace('\\', "/"))
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::embedding::types::{
302 ChunkContext, ChunkKind, ChunkSource, EmbedChunk, RepoIdentifier, Visibility,
303 };
304
305 fn make_chunk(
306 id: &str,
307 symbol: &str,
308 file: &str,
309 kind: ChunkKind,
310 calls: Vec<String>,
311 parent: Option<String>,
312 ) -> EmbedChunk {
313 EmbedChunk {
314 id: id.to_owned(),
315 full_hash: format!("{id}_full"),
316 content: format!("fn {symbol}() {{}}"),
317 tokens: 10,
318 kind,
319 source: ChunkSource {
320 repo: RepoIdentifier::default(),
321 file: file.to_owned(),
322 lines: (1, 5),
323 symbol: symbol.to_owned(),
324 fqn: None,
325 language: "Rust".to_owned(),
326 parent,
327 visibility: Visibility::Public,
328 is_test: false,
329 module_path: None,
330 parent_chunk_id: None,
331 },
332 context: ChunkContext { calls, ..Default::default() },
333 children_ids: Vec::new(),
334 repr: "code".to_string(),
335 code_chunk_id: None,
336 part: None,
337 }
338 }
339
340 #[test]
341 fn test_generate_graph_basic() {
342 let chunks = vec![
343 make_chunk(
344 "ec_aaa",
345 "foo",
346 "src/lib.rs",
347 ChunkKind::Function,
348 vec!["bar".into()],
349 None,
350 ),
351 make_chunk("ec_bbb", "bar", "src/lib.rs", ChunkKind::Function, vec![], None),
352 ];
353
354 let graph = generate_graph_export(&chunks);
355
356 assert!(graph.vertices.len() >= 3);
358
359 let calls: Vec<_> = graph.edges.iter().filter(|e| e.label == "CALLS").collect();
361 assert_eq!(calls.len(), 1);
362 assert_eq!(calls[0].from, "ec_aaa");
363 assert_eq!(calls[0].to, "ec_bbb");
364
365 let defined_in: Vec<_> = graph
366 .edges
367 .iter()
368 .filter(|e| e.label == "DEFINED_IN")
369 .collect();
370 assert_eq!(defined_in.len(), 2);
371 }
372
373 #[test]
374 fn test_contains_edge() {
375 let chunks = vec![
376 make_chunk("ec_cls", "MyClass", "src/model.rs", ChunkKind::Class, vec![], None),
377 make_chunk(
378 "ec_mth",
379 "my_method",
380 "src/model.rs",
381 ChunkKind::Method,
382 vec![],
383 Some("MyClass".into()),
384 ),
385 ];
386
387 let graph = generate_graph_export(&chunks);
388
389 let contains: Vec<_> = graph
390 .edges
391 .iter()
392 .filter(|e| e.label == "CONTAINS")
393 .collect();
394 assert_eq!(contains.len(), 1);
395 assert_eq!(contains[0].from, "ec_cls");
396 assert_eq!(contains[0].to, "ec_mth");
397 }
398
399 #[test]
400 fn test_unresolved_calls_skipped() {
401 let chunks = vec![make_chunk(
402 "ec_aaa",
403 "foo",
404 "src/lib.rs",
405 ChunkKind::Function,
406 vec!["nonexistent".into()],
407 None,
408 )];
409
410 let graph = generate_graph_export(&chunks);
411
412 let calls: Vec<_> = graph.edges.iter().filter(|e| e.label == "CALLS").collect();
413 assert_eq!(calls.len(), 0);
414 }
415
416 #[test]
417 fn test_edge_id_deterministic() {
418 let id1 = edge_id("a", "b", "CALLS");
419 let id2 = edge_id("a", "b", "CALLS");
420 assert_eq!(id1, id2);
421
422 let id3 = edge_id("a", "b", "DEFINED_IN");
423 assert_ne!(id1, id3);
424 }
425
426 #[test]
427 fn test_derive_module_path() {
428 assert_eq!(derive_module_path("src/auth/mod.rs"), Some("src/auth".into()));
429 assert_eq!(derive_module_path("src/auth/token.rs"), Some("src/auth".into()));
430 assert_eq!(derive_module_path("src/lib.rs"), Some("src".into()));
431 assert_eq!(derive_module_path("main.rs"), None);
432 }
433
434 #[test]
435 fn test_output_sorted_deterministically() {
436 let chunks = vec![
437 make_chunk("ec_zzz", "zeta", "src/z.rs", ChunkKind::Function, vec![], None),
438 make_chunk("ec_aaa", "alpha", "src/a.rs", ChunkKind::Function, vec![], None),
439 ];
440
441 let graph = generate_graph_export(&chunks);
442
443 let ids: Vec<_> = graph.vertices.iter().map(|v| v.id.as_str()).collect();
445 let mut sorted = ids.clone();
446 sorted.sort();
447 assert_eq!(ids, sorted);
448 }
449}