1use std::collections::{BTreeSet, HashMap};
15
16use tree_sitter::{Node as TsNode, Tree};
17
18use crate::ast::{api as ast_api, Language, Symbol, SymbolKind};
19
20use super::file_table::FileId;
21
22pub type NodeId = u32;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum NodeKind {
30 Function,
32 Type,
34 Field,
36 EnumCase,
38 Module,
40 Import,
42 CallSite,
44 Macro,
46}
47
48impl NodeKind {
49 pub fn as_str(self) -> &'static str {
51 match self {
52 NodeKind::Function => "Function",
53 NodeKind::Type => "Type",
54 NodeKind::Field => "Field",
55 NodeKind::EnumCase => "EnumCase",
56 NodeKind::Module => "Module",
57 NodeKind::Import => "Import",
58 NodeKind::CallSite => "CallSite",
59 NodeKind::Macro => "Macro",
60 }
61 }
62
63 pub fn parse(label: &str) -> Option<Self> {
65 match label {
66 "Function" => Some(NodeKind::Function),
67 "Type" => Some(NodeKind::Type),
68 "Field" => Some(NodeKind::Field),
69 "EnumCase" => Some(NodeKind::EnumCase),
70 "Module" => Some(NodeKind::Module),
71 "Import" => Some(NodeKind::Import),
72 "CallSite" => Some(NodeKind::CallSite),
73 "Macro" => Some(NodeKind::Macro),
74 _ => None,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub enum EdgeKind {
82 Calls,
84 Refs,
86 Imports,
88 Contains,
90 Overrides,
92}
93
94impl EdgeKind {
95 pub fn as_str(self) -> &'static str {
97 match self {
98 EdgeKind::Calls => "CALLS",
99 EdgeKind::Refs => "REFS",
100 EdgeKind::Imports => "IMPORTS",
101 EdgeKind::Contains => "CONTAINS",
102 EdgeKind::Overrides => "OVERRIDES",
103 }
104 }
105
106 pub fn parse_with_direction(label: &str) -> Option<(Self, bool)> {
110 if let Some(kind) = forward_match(label) {
111 return Some((kind, false));
112 }
113 match label {
114 "CALLED_BY" => Some((EdgeKind::Calls, true)),
115 "REFERENCED_BY" => Some((EdgeKind::Refs, true)),
116 "IMPORTED_BY" => Some((EdgeKind::Imports, true)),
117 "CONTAINED_BY" => Some((EdgeKind::Contains, true)),
118 "OVERRIDDEN_BY" => Some((EdgeKind::Overrides, true)),
119 _ => None,
120 }
121 }
122}
123
124fn forward_match(label: &str) -> Option<EdgeKind> {
125 match label {
126 "CALLS" => Some(EdgeKind::Calls),
127 "REFS" => Some(EdgeKind::Refs),
128 "IMPORTS" => Some(EdgeKind::Imports),
129 "CONTAINS" => Some(EdgeKind::Contains),
130 "OVERRIDES" => Some(EdgeKind::Overrides),
131 _ => None,
132 }
133}
134
135#[derive(Debug, Clone)]
138pub struct Node {
139 pub id: NodeId,
141 pub kind: NodeKind,
143 pub name: String,
146 pub file_id: FileId,
148 pub path: String,
150 pub line: u32,
152 pub signature: String,
154 pub container: Option<String>,
156 pub access_level: Option<String>,
158 pub language: String,
160}
161
162#[derive(Debug, Clone, Copy)]
164pub struct Edge {
165 pub from: NodeId,
167 pub to: NodeId,
169 pub kind: EdgeKind,
171}
172
173#[derive(Debug, Clone, Default)]
177pub struct RebuildOutcome {
178 pub node_count: usize,
181 pub symbols: Vec<Symbol>,
184}
185
186#[derive(Debug, Default, Clone)]
188pub struct SymbolGraph {
189 nodes: HashMap<NodeId, Node>,
190 by_file: HashMap<FileId, Vec<NodeId>>,
191 by_name: HashMap<String, Vec<NodeId>>,
192 out_edges: HashMap<NodeId, Vec<Edge>>,
193 in_edges: HashMap<NodeId, Vec<Edge>>,
194 next_id: NodeId,
195}
196
197impl SymbolGraph {
198 pub fn new() -> Self {
200 Self {
201 next_id: 1,
202 ..Self::default()
203 }
204 }
205
206 pub fn node_count(&self) -> usize {
208 self.nodes.len()
209 }
210
211 pub fn edge_count(&self) -> usize {
213 self.out_edges.values().map(Vec::len).sum()
214 }
215
216 pub fn node(&self, id: NodeId) -> Option<&Node> {
218 self.nodes.get(&id)
219 }
220
221 pub fn iter_nodes(&self) -> impl Iterator<Item = &Node> {
223 self.nodes.values()
224 }
225
226 pub fn nodes_of_kind(&self, kind: NodeKind) -> Vec<NodeId> {
229 let mut out: Vec<NodeId> = self
230 .nodes
231 .values()
232 .filter(|n| n.kind == kind)
233 .map(|n| n.id)
234 .collect();
235 out.sort_unstable();
236 out
237 }
238
239 pub fn all_node_ids(&self) -> Vec<NodeId> {
242 let mut out: Vec<NodeId> = self.nodes.keys().copied().collect();
243 out.sort_unstable();
244 out
245 }
246
247 pub fn nodes_named(&self, name: &str) -> &[NodeId] {
249 match self.by_name.get(name) {
250 Some(v) => v.as_slice(),
251 None => &[],
252 }
253 }
254
255 pub fn outgoing(&self, id: NodeId) -> &[Edge] {
257 self.out_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
258 }
259
260 pub fn incoming(&self, id: NodeId) -> &[Edge] {
262 self.in_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
263 }
264
265 pub fn file_ids(&self) -> Vec<FileId> {
267 let mut out: Vec<FileId> = self.by_file.keys().copied().collect();
268 out.sort_unstable();
269 out
270 }
271
272 pub fn remove_file(&mut self, file_id: FileId) {
274 let Some(node_ids) = self.by_file.remove(&file_id) else {
275 return;
276 };
277 for id in node_ids {
278 self.drop_node(id);
279 }
280 }
281
282 fn drop_node(&mut self, id: NodeId) {
283 let Some(node) = self.nodes.remove(&id) else {
284 return;
285 };
286 if let Some(bucket) = self.by_name.get_mut(&node.name) {
287 bucket.retain(|n| *n != id);
288 if bucket.is_empty() {
289 self.by_name.remove(&node.name);
290 }
291 }
292 if let Some(outs) = self.out_edges.remove(&id) {
293 for e in outs {
294 if let Some(bucket) = self.in_edges.get_mut(&e.to) {
295 bucket.retain(|edge| edge.from != id);
296 }
297 }
298 }
299 if let Some(ins) = self.in_edges.remove(&id) {
300 for e in ins {
301 if let Some(bucket) = self.out_edges.get_mut(&e.from) {
302 bucket.retain(|edge| edge.to != id);
303 }
304 }
305 }
306 }
307
308 pub fn rebuild_file(
315 &mut self,
316 file_id: FileId,
317 path: &str,
318 language: Language,
319 source: &str,
320 import_strings: &[String],
321 ) -> RebuildOutcome {
322 self.remove_file(file_id);
323 let module_id = self.add_module_for_file(file_id, path, &language);
324
325 let (tree, symbols) = match ast_api::parse_with_symbols(source, language) {
330 Ok((t, s)) => (Some(t), s),
331 Err(err) => {
332 tracing::debug!(
333 "code_index: tree-sitter parse failed for `{path}`: {err}; \
334 symbol graph slice will be Module-only"
335 );
336 (None, Vec::new())
337 }
338 };
339
340 let mut container_ids: HashMap<String, NodeId> = HashMap::new();
344 for sym in &symbols {
345 let Some(kind) = map_symbol_kind(sym.kind) else {
346 continue;
347 };
348 let id = self.add_node(Node {
349 id: 0,
350 kind,
351 name: sym.name.clone(),
352 file_id,
353 path: path.to_string(),
354 line: sym.start_row.saturating_add(1),
355 signature: sym.signature.clone(),
356 container: sym.container.clone(),
357 access_level: sym.access_level.clone(),
358 language: language.name().to_string(),
359 });
360 if matches!(kind, NodeKind::Type | NodeKind::Module) {
361 container_ids.insert(sym.name.clone(), id);
362 }
363 let parent_id = sym
364 .container
365 .as_deref()
366 .and_then(|c| container_ids.get(c).copied())
367 .unwrap_or(module_id);
368 self.add_edge(parent_id, id, EdgeKind::Contains);
369 }
370
371 if let Some(tree) = tree.as_ref() {
375 for (callee_name, line) in extract_call_sites_from_tree(tree, source) {
376 let call_id = self.add_node(Node {
377 id: 0,
378 kind: NodeKind::CallSite,
379 name: callee_name.clone(),
380 file_id,
381 path: path.to_string(),
382 line,
383 signature: format!("{callee_name}(…)"),
384 container: None,
385 access_level: None,
386 language: language.name().to_string(),
387 });
388 self.add_edge(module_id, call_id, EdgeKind::Contains);
389 let targets: Vec<NodeId> = self
390 .nodes_named(&callee_name)
391 .iter()
392 .copied()
393 .filter(|nid| {
394 self.nodes
395 .get(nid)
396 .is_some_and(|n| matches!(n.kind, NodeKind::Function))
397 })
398 .collect();
399 for t in targets {
400 self.add_edge(call_id, t, EdgeKind::Calls);
401 }
402 }
403 }
404
405 for raw in import_strings {
410 let imp_id = self.add_node(Node {
411 id: 0,
412 kind: NodeKind::Import,
413 name: raw.clone(),
414 file_id,
415 path: path.to_string(),
416 line: 1,
417 signature: format!("import {raw}"),
418 container: None,
419 access_level: None,
420 language: language.name().to_string(),
421 });
422 self.add_edge(module_id, imp_id, EdgeKind::Imports);
423 }
424
425 for target in self.collect_cross_file_refs(source, file_id) {
429 self.add_edge(module_id, target, EdgeKind::Refs);
430 }
431
432 let node_count = self.by_file.get(&file_id).map(Vec::len).unwrap_or_default();
433 RebuildOutcome {
434 node_count,
435 symbols,
436 }
437 }
438
439 pub fn link_imports(&mut self, resolved: &HashMap<FileId, Vec<FileId>>) {
444 for (src_file, targets) in resolved {
445 let Some(src_module) = self.module_node_for_file(*src_file) else {
446 continue;
447 };
448 for tgt_file in targets {
449 let Some(tgt_module) = self.module_node_for_file(*tgt_file) else {
450 continue;
451 };
452 let already_linked = self.out_edges.get(&src_module).is_some_and(|edges| {
459 edges
460 .iter()
461 .any(|e| e.to == tgt_module && e.kind == EdgeKind::Imports)
462 });
463 if !already_linked {
464 self.add_edge(src_module, tgt_module, EdgeKind::Imports);
465 }
466 }
467 }
468 }
469
470 pub fn module_node_for_file(&self, file_id: FileId) -> Option<NodeId> {
472 let ids = self.by_file.get(&file_id)?;
473 ids.iter().copied().find(|id| {
474 self.nodes
475 .get(id)
476 .is_some_and(|n| matches!(n.kind, NodeKind::Module))
477 })
478 }
479
480 fn collect_cross_file_refs(&self, source: &str, this_file: FileId) -> BTreeSet<NodeId> {
484 let mut out: BTreeSet<NodeId> = BTreeSet::new();
485 if self.by_name.is_empty() {
486 return out;
487 }
488 let mut word = String::with_capacity(32);
489 for ch in source.chars() {
490 if ch.is_alphanumeric() || ch == '_' {
491 word.push(ch);
492 } else if !word.is_empty() {
493 self.absorb_word_refs(&word, this_file, &mut out);
494 word.clear();
495 }
496 }
497 if !word.is_empty() {
498 self.absorb_word_refs(&word, this_file, &mut out);
499 }
500 out
501 }
502
503 fn absorb_word_refs(&self, word: &str, this_file: FileId, bag: &mut BTreeSet<NodeId>) {
504 if word.len() < 3 {
505 return;
506 }
507 let Some(ids) = self.by_name.get(word) else {
508 return;
509 };
510 for nid in ids {
511 let same_file = self.nodes.get(nid).is_some_and(|n| n.file_id == this_file);
512 if !same_file {
513 bag.insert(*nid);
514 }
515 }
516 }
517
518 fn add_module_for_file(&mut self, file_id: FileId, path: &str, language: &Language) -> NodeId {
519 let name = module_name_from_path(path);
520 self.add_node(Node {
521 id: 0,
522 kind: NodeKind::Module,
523 name,
524 file_id,
525 path: path.to_string(),
526 line: 1,
527 signature: format!("module {path}"),
528 container: None,
529 access_level: None,
530 language: language.name().to_string(),
531 })
532 }
533
534 fn add_node(&mut self, mut node: Node) -> NodeId {
535 let id = self.next_id;
536 self.next_id = self.next_id.checked_add(1).expect("NodeId overflow");
537 node.id = id;
538 self.by_file.entry(node.file_id).or_default().push(id);
539 self.by_name.entry(node.name.clone()).or_default().push(id);
540 self.nodes.insert(id, node);
541 id
542 }
543
544 fn add_edge(&mut self, from: NodeId, to: NodeId, kind: EdgeKind) {
545 let edge = Edge { from, to, kind };
546 self.out_edges.entry(from).or_default().push(edge);
547 self.in_edges.entry(to).or_default().push(edge);
548 }
549}
550
551pub fn module_name_from_path(path: &str) -> String {
554 let stem = path.rsplit_once('/').map(|(_, name)| name).unwrap_or(path);
555 let base = stem.rsplit_once('.').map(|(name, _)| name).unwrap_or(stem);
556 base.to_string()
557}
558
559fn map_symbol_kind(kind: SymbolKind) -> Option<NodeKind> {
560 match kind {
561 SymbolKind::Function | SymbolKind::Method => Some(NodeKind::Function),
562 SymbolKind::Field => Some(NodeKind::Field),
563 SymbolKind::EnumCase => Some(NodeKind::EnumCase),
564 SymbolKind::Class
565 | SymbolKind::Struct
566 | SymbolKind::Enum
567 | SymbolKind::Interface
568 | SymbolKind::Protocol
569 | SymbolKind::Type => Some(NodeKind::Type),
570 SymbolKind::Module => Some(NodeKind::Module),
571 SymbolKind::Variable | SymbolKind::Other => None,
572 }
573}
574
575fn extract_call_sites_from_tree(tree: &Tree, source: &str) -> Vec<(String, u32)> {
580 let mut out: Vec<(String, u32)> = Vec::new();
581 let mut cursor = tree.root_node().walk();
582 let mut stack: Vec<TsNode<'_>> = vec![tree.root_node()];
583 while let Some(node) = stack.pop() {
584 if is_call_kind(node.kind()) {
585 if let Some(name) = call_callee_name(node, source) {
586 let line = node.start_position().row as u32 + 1;
587 out.push((name, line));
588 }
589 }
590 for child in node.children(&mut cursor) {
591 stack.push(child);
592 }
593 }
594 out
595}
596
597fn is_call_kind(kind: &str) -> bool {
598 matches!(
599 kind,
600 "call_expression"
601 | "call"
602 | "function_call"
603 | "method_invocation"
604 | "method_call_expression"
605 | "invocation_expression"
606 | "function_call_expression"
607 | "macro_invocation"
608 )
609}
610
611fn call_callee_name(node: TsNode<'_>, source: &str) -> Option<String> {
612 let callee = node
613 .child_by_field_name("function")
614 .or_else(|| node.child_by_field_name("name"))
615 .or_else(|| node.child_by_field_name("method"))
616 .or_else(|| node.child(0u32))?;
617 let text = &source[callee.start_byte()..callee.end_byte()];
618 let last = text.rsplit_once(['.', ':', '!']);
619 let raw = last.map(|(_, name)| name).unwrap_or(text);
620 let trimmed = raw.trim();
621 let plain: String = trimmed
622 .chars()
623 .take_while(|c| c.is_alphanumeric() || *c == '_')
624 .collect();
625 if plain.is_empty() {
626 None
627 } else {
628 Some(plain)
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn add_and_remove_round_trip() {
638 let mut g = SymbolGraph::new();
639 let outcome = g.rebuild_file(1, "src/a.rs", Language::Rust, "fn foo() {}\n", &[]);
640 assert!(
641 outcome.node_count >= 2,
642 "module + function expected, got {}",
643 outcome.node_count
644 );
645 assert!(
646 outcome.symbols.iter().any(|s| s.name == "foo"),
647 "rebuild_file should surface the parsed `foo` symbol"
648 );
649 assert!(!g.nodes_named("foo").is_empty());
650 g.remove_file(1);
651 assert_eq!(g.node_count(), 0);
652 assert!(g.nodes_named("foo").is_empty());
653 }
654
655 #[test]
656 fn rebuild_file_emits_function_module_and_call_nodes() {
657 let mut g = SymbolGraph::new();
658 let src = "fn alpha() {}\nfn beta() { alpha(); }\n";
659 let outcome = g.rebuild_file(7, "src/x.rs", Language::Rust, src, &[]);
660 assert!(
661 outcome.node_count >= 3,
662 "expected module + 2 functions, got {}",
663 outcome.node_count
664 );
665 let alpha_funcs: Vec<_> = g
666 .iter_nodes()
667 .filter(|n| n.kind == NodeKind::Function && n.name == "alpha")
668 .collect();
669 assert_eq!(alpha_funcs.len(), 1);
670 let beta_funcs: Vec<_> = g
671 .iter_nodes()
672 .filter(|n| n.kind == NodeKind::Function && n.name == "beta")
673 .collect();
674 assert_eq!(beta_funcs.len(), 1);
675 let beta_calls: Vec<_> = g
676 .iter_nodes()
677 .filter(|n| n.kind == NodeKind::CallSite && n.name == "alpha")
678 .collect();
679 assert!(!beta_calls.is_empty(), "expected a CallSite for alpha()");
680 }
681
682 #[test]
683 fn rebuild_file_emits_fields_and_enum_cases() {
684 let mut g = SymbolGraph::new();
685 let src = "pub struct Greeter {\n pub name: String,\n}\n\nenum Color {\n Red,\n}\n";
686 g.rebuild_file(9, "src/lib.rs", Language::Rust, src, &[]);
687
688 let field = g
689 .iter_nodes()
690 .find(|n| n.kind == NodeKind::Field && n.name == "name")
691 .expect("expected public field node");
692 assert_eq!(field.container.as_deref(), Some("Greeter"));
693 assert_eq!(field.access_level.as_deref(), Some("public"));
694
695 let case = g
696 .iter_nodes()
697 .find(|n| n.kind == NodeKind::EnumCase && n.name == "Red")
698 .expect("expected enum case node");
699 assert_eq!(case.container.as_deref(), Some("Color"));
700
701 let color = g
702 .iter_nodes()
703 .find(|n| n.kind == NodeKind::Type && n.name == "Color")
704 .expect("expected enum type node");
705 assert!(
706 g.outgoing(color.id)
707 .iter()
708 .any(|edge| edge.kind == EdgeKind::Contains && edge.to == case.id),
709 "enum type should contain its case"
710 );
711 }
712
713 #[test]
714 fn called_by_inverse_label_resolves() {
715 let (kind, reversed) = EdgeKind::parse_with_direction("CALLED_BY").unwrap();
716 assert_eq!(kind, EdgeKind::Calls);
717 assert!(reversed);
718 let (kind, reversed) = EdgeKind::parse_with_direction("CALLS").unwrap();
719 assert_eq!(kind, EdgeKind::Calls);
720 assert!(!reversed);
721 }
722
723 #[test]
724 fn link_imports_creates_module_to_module_edges() {
725 let mut g = SymbolGraph::new();
726 g.rebuild_file(
727 1,
728 "src/a.ts",
729 Language::TypeScript,
730 "import { x } from \"./b\";\n",
731 &["./b".into()],
732 );
733 g.rebuild_file(
734 2,
735 "src/b.ts",
736 Language::TypeScript,
737 "export const x = 1;\n",
738 &[],
739 );
740 let mut resolved: HashMap<FileId, Vec<FileId>> = HashMap::new();
741 resolved.insert(1, vec![2]);
742 g.link_imports(&resolved);
743 let a_mod = g.module_node_for_file(1).unwrap();
744 let b_mod = g.module_node_for_file(2).unwrap();
745 let edge_exists = g
746 .outgoing(a_mod)
747 .iter()
748 .any(|e| e.kind == EdgeKind::Imports && e.to == b_mod);
749 assert!(edge_exists, "expected Module→Module IMPORTS edge");
750 }
751
752 #[test]
753 fn link_imports_is_idempotent_across_repeated_relinks() {
754 let mut g = SymbolGraph::new();
755 g.rebuild_file(
756 1,
757 "src/a.ts",
758 Language::TypeScript,
759 "import { x } from \"./b\";\n",
760 &["./b".into()],
761 );
762 g.rebuild_file(
763 2,
764 "src/b.ts",
765 Language::TypeScript,
766 "export const x = 1;\n",
767 &[],
768 );
769 let mut resolved: HashMap<FileId, Vec<FileId>> = HashMap::new();
770 resolved.insert(1, vec![2]);
771 g.link_imports(&resolved);
775 g.link_imports(&resolved);
776 g.link_imports(&resolved);
777 let a_mod = g.module_node_for_file(1).unwrap();
778 let b_mod = g.module_node_for_file(2).unwrap();
779 let module_import_edges = g
780 .outgoing(a_mod)
781 .iter()
782 .filter(|e| e.kind == EdgeKind::Imports && e.to == b_mod)
783 .count();
784 assert_eq!(
785 module_import_edges, 1,
786 "Module→Module IMPORTS edge must not duplicate across relinks"
787 );
788 }
789}