1use std::collections::{BTreeSet, HashMap};
15
16use tree_sitter::{Node as TsNode, Tree};
17
18use crate::ast::{api as ast_api, Language, Symbol, SymbolKind};
19
20use super::file_table::FileId;
21
22pub type NodeId = u32;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum NodeKind {
30 Function,
32 Type,
34 Module,
36 Import,
38 CallSite,
40 Macro,
42}
43
44impl NodeKind {
45 pub fn as_str(self) -> &'static str {
47 match self {
48 NodeKind::Function => "Function",
49 NodeKind::Type => "Type",
50 NodeKind::Module => "Module",
51 NodeKind::Import => "Import",
52 NodeKind::CallSite => "CallSite",
53 NodeKind::Macro => "Macro",
54 }
55 }
56
57 pub fn parse(label: &str) -> Option<Self> {
59 match label {
60 "Function" => Some(NodeKind::Function),
61 "Type" => Some(NodeKind::Type),
62 "Module" => Some(NodeKind::Module),
63 "Import" => Some(NodeKind::Import),
64 "CallSite" => Some(NodeKind::CallSite),
65 "Macro" => Some(NodeKind::Macro),
66 _ => None,
67 }
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum EdgeKind {
74 Calls,
76 Refs,
78 Imports,
80 Contains,
82 Overrides,
84}
85
86impl EdgeKind {
87 pub fn as_str(self) -> &'static str {
89 match self {
90 EdgeKind::Calls => "CALLS",
91 EdgeKind::Refs => "REFS",
92 EdgeKind::Imports => "IMPORTS",
93 EdgeKind::Contains => "CONTAINS",
94 EdgeKind::Overrides => "OVERRIDES",
95 }
96 }
97
98 pub fn parse_with_direction(label: &str) -> Option<(Self, bool)> {
102 if let Some(kind) = forward_match(label) {
103 return Some((kind, false));
104 }
105 match label {
106 "CALLED_BY" => Some((EdgeKind::Calls, true)),
107 "REFERENCED_BY" => Some((EdgeKind::Refs, true)),
108 "IMPORTED_BY" => Some((EdgeKind::Imports, true)),
109 "CONTAINED_BY" => Some((EdgeKind::Contains, true)),
110 "OVERRIDDEN_BY" => Some((EdgeKind::Overrides, true)),
111 _ => None,
112 }
113 }
114}
115
116fn forward_match(label: &str) -> Option<EdgeKind> {
117 match label {
118 "CALLS" => Some(EdgeKind::Calls),
119 "REFS" => Some(EdgeKind::Refs),
120 "IMPORTS" => Some(EdgeKind::Imports),
121 "CONTAINS" => Some(EdgeKind::Contains),
122 "OVERRIDES" => Some(EdgeKind::Overrides),
123 _ => None,
124 }
125}
126
127#[derive(Debug, Clone)]
130pub struct Node {
131 pub id: NodeId,
133 pub kind: NodeKind,
135 pub name: String,
138 pub file_id: FileId,
140 pub path: String,
142 pub line: u32,
144 pub signature: String,
146 pub container: Option<String>,
148 pub language: String,
150}
151
152#[derive(Debug, Clone, Copy)]
154pub struct Edge {
155 pub from: NodeId,
157 pub to: NodeId,
159 pub kind: EdgeKind,
161}
162
163#[derive(Debug, Clone, Default)]
167pub struct RebuildOutcome {
168 pub node_count: usize,
171 pub symbols: Vec<Symbol>,
174}
175
176#[derive(Debug, Default, Clone)]
178pub struct SymbolGraph {
179 nodes: HashMap<NodeId, Node>,
180 by_file: HashMap<FileId, Vec<NodeId>>,
181 by_name: HashMap<String, Vec<NodeId>>,
182 out_edges: HashMap<NodeId, Vec<Edge>>,
183 in_edges: HashMap<NodeId, Vec<Edge>>,
184 next_id: NodeId,
185}
186
187impl SymbolGraph {
188 pub fn new() -> Self {
190 Self {
191 next_id: 1,
192 ..Self::default()
193 }
194 }
195
196 pub fn node_count(&self) -> usize {
198 self.nodes.len()
199 }
200
201 pub fn edge_count(&self) -> usize {
203 self.out_edges.values().map(Vec::len).sum()
204 }
205
206 pub fn node(&self, id: NodeId) -> Option<&Node> {
208 self.nodes.get(&id)
209 }
210
211 pub fn iter_nodes(&self) -> impl Iterator<Item = &Node> {
213 self.nodes.values()
214 }
215
216 pub fn nodes_of_kind(&self, kind: NodeKind) -> Vec<NodeId> {
219 let mut out: Vec<NodeId> = self
220 .nodes
221 .values()
222 .filter(|n| n.kind == kind)
223 .map(|n| n.id)
224 .collect();
225 out.sort_unstable();
226 out
227 }
228
229 pub fn all_node_ids(&self) -> Vec<NodeId> {
232 let mut out: Vec<NodeId> = self.nodes.keys().copied().collect();
233 out.sort_unstable();
234 out
235 }
236
237 pub fn nodes_named(&self, name: &str) -> &[NodeId] {
239 match self.by_name.get(name) {
240 Some(v) => v.as_slice(),
241 None => &[],
242 }
243 }
244
245 pub fn outgoing(&self, id: NodeId) -> &[Edge] {
247 self.out_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
248 }
249
250 pub fn incoming(&self, id: NodeId) -> &[Edge] {
252 self.in_edges.get(&id).map(Vec::as_slice).unwrap_or(&[])
253 }
254
255 pub fn file_ids(&self) -> Vec<FileId> {
257 let mut out: Vec<FileId> = self.by_file.keys().copied().collect();
258 out.sort_unstable();
259 out
260 }
261
262 pub fn remove_file(&mut self, file_id: FileId) {
264 let Some(node_ids) = self.by_file.remove(&file_id) else {
265 return;
266 };
267 for id in node_ids {
268 self.drop_node(id);
269 }
270 }
271
272 fn drop_node(&mut self, id: NodeId) {
273 let Some(node) = self.nodes.remove(&id) else {
274 return;
275 };
276 if let Some(bucket) = self.by_name.get_mut(&node.name) {
277 bucket.retain(|n| *n != id);
278 if bucket.is_empty() {
279 self.by_name.remove(&node.name);
280 }
281 }
282 if let Some(outs) = self.out_edges.remove(&id) {
283 for e in outs {
284 if let Some(bucket) = self.in_edges.get_mut(&e.to) {
285 bucket.retain(|edge| edge.from != id);
286 }
287 }
288 }
289 if let Some(ins) = self.in_edges.remove(&id) {
290 for e in ins {
291 if let Some(bucket) = self.out_edges.get_mut(&e.from) {
292 bucket.retain(|edge| edge.to != id);
293 }
294 }
295 }
296 }
297
298 pub fn rebuild_file(
305 &mut self,
306 file_id: FileId,
307 path: &str,
308 language: Language,
309 source: &str,
310 import_strings: &[String],
311 ) -> RebuildOutcome {
312 self.remove_file(file_id);
313 let module_id = self.add_module_for_file(file_id, path, &language);
314
315 let (tree, symbols) = match ast_api::parse_with_symbols(source, language) {
320 Ok((t, s)) => (Some(t), s),
321 Err(err) => {
322 tracing::debug!(
323 "code_index: tree-sitter parse failed for `{path}`: {err}; \
324 symbol graph slice will be Module-only"
325 );
326 (None, Vec::new())
327 }
328 };
329
330 let mut container_ids: HashMap<String, NodeId> = HashMap::new();
334 for sym in &symbols {
335 let Some(kind) = map_symbol_kind(sym.kind) else {
336 continue;
337 };
338 let id = self.add_node(Node {
339 id: 0,
340 kind,
341 name: sym.name.clone(),
342 file_id,
343 path: path.to_string(),
344 line: sym.start_row.saturating_add(1),
345 signature: sym.signature.clone(),
346 container: sym.container.clone(),
347 language: language.name().to_string(),
348 });
349 if matches!(kind, NodeKind::Type | NodeKind::Module) {
350 container_ids.insert(sym.name.clone(), id);
351 }
352 let parent_id = sym
353 .container
354 .as_deref()
355 .and_then(|c| container_ids.get(c).copied())
356 .unwrap_or(module_id);
357 self.add_edge(parent_id, id, EdgeKind::Contains);
358 }
359
360 if let Some(tree) = tree.as_ref() {
364 for (callee_name, line) in extract_call_sites_from_tree(tree, source) {
365 let call_id = self.add_node(Node {
366 id: 0,
367 kind: NodeKind::CallSite,
368 name: callee_name.clone(),
369 file_id,
370 path: path.to_string(),
371 line,
372 signature: format!("{callee_name}(…)"),
373 container: None,
374 language: language.name().to_string(),
375 });
376 self.add_edge(module_id, call_id, EdgeKind::Contains);
377 let targets: Vec<NodeId> = self
378 .nodes_named(&callee_name)
379 .iter()
380 .copied()
381 .filter(|nid| {
382 self.nodes
383 .get(nid)
384 .is_some_and(|n| matches!(n.kind, NodeKind::Function))
385 })
386 .collect();
387 for t in targets {
388 self.add_edge(call_id, t, EdgeKind::Calls);
389 }
390 }
391 }
392
393 for raw in import_strings {
398 let imp_id = self.add_node(Node {
399 id: 0,
400 kind: NodeKind::Import,
401 name: raw.clone(),
402 file_id,
403 path: path.to_string(),
404 line: 1,
405 signature: format!("import {raw}"),
406 container: None,
407 language: language.name().to_string(),
408 });
409 self.add_edge(module_id, imp_id, EdgeKind::Imports);
410 }
411
412 for target in self.collect_cross_file_refs(source, file_id) {
416 self.add_edge(module_id, target, EdgeKind::Refs);
417 }
418
419 let node_count = self.by_file.get(&file_id).map(Vec::len).unwrap_or_default();
420 RebuildOutcome {
421 node_count,
422 symbols,
423 }
424 }
425
426 pub fn link_imports(&mut self, resolved: &HashMap<FileId, Vec<FileId>>) {
431 for (src_file, targets) in resolved {
432 let Some(src_module) = self.module_node_for_file(*src_file) else {
433 continue;
434 };
435 for tgt_file in targets {
436 let Some(tgt_module) = self.module_node_for_file(*tgt_file) else {
437 continue;
438 };
439 self.add_edge(src_module, tgt_module, EdgeKind::Imports);
440 }
441 }
442 }
443
444 pub fn module_node_for_file(&self, file_id: FileId) -> Option<NodeId> {
446 let ids = self.by_file.get(&file_id)?;
447 ids.iter().copied().find(|id| {
448 self.nodes
449 .get(id)
450 .is_some_and(|n| matches!(n.kind, NodeKind::Module))
451 })
452 }
453
454 fn collect_cross_file_refs(&self, source: &str, this_file: FileId) -> BTreeSet<NodeId> {
458 let mut out: BTreeSet<NodeId> = BTreeSet::new();
459 if self.by_name.is_empty() {
460 return out;
461 }
462 let mut word = String::with_capacity(32);
463 for ch in source.chars() {
464 if ch.is_alphanumeric() || ch == '_' {
465 word.push(ch);
466 } else if !word.is_empty() {
467 self.absorb_word_refs(&word, this_file, &mut out);
468 word.clear();
469 }
470 }
471 if !word.is_empty() {
472 self.absorb_word_refs(&word, this_file, &mut out);
473 }
474 out
475 }
476
477 fn absorb_word_refs(&self, word: &str, this_file: FileId, bag: &mut BTreeSet<NodeId>) {
478 if word.len() < 3 {
479 return;
480 }
481 let Some(ids) = self.by_name.get(word) else {
482 return;
483 };
484 for nid in ids {
485 let same_file = self.nodes.get(nid).is_some_and(|n| n.file_id == this_file);
486 if !same_file {
487 bag.insert(*nid);
488 }
489 }
490 }
491
492 fn add_module_for_file(&mut self, file_id: FileId, path: &str, language: &Language) -> NodeId {
493 let name = module_name_from_path(path);
494 self.add_node(Node {
495 id: 0,
496 kind: NodeKind::Module,
497 name,
498 file_id,
499 path: path.to_string(),
500 line: 1,
501 signature: format!("module {path}"),
502 container: None,
503 language: language.name().to_string(),
504 })
505 }
506
507 fn add_node(&mut self, mut node: Node) -> NodeId {
508 let id = self.next_id;
509 self.next_id = self.next_id.checked_add(1).expect("NodeId overflow");
510 node.id = id;
511 self.by_file.entry(node.file_id).or_default().push(id);
512 self.by_name.entry(node.name.clone()).or_default().push(id);
513 self.nodes.insert(id, node);
514 id
515 }
516
517 fn add_edge(&mut self, from: NodeId, to: NodeId, kind: EdgeKind) {
518 let edge = Edge { from, to, kind };
519 self.out_edges.entry(from).or_default().push(edge);
520 self.in_edges.entry(to).or_default().push(edge);
521 }
522}
523
524pub fn module_name_from_path(path: &str) -> String {
527 let stem = path.rsplit_once('/').map(|(_, name)| name).unwrap_or(path);
528 let base = stem.rsplit_once('.').map(|(name, _)| name).unwrap_or(stem);
529 base.to_string()
530}
531
532fn map_symbol_kind(kind: SymbolKind) -> Option<NodeKind> {
533 match kind {
534 SymbolKind::Function | SymbolKind::Method => Some(NodeKind::Function),
535 SymbolKind::Class
536 | SymbolKind::Struct
537 | SymbolKind::Enum
538 | SymbolKind::Interface
539 | SymbolKind::Protocol
540 | SymbolKind::Type => Some(NodeKind::Type),
541 SymbolKind::Module => Some(NodeKind::Module),
542 SymbolKind::Variable | SymbolKind::Other => None,
543 }
544}
545
546fn extract_call_sites_from_tree(tree: &Tree, source: &str) -> Vec<(String, u32)> {
551 let mut out: Vec<(String, u32)> = Vec::new();
552 let mut cursor = tree.root_node().walk();
553 let mut stack: Vec<TsNode<'_>> = vec![tree.root_node()];
554 while let Some(node) = stack.pop() {
555 if is_call_kind(node.kind()) {
556 if let Some(name) = call_callee_name(node, source) {
557 let line = node.start_position().row as u32 + 1;
558 out.push((name, line));
559 }
560 }
561 for child in node.children(&mut cursor) {
562 stack.push(child);
563 }
564 }
565 out
566}
567
568fn is_call_kind(kind: &str) -> bool {
569 matches!(
570 kind,
571 "call_expression"
572 | "call"
573 | "function_call"
574 | "method_invocation"
575 | "method_call_expression"
576 | "invocation_expression"
577 | "function_call_expression"
578 | "macro_invocation"
579 )
580}
581
582fn call_callee_name(node: TsNode<'_>, source: &str) -> Option<String> {
583 let callee = node
584 .child_by_field_name("function")
585 .or_else(|| node.child_by_field_name("name"))
586 .or_else(|| node.child_by_field_name("method"))
587 .or_else(|| node.child(0u32))?;
588 let text = &source[callee.start_byte()..callee.end_byte()];
589 let last = text.rsplit_once(['.', ':', '!']);
590 let raw = last.map(|(_, name)| name).unwrap_or(text);
591 let trimmed = raw.trim();
592 let plain: String = trimmed
593 .chars()
594 .take_while(|c| c.is_alphanumeric() || *c == '_')
595 .collect();
596 if plain.is_empty() {
597 None
598 } else {
599 Some(plain)
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn add_and_remove_round_trip() {
609 let mut g = SymbolGraph::new();
610 let outcome = g.rebuild_file(1, "src/a.rs", Language::Rust, "fn foo() {}\n", &[]);
611 assert!(
612 outcome.node_count >= 2,
613 "module + function expected, got {}",
614 outcome.node_count
615 );
616 assert!(
617 outcome.symbols.iter().any(|s| s.name == "foo"),
618 "rebuild_file should surface the parsed `foo` symbol"
619 );
620 assert!(!g.nodes_named("foo").is_empty());
621 g.remove_file(1);
622 assert_eq!(g.node_count(), 0);
623 assert!(g.nodes_named("foo").is_empty());
624 }
625
626 #[test]
627 fn rebuild_file_emits_function_module_and_call_nodes() {
628 let mut g = SymbolGraph::new();
629 let src = "fn alpha() {}\nfn beta() { alpha(); }\n";
630 let outcome = g.rebuild_file(7, "src/x.rs", Language::Rust, src, &[]);
631 assert!(
632 outcome.node_count >= 3,
633 "expected module + 2 functions, got {}",
634 outcome.node_count
635 );
636 let alpha_funcs: Vec<_> = g
637 .iter_nodes()
638 .filter(|n| n.kind == NodeKind::Function && n.name == "alpha")
639 .collect();
640 assert_eq!(alpha_funcs.len(), 1);
641 let beta_funcs: Vec<_> = g
642 .iter_nodes()
643 .filter(|n| n.kind == NodeKind::Function && n.name == "beta")
644 .collect();
645 assert_eq!(beta_funcs.len(), 1);
646 let beta_calls: Vec<_> = g
647 .iter_nodes()
648 .filter(|n| n.kind == NodeKind::CallSite && n.name == "alpha")
649 .collect();
650 assert!(!beta_calls.is_empty(), "expected a CallSite for alpha()");
651 }
652
653 #[test]
654 fn called_by_inverse_label_resolves() {
655 let (kind, reversed) = EdgeKind::parse_with_direction("CALLED_BY").unwrap();
656 assert_eq!(kind, EdgeKind::Calls);
657 assert!(reversed);
658 let (kind, reversed) = EdgeKind::parse_with_direction("CALLS").unwrap();
659 assert_eq!(kind, EdgeKind::Calls);
660 assert!(!reversed);
661 }
662
663 #[test]
664 fn link_imports_creates_module_to_module_edges() {
665 let mut g = SymbolGraph::new();
666 g.rebuild_file(
667 1,
668 "src/a.ts",
669 Language::TypeScript,
670 "import { x } from \"./b\";\n",
671 &["./b".into()],
672 );
673 g.rebuild_file(
674 2,
675 "src/b.ts",
676 Language::TypeScript,
677 "export const x = 1;\n",
678 &[],
679 );
680 let mut resolved: HashMap<FileId, Vec<FileId>> = HashMap::new();
681 resolved.insert(1, vec![2]);
682 g.link_imports(&resolved);
683 let a_mod = g.module_node_for_file(1).unwrap();
684 let b_mod = g.module_node_for_file(2).unwrap();
685 let edge_exists = g
686 .outgoing(a_mod)
687 .iter()
688 .any(|e| e.kind == EdgeKind::Imports && e.to == b_mod);
689 assert!(edge_exists, "expected Module→Module IMPORTS edge");
690 }
691}