1use std::collections::{BTreeSet, HashMap};
15
16use tree_sitter::{Node as TsNode, Tree};
17
18use crate::ast::{api as ast_api, Language, Symbol, SymbolKind};
19
20use super::file_table::FileId;
21
22pub type NodeId = u32;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum NodeKind {
30 Function,
32 Type,
34 Module,
36 Import,
38 CallSite,
40 Macro,
42}
43
44impl NodeKind {
45 pub fn as_str(self) -> &'static str {
47 match self {
48 NodeKind::Function => "Function",
49 NodeKind::Type => "Type",
50 NodeKind::Module => "Module",
51 NodeKind::Import => "Import",
52 NodeKind::CallSite => "CallSite",
53 NodeKind::Macro => "Macro",
54 }
55 }
56
57 pub fn parse(label: &str) -> Option<Self> {
59 match label {
60 "Function" => Some(NodeKind::Function),
61 "Type" => Some(NodeKind::Type),
62 "Module" => Some(NodeKind::Module),
63 "Import" => Some(NodeKind::Import),
64 "CallSite" => Some(NodeKind::CallSite),
65 "Macro" => Some(NodeKind::Macro),
66 _ => None,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum EdgeKind {
74 Calls,
76 Refs,
78 Imports,
80 Contains,
82 Overrides,
84}
85
86impl EdgeKind {
87 pub fn as_str(self) -> &'static str {
89 match self {
90 EdgeKind::Calls => "CALLS",
91 EdgeKind::Refs => "REFS",
92 EdgeKind::Imports => "IMPORTS",
93 EdgeKind::Contains => "CONTAINS",
94 EdgeKind::Overrides => "OVERRIDES",
95 }
96 }
97
98 pub fn parse_with_direction(label: &str) -> Option<(Self, bool)> {
102 if let Some(kind) = forward_match(label) {
103 return Some((kind, false));
104 }
105 match label {
106 "CALLED_BY" => Some((EdgeKind::Calls, true)),
107 "REFERENCED_BY" => Some((EdgeKind::Refs, true)),
108 "IMPORTED_BY" => Some((EdgeKind::Imports, true)),
109 "CONTAINED_BY" => Some((EdgeKind::Contains, true)),
110 "OVERRIDDEN_BY" => Some((EdgeKind::Overrides, true)),
111 _ => None,
112 }
113 }
114}
115
116fn forward_match(label: &str) -> Option<EdgeKind> {
117 match label {
118 "CALLS" => Some(EdgeKind::Calls),
119 "REFS" => Some(EdgeKind::Refs),
120 "IMPORTS" => Some(EdgeKind::Imports),
121 "CONTAINS" => Some(EdgeKind::Contains),
122 "OVERRIDES" => Some(EdgeKind::Overrides),
123 _ => None,
124 }
125}
126
127#[derive(Debug, Clone)]
130pub struct Node {
131 pub id: NodeId,
133 pub kind: NodeKind,
135 pub name: String,
138 pub file_id: FileId,
140 pub path: String,
142 pub line: u32,
144 pub signature: String,
146 pub container: Option<String>,
148 pub language: String,
150}
151
152#[derive(Debug, Clone, Copy)]
154pub struct Edge {
155 pub from: NodeId,
157 pub to: NodeId,
159 pub kind: EdgeKind,
161}
162
163#[derive(Debug, Clone, Default)]
167pub struct RebuildOutcome {
168 pub node_count: usize,
171 pub symbols: Vec<Symbol>,
174}
175
176#[derive(Debug, Default, Clone)]
178pub struct SymbolGraph {
179 nodes: HashMap<NodeId, Node>,
180 by_file: HashMap<FileId, Vec<NodeId>>,
181 by_name: HashMap<String, Vec<NodeId>>,
182 out_edges: HashMap<NodeId, Vec<Edge>>,
183 in_edges: HashMap<NodeId, Vec<Edge>>,
184 next_id: NodeId,
185}
186
187impl SymbolGraph {
188 pub fn new() -> Self {
190 Self {
191 next_id: 1,
192 ..Self::default()
193 }
194 }
195
196 pub fn node_count(&self) -> usize {
198 self.nodes.len()
199 }
200
201 pub fn edge_count(&self) -> usize {
203 self.out_edges.values().map(Vec::len).sum()
204 }
205
206 pub fn node(&self, id: NodeId) -> Option<&Node> {
208 self.nodes.get(&id)
209 }
210
211 pub fn iter_nodes(&self) -> impl Iterator<Item = &Node> {
213 self.nodes.values()
214 }
215
216 pub fn nodes_of_kind(&self, kind: NodeKind) -> Vec<NodeId> {
219 let mut out: Vec<NodeId> = self
220 .nodes
221 .values()
222 .filter(|n| n.kind == kind)
223 .map(|n| n.id)
224 .collect();
225 out.sort_unstable();
226 out
227 }
228
229 pub fn all_node_ids(&self) -> Vec<NodeId> {
232 let mut out: Vec<NodeId> = self.nodes.keys().copied().collect();
233 out.sort_unstable();
234 out
235 }
236
237 pub fn nodes_named(&self, name: &str) -> &[NodeId] {
239 match self.by_name.get(name) {
240 Some(v) => v.as_slice(),
241 None => &[],
242 }
243 }
244
245 pub fn outgoing(&self, id: NodeId) -> &[Edge] {
247 self.out_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
248 }
249
250 pub fn incoming(&self, id: NodeId) -> &[Edge] {
252 self.in_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
253 }
254
255 pub fn file_ids(&self) -> Vec<FileId> {
257 let mut out: Vec<FileId> = self.by_file.keys().copied().collect();
258 out.sort_unstable();
259 out
260 }
261
262 pub fn remove_file(&mut self, file_id: FileId) {
264 let Some(node_ids) = self.by_file.remove(&file_id) else {
265 return;
266 };
267 for id in node_ids {
268 self.drop_node(id);
269 }
270 }
271
272 fn drop_node(&mut self, id: NodeId) {
273 let Some(node) = self.nodes.remove(&id) else {
274 return;
275 };
276 if let Some(bucket) = self.by_name.get_mut(&node.name) {
277 bucket.retain(|n| *n != id);
278 if bucket.is_empty() {
279 self.by_name.remove(&node.name);
280 }
281 }
282 if let Some(outs) = self.out_edges.remove(&id) {
283 for e in outs {
284 if let Some(bucket) = self.in_edges.get_mut(&e.to) {
285 bucket.retain(|edge| edge.from != id);
286 }
287 }
288 }
289 if let Some(ins) = self.in_edges.remove(&id) {
290 for e in ins {
291 if let Some(bucket) = self.out_edges.get_mut(&e.from) {
292 bucket.retain(|edge| edge.to != id);
293 }
294 }
295 }
296 }
297
298 pub fn rebuild_file(
305 &mut self,
306 file_id: FileId,
307 path: &str,
308 language: Language,
309 source: &str,
310 import_strings: &[String],
311 ) -> RebuildOutcome {
312 self.remove_file(file_id);
313 let module_id = self.add_module_for_file(file_id, path, &language);
314
315 let (tree, symbols) = match ast_api::parse_with_symbols(source, language) {
320 Ok((t, s)) => (Some(t), s),
321 Err(err) => {
322 tracing::debug!(
323 "code_index: tree-sitter parse failed for `{path}`: {err}; \
324 symbol graph slice will be Module-only"
325 );
326 (None, Vec::new())
327 }
328 };
329
330 let mut container_ids: HashMap<String, NodeId> = HashMap::new();
334 for sym in &symbols {
335 let Some(kind) = map_symbol_kind(sym.kind) else {
336 continue;
337 };
338 let id = self.add_node(Node {
339 id: 0,
340 kind,
341 name: sym.name.clone(),
342 file_id,
343 path: path.to_string(),
344 line: sym.start_row.saturating_add(1),
345 signature: sym.signature.clone(),
346 container: sym.container.clone(),
347 language: language.name().to_string(),
348 });
349 if matches!(kind, NodeKind::Type | NodeKind::Module) {
350 container_ids.insert(sym.name.clone(), id);
351 }
352 let parent_id = sym
353 .container
354 .as_deref()
355 .and_then(|c| container_ids.get(c).copied())
356 .unwrap_or(module_id);
357 self.add_edge(parent_id, id, EdgeKind::Contains);
358 }
359
360 if let Some(tree) = tree.as_ref() {
364 for (callee_name, line) in extract_call_sites_from_tree(tree, source) {
365 let call_id = self.add_node(Node {
366 id: 0,
367 kind: NodeKind::CallSite,
368 name: callee_name.clone(),
369 file_id,
370 path: path.to_string(),
371 line,
372 signature: format!("{callee_name}(…)"),
373 container: None,
374 language: language.name().to_string(),
375 });
376 self.add_edge(module_id, call_id, EdgeKind::Contains);
377 let targets: Vec<NodeId> = self
378 .nodes_named(&callee_name)
379 .iter()
380 .copied()
381 .filter(|nid| {
382 self.nodes
383 .get(nid)
384 .is_some_and(|n| matches!(n.kind, NodeKind::Function))
385 })
386 .collect();
387 for t in targets {
388 self.add_edge(call_id, t, EdgeKind::Calls);
389 }
390 }
391 }
392
393 for raw in import_strings {
398 let imp_id = self.add_node(Node {
399 id: 0,
400 kind: NodeKind::Import,
401 name: raw.clone(),
402 file_id,
403 path: path.to_string(),
404 line: 1,
405 signature: format!("import {raw}"),
406 container: None,
407 language: language.name().to_string(),
408 });
409 self.add_edge(module_id, imp_id, EdgeKind::Imports);
410 }
411
412 for target in self.collect_cross_file_refs(source, file_id) {
416 self.add_edge(module_id, target, EdgeKind::Refs);
417 }
418
419 let node_count = self.by_file.get(&file_id).map(Vec::len).unwrap_or_default();
420 RebuildOutcome {
421 node_count,
422 symbols,
423 }
424 }
425
426 pub fn link_imports(&mut self, resolved: &HashMap<FileId, Vec<FileId>>) {
431 for (src_file, targets) in resolved {
432 let Some(src_module) = self.module_node_for_file(*src_file) else {
433 continue;
434 };
435 for tgt_file in targets {
436 let Some(tgt_module) = self.module_node_for_file(*tgt_file) else {
437 continue;
438 };
439 let already_linked = self.out_edges.get(&src_module).is_some_and(|edges| {
446 edges
447 .iter()
448 .any(|e| e.to == tgt_module && e.kind == EdgeKind::Imports)
449 });
450 if !already_linked {
451 self.add_edge(src_module, tgt_module, EdgeKind::Imports);
452 }
453 }
454 }
455 }
456
457 pub fn module_node_for_file(&self, file_id: FileId) -> Option<NodeId> {
459 let ids = self.by_file.get(&file_id)?;
460 ids.iter().copied().find(|id| {
461 self.nodes
462 .get(id)
463 .is_some_and(|n| matches!(n.kind, NodeKind::Module))
464 })
465 }
466
467 fn collect_cross_file_refs(&self, source: &str, this_file: FileId) -> BTreeSet<NodeId> {
471 let mut out: BTreeSet<NodeId> = BTreeSet::new();
472 if self.by_name.is_empty() {
473 return out;
474 }
475 let mut word = String::with_capacity(32);
476 for ch in source.chars() {
477 if ch.is_alphanumeric() || ch == '_' {
478 word.push(ch);
479 } else if !word.is_empty() {
480 self.absorb_word_refs(&word, this_file, &mut out);
481 word.clear();
482 }
483 }
484 if !word.is_empty() {
485 self.absorb_word_refs(&word, this_file, &mut out);
486 }
487 out
488 }
489
490 fn absorb_word_refs(&self, word: &str, this_file: FileId, bag: &mut BTreeSet<NodeId>) {
491 if word.len() < 3 {
492 return;
493 }
494 let Some(ids) = self.by_name.get(word) else {
495 return;
496 };
497 for nid in ids {
498 let same_file = self.nodes.get(nid).is_some_and(|n| n.file_id == this_file);
499 if !same_file {
500 bag.insert(*nid);
501 }
502 }
503 }
504
505 fn add_module_for_file(&mut self, file_id: FileId, path: &str, language: &Language) -> NodeId {
506 let name = module_name_from_path(path);
507 self.add_node(Node {
508 id: 0,
509 kind: NodeKind::Module,
510 name,
511 file_id,
512 path: path.to_string(),
513 line: 1,
514 signature: format!("module {path}"),
515 container: None,
516 language: language.name().to_string(),
517 })
518 }
519
520 fn add_node(&mut self, mut node: Node) -> NodeId {
521 let id = self.next_id;
522 self.next_id = self.next_id.checked_add(1).expect("NodeId overflow");
523 node.id = id;
524 self.by_file.entry(node.file_id).or_default().push(id);
525 self.by_name.entry(node.name.clone()).or_default().push(id);
526 self.nodes.insert(id, node);
527 id
528 }
529
530 fn add_edge(&mut self, from: NodeId, to: NodeId, kind: EdgeKind) {
531 let edge = Edge { from, to, kind };
532 self.out_edges.entry(from).or_default().push(edge);
533 self.in_edges.entry(to).or_default().push(edge);
534 }
535}
536
537pub fn module_name_from_path(path: &str) -> String {
540 let stem = path.rsplit_once('/').map(|(_, name)| name).unwrap_or(path);
541 let base = stem.rsplit_once('.').map(|(name, _)| name).unwrap_or(stem);
542 base.to_string()
543}
544
545fn map_symbol_kind(kind: SymbolKind) -> Option<NodeKind> {
546 match kind {
547 SymbolKind::Function | SymbolKind::Method => Some(NodeKind::Function),
548 SymbolKind::Class
549 | SymbolKind::Struct
550 | SymbolKind::Enum
551 | SymbolKind::Interface
552 | SymbolKind::Protocol
553 | SymbolKind::Type => Some(NodeKind::Type),
554 SymbolKind::Module => Some(NodeKind::Module),
555 SymbolKind::Variable | SymbolKind::Other => None,
556 }
557}
558
559fn extract_call_sites_from_tree(tree: &Tree, source: &str) -> Vec<(String, u32)> {
564 let mut out: Vec<(String, u32)> = Vec::new();
565 let mut cursor = tree.root_node().walk();
566 let mut stack: Vec<TsNode<'_>> = vec![tree.root_node()];
567 while let Some(node) = stack.pop() {
568 if is_call_kind(node.kind()) {
569 if let Some(name) = call_callee_name(node, source) {
570 let line = node.start_position().row as u32 + 1;
571 out.push((name, line));
572 }
573 }
574 for child in node.children(&mut cursor) {
575 stack.push(child);
576 }
577 }
578 out
579}
580
581fn is_call_kind(kind: &str) -> bool {
582 matches!(
583 kind,
584 "call_expression"
585 | "call"
586 | "function_call"
587 | "method_invocation"
588 | "method_call_expression"
589 | "invocation_expression"
590 | "function_call_expression"
591 | "macro_invocation"
592 )
593}
594
595fn call_callee_name(node: TsNode<'_>, source: &str) -> Option<String> {
596 let callee = node
597 .child_by_field_name("function")
598 .or_else(|| node.child_by_field_name("name"))
599 .or_else(|| node.child_by_field_name("method"))
600 .or_else(|| node.child(0u32))?;
601 let text = &source[callee.start_byte()..callee.end_byte()];
602 let last = text.rsplit_once(['.', ':', '!']);
603 let raw = last.map(|(_, name)| name).unwrap_or(text);
604 let trimmed = raw.trim();
605 let plain: String = trimmed
606 .chars()
607 .take_while(|c| c.is_alphanumeric() || *c == '_')
608 .collect();
609 if plain.is_empty() {
610 None
611 } else {
612 Some(plain)
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn add_and_remove_round_trip() {
622 let mut g = SymbolGraph::new();
623 let outcome = g.rebuild_file(1, "src/a.rs", Language::Rust, "fn foo() {}\n", &[]);
624 assert!(
625 outcome.node_count >= 2,
626 "module + function expected, got {}",
627 outcome.node_count
628 );
629 assert!(
630 outcome.symbols.iter().any(|s| s.name == "foo"),
631 "rebuild_file should surface the parsed `foo` symbol"
632 );
633 assert!(!g.nodes_named("foo").is_empty());
634 g.remove_file(1);
635 assert_eq!(g.node_count(), 0);
636 assert!(g.nodes_named("foo").is_empty());
637 }
638
639 #[test]
640 fn rebuild_file_emits_function_module_and_call_nodes() {
641 let mut g = SymbolGraph::new();
642 let src = "fn alpha() {}\nfn beta() { alpha(); }\n";
643 let outcome = g.rebuild_file(7, "src/x.rs", Language::Rust, src, &[]);
644 assert!(
645 outcome.node_count >= 3,
646 "expected module + 2 functions, got {}",
647 outcome.node_count
648 );
649 let alpha_funcs: Vec<_> = g
650 .iter_nodes()
651 .filter(|n| n.kind == NodeKind::Function && n.name == "alpha")
652 .collect();
653 assert_eq!(alpha_funcs.len(), 1);
654 let beta_funcs: Vec<_> = g
655 .iter_nodes()
656 .filter(|n| n.kind == NodeKind::Function && n.name == "beta")
657 .collect();
658 assert_eq!(beta_funcs.len(), 1);
659 let beta_calls: Vec<_> = g
660 .iter_nodes()
661 .filter(|n| n.kind == NodeKind::CallSite && n.name == "alpha")
662 .collect();
663 assert!(!beta_calls.is_empty(), "expected a CallSite for alpha()");
664 }
665
666 #[test]
667 fn called_by_inverse_label_resolves() {
668 let (kind, reversed) = EdgeKind::parse_with_direction("CALLED_BY").unwrap();
669 assert_eq!(kind, EdgeKind::Calls);
670 assert!(reversed);
671 let (kind, reversed) = EdgeKind::parse_with_direction("CALLS").unwrap();
672 assert_eq!(kind, EdgeKind::Calls);
673 assert!(!reversed);
674 }
675
676 #[test]
677 fn link_imports_creates_module_to_module_edges() {
678 let mut g = SymbolGraph::new();
679 g.rebuild_file(
680 1,
681 "src/a.ts",
682 Language::TypeScript,
683 "import { x } from \"./b\";\n",
684 &["./b".into()],
685 );
686 g.rebuild_file(
687 2,
688 "src/b.ts",
689 Language::TypeScript,
690 "export const x = 1;\n",
691 &[],
692 );
693 let mut resolved: HashMap<FileId, Vec<FileId>> = HashMap::new();
694 resolved.insert(1, vec![2]);
695 g.link_imports(&resolved);
696 let a_mod = g.module_node_for_file(1).unwrap();
697 let b_mod = g.module_node_for_file(2).unwrap();
698 let edge_exists = g
699 .outgoing(a_mod)
700 .iter()
701 .any(|e| e.kind == EdgeKind::Imports && e.to == b_mod);
702 assert!(edge_exists, "expected Module→Module IMPORTS edge");
703 }
704
705 #[test]
706 fn link_imports_is_idempotent_across_repeated_relinks() {
707 let mut g = SymbolGraph::new();
708 g.rebuild_file(
709 1,
710 "src/a.ts",
711 Language::TypeScript,
712 "import { x } from \"./b\";\n",
713 &["./b".into()],
714 );
715 g.rebuild_file(
716 2,
717 "src/b.ts",
718 Language::TypeScript,
719 "export const x = 1;\n",
720 &[],
721 );
722 let mut resolved: HashMap<FileId, Vec<FileId>> = HashMap::new();
723 resolved.insert(1, vec![2]);
724 g.link_imports(&resolved);
728 g.link_imports(&resolved);
729 g.link_imports(&resolved);
730 let a_mod = g.module_node_for_file(1).unwrap();
731 let b_mod = g.module_node_for_file(2).unwrap();
732 let module_import_edges = g
733 .outgoing(a_mod)
734 .iter()
735 .filter(|e| e.kind == EdgeKind::Imports && e.to == b_mod)
736 .count();
737 assert_eq!(
738 module_import_edges, 1,
739 "Module→Module IMPORTS edge must not duplicate across relinks"
740 );
741 }
742}