gabb_cli/languages/
rust.rs

1use crate::languages::ImportBindingInfo;
2use crate::store::{normalize_path, EdgeRecord, FileDependency, ReferenceRecord, SymbolRecord};
3use anyhow::{Context, Result};
4use once_cell::sync::Lazy;
5use std::collections::{HashMap, HashSet};
6use std::fs;
7use std::path::Path;
8use tree_sitter::{Language, Node, Parser, TreeCursor};
9
10static RUST_LANGUAGE: Lazy<Language> = Lazy::new(|| tree_sitter_rust::LANGUAGE.into());
11
12#[derive(Clone, Debug)]
13struct SymbolBinding {
14    id: String,
15    qualifier: Option<String>,
16}
17
18impl From<&SymbolRecord> for SymbolBinding {
19    fn from(value: &SymbolRecord) -> Self {
20        Self {
21            id: value.id.clone(),
22            qualifier: value.qualifier.clone(),
23        }
24    }
25}
26
27#[derive(Clone, Debug)]
28struct ResolvedTarget {
29    id: String,
30    qualifier: Option<String>,
31}
32
33impl ResolvedTarget {
34    fn member_id(&self, member: &str) -> String {
35        if let Some(q) = &self.qualifier {
36            format!("{q}::{member}")
37        } else {
38            format!("{}::{member}", self.id)
39        }
40    }
41}
42
43/// Index a Rust file, returning symbols, edges, references, file dependencies, and import bindings.
44#[allow(clippy::type_complexity)]
45pub fn index_file(
46    path: &Path,
47    source: &str,
48) -> Result<(
49    Vec<SymbolRecord>,
50    Vec<EdgeRecord>,
51    Vec<ReferenceRecord>,
52    Vec<FileDependency>,
53    Vec<ImportBindingInfo>,
54)> {
55    let mut parser = Parser::new();
56    parser
57        .set_language(&RUST_LANGUAGE)
58        .context("failed to set Rust language")?;
59    let tree = parser
60        .parse(source, None)
61        .context("failed to parse Rust file")?;
62
63    let mut symbols = Vec::new();
64    let mut edges = Vec::new();
65    let mut declared_spans: HashSet<(usize, usize)> = HashSet::new();
66    let mut symbol_by_name: HashMap<String, SymbolBinding> = HashMap::new();
67
68    {
69        let mut cursor = tree.walk();
70        walk_symbols(
71            path,
72            source,
73            &mut cursor,
74            None,
75            &[],
76            None,
77            &mut symbols,
78            &mut edges,
79            &mut declared_spans,
80            &mut symbol_by_name,
81        );
82    }
83
84    let references = collect_references(
85        path,
86        source,
87        &tree.root_node(),
88        &declared_spans,
89        &symbol_by_name,
90    );
91
92    // Extract file dependencies and import bindings from mod and use declarations
93    let (dependencies, import_bindings) = collect_dependencies(path, source, &tree.root_node());
94
95    Ok((symbols, edges, references, dependencies, import_bindings))
96}
97
98/// Extract file dependencies from `mod` and `use` declarations.
99/// - `mod foo;` indicates dependency on foo.rs or foo/mod.rs
100/// - `use crate::foo::Bar;` indicates dependency on the foo module
101fn collect_dependencies(
102    path: &Path,
103    source: &str,
104    root: &Node,
105) -> (Vec<FileDependency>, Vec<ImportBindingInfo>) {
106    let mut dependencies = Vec::new();
107    let mut import_bindings = Vec::new();
108    let mut seen = HashSet::new();
109    let from_file = normalize_path(path);
110    let parent = path.parent();
111
112    // Find the crate root directory (where Cargo.toml or lib.rs/main.rs is)
113    let crate_root = find_crate_root(path);
114
115    let mut stack = vec![*root];
116    while let Some(node) = stack.pop() {
117        // Handle `mod foo;` declarations (without body)
118        if node.kind() == "mod_item" {
119            let has_body = node
120                .children(&mut node.walk())
121                .any(|c| c.kind() == "declaration_list");
122            if !has_body {
123                if let Some(name_node) = node.child_by_field_name("name") {
124                    let mod_name = slice(source, &name_node);
125                    let key = format!("mod:{}", mod_name);
126                    if !mod_name.is_empty() && !seen.contains(&key) {
127                        seen.insert(key);
128                        if let Some(to_file) = resolve_mod_path(parent, &mod_name) {
129                            dependencies.push(FileDependency {
130                                from_file: from_file.clone(),
131                                to_file,
132                                kind: "mod".to_string(),
133                            });
134                        }
135                    }
136                }
137            }
138        }
139
140        // Handle `use` declarations
141        if node.kind() == "use_declaration" {
142            if let Some(use_path) = extract_use_path(source, &node) {
143                // Only handle crate-local paths (crate::, super::, self::)
144                if let Some(resolved) = resolve_use_path(&use_path, path, crate_root.as_deref()) {
145                    let key = format!("use:{}", resolved);
146                    if !seen.contains(&key) {
147                        seen.insert(key.clone());
148                        dependencies.push(FileDependency {
149                            from_file: from_file.clone(),
150                            to_file: resolved.clone(),
151                            kind: "use".to_string(),
152                        });
153                    }
154                    // Extract import bindings for two-phase resolution
155                    let bindings = extract_use_bindings(source, &node, &resolved);
156                    import_bindings.extend(bindings);
157                }
158            }
159        }
160
161        let mut cursor = node.walk();
162        for child in node.children(&mut cursor) {
163            stack.push(child);
164        }
165    }
166
167    (dependencies, import_bindings)
168}
169
170/// Extract import bindings from a use declaration
171fn extract_use_bindings(source: &str, node: &Node, source_file: &str) -> Vec<ImportBindingInfo> {
172    let mut bindings = Vec::new();
173    let mut stack = vec![*node];
174
175    while let Some(n) = stack.pop() {
176        match n.kind() {
177            "use_as_clause" => {
178                // `use foo::bar as baz;` - bar aliased as baz
179                if let Some(path_node) = n.child_by_field_name("path") {
180                    let original_name = extract_last_path_segment(source, &path_node);
181                    if let Some(alias_node) = n.child_by_field_name("alias") {
182                        let local_name = slice(source, &alias_node);
183                        if !local_name.is_empty() && !original_name.is_empty() {
184                            bindings.push(ImportBindingInfo {
185                                local_name,
186                                source_file: source_file.to_string(),
187                                original_name,
188                            });
189                        }
190                    }
191                }
192            }
193            "scoped_identifier" | "identifier" => {
194                // Simple use without alias: local_name == original_name
195                let name = extract_last_path_segment(source, &n);
196                if !name.is_empty() && n.parent().map(|p| p.kind()) != Some("use_as_clause") {
197                    bindings.push(ImportBindingInfo {
198                        local_name: name.clone(),
199                        source_file: source_file.to_string(),
200                        original_name: name,
201                    });
202                }
203            }
204            _ => {
205                let mut cursor = n.walk();
206                for child in n.children(&mut cursor) {
207                    stack.push(child);
208                }
209            }
210        }
211    }
212
213    bindings
214}
215
216/// Extract the last segment of a path (e.g., "bar" from "foo::bar")
217fn extract_last_path_segment(source: &str, node: &Node) -> String {
218    let text = slice(source, node);
219    text.rsplit("::").next().unwrap_or(&text).to_string()
220}
221
222/// Resolve a mod declaration to a file path
223fn resolve_mod_path(parent: Option<&Path>, mod_name: &str) -> Option<String> {
224    let parent_dir = parent?;
225    let mod_file = parent_dir.join(format!("{}.rs", mod_name));
226    let mod_dir_file = parent_dir.join(mod_name).join("mod.rs");
227
228    if mod_file.exists() {
229        Some(normalize_path(&mod_file))
230    } else if mod_dir_file.exists() {
231        Some(normalize_path(&mod_dir_file))
232    } else {
233        // Use the expected path even if it doesn't exist
234        Some(normalize_path(&mod_file))
235    }
236}
237
238/// Extract the path from a use declaration
239fn extract_use_path(source: &str, node: &Node) -> Option<String> {
240    // Find the use path - could be scoped_identifier, identifier, or use_wildcard
241    let mut stack = vec![*node];
242    while let Some(n) = stack.pop() {
243        match n.kind() {
244            "scoped_identifier" | "identifier" | "scoped_use_list" => {
245                return Some(slice(source, &n));
246            }
247            _ => {
248                let mut cursor = n.walk();
249                for child in n.children(&mut cursor) {
250                    stack.push(child);
251                }
252            }
253        }
254    }
255    None
256}
257
258/// Resolve a use path to a file path
259/// Handles crate::, super::, and self:: prefixes
260fn resolve_use_path(
261    use_path: &str,
262    current_file: &Path,
263    crate_root: Option<&Path>,
264) -> Option<String> {
265    let parts: Vec<&str> = use_path.split("::").collect();
266    if parts.is_empty() {
267        return None;
268    }
269
270    let first = parts[0];
271    let parent = current_file.parent()?;
272
273    match first {
274        "crate" => {
275            // crate:: paths start from crate root
276            let root = crate_root?;
277            if parts.len() < 2 {
278                return None;
279            }
280            // Take the first module after crate::
281            let module_name = parts[1];
282            resolve_mod_path(Some(root), module_name)
283        }
284        "super" => {
285            // super:: paths go up one directory
286            let grandparent = parent.parent()?;
287            if parts.len() < 2 {
288                // Just `use super::*` - depend on parent mod.rs
289                let mod_file = grandparent.join("mod.rs");
290                if mod_file.exists() {
291                    return Some(normalize_path(&mod_file));
292                }
293                return None;
294            }
295            let module_name = parts[1];
296            resolve_mod_path(Some(grandparent), module_name)
297        }
298        "self" => {
299            // self:: paths are in current module - no external dependency
300            None
301        }
302        _ => {
303            // External crate or other - no local file dependency
304            None
305        }
306    }
307}
308
309/// Find the crate root directory (where src/lib.rs or src/main.rs is)
310fn find_crate_root(path: &Path) -> Option<std::path::PathBuf> {
311    let mut current = path.parent()?;
312
313    // Walk up looking for src directory with lib.rs or main.rs
314    for _ in 0..10 {
315        // Check if we're in a src directory
316        if current.file_name().and_then(|n| n.to_str()) == Some("src") {
317            return Some(current.to_path_buf());
318        }
319
320        // Check if there's a Cargo.toml here (we're at crate root)
321        if current.join("Cargo.toml").exists() {
322            let src = current.join("src");
323            if src.exists() {
324                return Some(src);
325            }
326            return Some(current.to_path_buf());
327        }
328
329        current = current.parent()?;
330    }
331    None
332}
333
334#[allow(clippy::too_many_arguments, clippy::only_used_in_recursion)]
335fn walk_symbols(
336    path: &Path,
337    source: &str,
338    cursor: &mut TreeCursor,
339    container: Option<String>,
340    module_path: &[String],
341    impl_trait: Option<ResolvedTarget>,
342    symbols: &mut Vec<SymbolRecord>,
343    edges: &mut Vec<EdgeRecord>,
344    declared_spans: &mut HashSet<(usize, usize)>,
345    symbol_by_name: &mut HashMap<String, SymbolBinding>,
346) {
347    loop {
348        let node = cursor.node();
349        match node.kind() {
350            "function_item" => {
351                if let Some(name_node) = node.child_by_field_name("name") {
352                    let name = slice(source, &name_node);
353                    let sym = make_symbol(
354                        path,
355                        module_path,
356                        &node,
357                        &name,
358                        "function",
359                        container.clone(),
360                        source.as_bytes(),
361                    );
362                    declared_spans.insert((sym.start as usize, sym.end as usize));
363                    symbol_by_name
364                        .entry(name.clone())
365                        .or_insert_with(|| SymbolBinding::from(&sym));
366                    if let Some(trait_target) = &impl_trait {
367                        edges.push(EdgeRecord {
368                            src: sym.id.clone(),
369                            dst: trait_target.member_id(&name),
370                            kind: "overrides".to_string(),
371                        });
372                    }
373                    if let Some(parent) = &container {
374                        if let Some(binding) = symbol_by_name.get(parent) {
375                            edges.push(EdgeRecord {
376                                src: sym.id.clone(),
377                                dst: binding.id.clone(),
378                                kind: "inherent_impl".to_string(),
379                            });
380                        }
381                    }
382                    symbols.push(sym);
383                }
384            }
385            "struct_item" => {
386                if let Some(name_node) = node.child_by_field_name("name") {
387                    let name = slice(source, &name_node);
388                    let sym = make_symbol(
389                        path,
390                        module_path,
391                        &node,
392                        &name,
393                        "struct",
394                        container.clone(),
395                        source.as_bytes(),
396                    );
397                    declared_spans.insert((sym.start as usize, sym.end as usize));
398                    symbol_by_name
399                        .entry(name)
400                        .or_insert_with(|| SymbolBinding::from(&sym));
401                    symbols.push(sym);
402                }
403            }
404            "enum_item" => {
405                if let Some(name_node) = node.child_by_field_name("name") {
406                    let name = slice(source, &name_node);
407                    let sym = make_symbol(
408                        path,
409                        module_path,
410                        &node,
411                        &name,
412                        "enum",
413                        container.clone(),
414                        source.as_bytes(),
415                    );
416                    declared_spans.insert((sym.start as usize, sym.end as usize));
417                    symbol_by_name
418                        .entry(name)
419                        .or_insert_with(|| SymbolBinding::from(&sym));
420                    symbols.push(sym);
421                }
422            }
423            "trait_item" => {
424                if let Some(name_node) = node.child_by_field_name("name") {
425                    let name = slice(source, &name_node);
426                    let sym = make_symbol(
427                        path,
428                        module_path,
429                        &node,
430                        &name,
431                        "trait",
432                        container.clone(),
433                        source.as_bytes(),
434                    );
435                    declared_spans.insert((sym.start as usize, sym.end as usize));
436                    symbol_by_name
437                        .entry(name)
438                        .or_insert_with(|| SymbolBinding::from(&sym));
439                    symbols.push(sym);
440                }
441            }
442            "mod_item" => {
443                if let Some(name_node) = node.child_by_field_name("name") {
444                    let name = slice(source, &name_node);
445                    let mut mod_path = module_path.to_vec();
446                    mod_path.push(name);
447                    if cursor.goto_first_child() {
448                        walk_symbols(
449                            path,
450                            source,
451                            cursor,
452                            container.clone(),
453                            &mod_path,
454                            None,
455                            symbols,
456                            edges,
457                            declared_spans,
458                            symbol_by_name,
459                        );
460                        cursor.goto_parent();
461                    }
462                    if cursor.goto_next_sibling() {
463                        continue;
464                    } else {
465                        break;
466                    }
467                }
468            }
469            _ => {}
470        }
471
472        if cursor.goto_first_child() {
473            let mut child_container = container.clone();
474            let mut child_trait = impl_trait.clone();
475            let child_modules = module_path.to_vec();
476            if node.kind() == "impl_item" {
477                let (ty, trait_target) =
478                    record_impl_edges(path, source, &node, module_path, symbol_by_name, edges);
479                child_container = ty.or(container.clone());
480                child_trait = trait_target;
481            }
482            walk_symbols(
483                path,
484                source,
485                cursor,
486                child_container,
487                &child_modules,
488                child_trait,
489                symbols,
490                edges,
491                declared_spans,
492                symbol_by_name,
493            );
494            cursor.goto_parent();
495        }
496
497        if !cursor.goto_next_sibling() {
498            break;
499        }
500    }
501}
502
503fn collect_references(
504    path: &Path,
505    source: &str,
506    root: &Node,
507    declared_spans: &HashSet<(usize, usize)>,
508    symbol_by_name: &HashMap<String, SymbolBinding>,
509) -> Vec<ReferenceRecord> {
510    let mut refs = Vec::new();
511    let mut stack = vec![*root];
512    let file = normalize_path(path);
513
514    while let Some(node) = stack.pop() {
515        if node.kind() == "identifier" {
516            let span = (node.start_byte(), node.end_byte());
517            if !declared_spans.contains(&span) {
518                let name = slice(source, &node);
519                if let Some(sym) = symbol_by_name.get(&name) {
520                    refs.push(ReferenceRecord {
521                        file: file.clone(),
522                        start: node.start_byte() as i64,
523                        end: node.end_byte() as i64,
524                        symbol_id: sym.id.clone(),
525                    });
526                }
527            }
528        }
529
530        let mut cursor = node.walk();
531        for child in node.children(&mut cursor) {
532            stack.push(child);
533        }
534    }
535
536    refs
537}
538
539fn record_impl_edges(
540    path: &Path,
541    source: &str,
542    node: &Node,
543    module_path: &[String],
544    symbol_by_name: &HashMap<String, SymbolBinding>,
545    edges: &mut Vec<EdgeRecord>,
546) -> (Option<String>, Option<ResolvedTarget>) {
547    let ty_name = node
548        .child_by_field_name("type")
549        .map(|ty| slice(source, &ty))
550        .filter(|s| !s.is_empty());
551    let trait_name = node
552        .child_by_field_name("trait")
553        .map(|tr| slice(source, &tr))
554        .filter(|s| !s.is_empty());
555
556    let mut trait_target = None;
557    if let (Some(ty), Some(tr)) = (ty_name.as_ref(), trait_name.as_ref()) {
558        let src = resolve_rust_name(
559            ty,
560            Some((node.start_byte(), node.end_byte())),
561            path,
562            module_path,
563            symbol_by_name,
564        );
565        let dst = resolve_rust_name(
566            tr,
567            Some((node.start_byte(), node.end_byte())),
568            path,
569            module_path,
570            symbol_by_name,
571        );
572        trait_target = Some(dst.clone());
573        edges.push(EdgeRecord {
574            src: src.id,
575            dst: dst.id,
576            kind: "trait_impl".to_string(),
577        });
578    }
579
580    (ty_name, trait_target)
581}
582
583fn resolve_rust_name(
584    name: &str,
585    span: Option<(usize, usize)>,
586    path: &Path,
587    module_path: &[String],
588    symbol_by_name: &HashMap<String, SymbolBinding>,
589) -> ResolvedTarget {
590    if let Some(binding) = symbol_by_name.get(name) {
591        return ResolvedTarget {
592            id: binding.id.clone(),
593            qualifier: binding.qualifier.clone(),
594        };
595    }
596    let prefix = module_prefix(path, module_path);
597    let id = match span {
598        Some((start, end)) => format!("{}#{}-{}", normalize_path(path), start, end),
599        None => format!("{prefix}::{name}"),
600    };
601    let qualifier = Some(prefix);
602    ResolvedTarget { id, qualifier }
603}
604
605fn module_prefix(path: &Path, module_path: &[String]) -> String {
606    let mut base = normalize_path(path);
607    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
608        let trim = ext.len() + 1;
609        if base.len() > trim {
610            base.truncate(base.len() - trim);
611        }
612    }
613    for segment in module_path {
614        base.push_str("::");
615        base.push_str(segment);
616    }
617    base
618}
619
620fn make_symbol(
621    path: &Path,
622    module_path: &[String],
623    node: &Node,
624    name: &str,
625    kind: &str,
626    container: Option<String>,
627    source: &[u8],
628) -> SymbolRecord {
629    let qualifier = Some(module_qualifier(path, module_path, &container));
630    let visibility = visibility(node, path);
631    let content_hash = super::compute_content_hash(source, node.start_byte(), node.end_byte());
632    SymbolRecord {
633        id: format!(
634            "{}#{}-{}",
635            normalize_path(path),
636            node.start_byte(),
637            node.end_byte()
638        ),
639        file: normalize_path(path),
640        kind: kind.to_string(),
641        name: name.to_string(),
642        start: node.start_byte() as i64,
643        end: node.end_byte() as i64,
644        qualifier,
645        visibility,
646        container,
647        content_hash,
648    }
649}
650
651fn module_qualifier(path: &Path, module_path: &[String], container: &Option<String>) -> String {
652    let mut base = module_prefix(path, module_path);
653    if let Some(c) = container {
654        base.push_str("::");
655        base.push_str(c);
656    }
657    base
658}
659
660fn visibility(node: &Node, path: &Path) -> Option<String> {
661    if let Some(vis) = node.child_by_field_name("visibility") {
662        let text = slice_file(path, &vis);
663        if !text.is_empty() {
664            return Some(text);
665        }
666    }
667    let mut cursor = node.walk();
668    for child in node.children(&mut cursor) {
669        if child.kind() == "visibility_modifier" || child.kind() == "pub" {
670            let text = slice_file(path, &child);
671            if !text.is_empty() {
672                return Some(text);
673            }
674        }
675    }
676    None
677}
678
679fn slice(source: &str, node: &Node) -> String {
680    let bytes = node.byte_range();
681    source
682        .get(bytes.clone())
683        .unwrap_or_default()
684        .trim()
685        .to_string()
686}
687
688fn slice_file(path: &Path, node: &Node) -> String {
689    // Best-effort visibility slice using the file contents; if missing, fall back to node text.
690    let source = fs::read_to_string(path).unwrap_or_default();
691    slice(&source, node)
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use std::fs;
698    use tempfile::tempdir;
699
700    #[test]
701    fn extracts_rust_symbols_and_visibility() {
702        let dir = tempdir().unwrap();
703        let path = dir.path().join("mod.rs");
704        let source = r#"
705            pub mod inner {
706                pub struct Thing;
707                impl Thing {
708                    pub fn make() {}
709                }
710            }
711        "#;
712        fs::write(&path, source).unwrap();
713
714        let (symbols, edges, _refs, _deps, _imports) = index_file(&path, source).unwrap();
715        let names: Vec<_> = symbols.iter().map(|s| s.name.as_str()).collect();
716        assert!(names.contains(&"Thing"));
717        assert!(names.contains(&"make"));
718
719        let thing = symbols.iter().find(|s| s.name == "Thing").unwrap();
720        assert_eq!(thing.visibility.as_deref(), Some("pub"));
721        assert!(thing.qualifier.as_deref().unwrap().contains("mod::inner"));
722
723        let make = symbols.iter().find(|s| s.name == "make").unwrap();
724        assert_eq!(make.kind, "function");
725        assert!(
726            edges.iter().any(|e| e.kind == "inherent_impl"),
727            "expected inherent_impl edge from make to Thing"
728        );
729    }
730
731    #[test]
732    fn captures_trait_impl_relationship() {
733        let dir = tempdir().unwrap();
734        let path = dir.path().join("impl.rs");
735        let source = r#"
736            trait Greeter {
737                fn greet(&self);
738            }
739            struct Person;
740            impl Greeter for Person {
741                fn greet(&self) {}
742            }
743        "#;
744        fs::write(&path, source).unwrap();
745
746        let (symbols, edges, _refs, _deps, _imports) = index_file(&path, source).unwrap();
747        let person = symbols.iter().find(|s| s.name == "Person").unwrap();
748        let greeter = symbols.iter().find(|s| s.name == "Greeter").unwrap();
749
750        assert!(symbols.iter().any(|s| s.name == "greet"));
751        let path_str = path.to_string_lossy();
752        assert!(person.id.starts_with(path_str.as_ref()));
753        assert!(greeter.id.starts_with(path_str.as_ref()));
754        assert!(
755            edges.iter().any(|e| e.kind == "trait_impl"),
756            "expected trait_impl edge"
757        );
758        assert!(
759            edges.iter().any(|e| e.kind == "overrides"),
760            "expected method overrides edges for trait methods"
761        );
762    }
763}