car-ast 0.7.0

Tree-sitter AST parsing for code-aware inference
Documentation
//! Tree-sitter AST parsing for code-aware inference.
//!
//! Provides structured code understanding: parse source files into symbols
//! (functions, structs, classes, imports) with signatures, spans, and
//! doc comments. Used by car-reason for smart context assembly and by
//! car-inference for accurate code detection.

pub mod types;
pub mod languages;
pub mod index;
pub mod global_index;

pub use types::*;
pub use index::ProjectIndex;
pub use global_index::{GlobalHit, GlobalIndex, GlobalReference, ImplementationHit, RepoId, RepoProvenance};

use languages::{extract_symbols, ts_language};

/// Parse source code into a structured `ParsedFile`.
///
/// Returns `None` if the language is unsupported or the grammar feature
/// is not enabled.
pub fn parse(source: &str, lang: Language) -> Option<ParsedFile> {
    let ts_lang = ts_language(lang)?;

    let mut parser = tree_sitter::Parser::new();
    parser.set_language(&ts_lang).ok()?;

    let tree = parser.parse(source.as_bytes(), None)?;
    let (symbols, imports) = extract_symbols(lang, &tree, source.as_bytes());

    Some(ParsedFile {
        language: lang,
        symbols,
        imports,
    })
}

/// Parse source code, auto-detecting language from filename.
pub fn parse_file(source: &str, filename: &str) -> Option<ParsedFile> {
    let lang = Language::from_filename(filename)?;
    parse(source, lang)
}

/// Extract a symbol's source code from the original source.
pub fn extract_source(symbol: &Symbol, source: &str) -> String {
    if symbol.span.end_byte <= source.len() {
        source[symbol.span.start_byte..symbol.span.end_byte].to_string()
    } else {
        String::new()
    }
}

/// Extract context for a symbol: its source + references in the same file.
pub fn extract_context(symbol: &Symbol, _parsed: &ParsedFile, source: &str) -> SymbolContext {
    let sym_source = extract_source(symbol, source);
    let source_bytes = source.as_bytes();

    // Find references to this symbol name in the file
    let mut references = Vec::new();
    let name = &symbol.name;
    let mut search_from = 0;
    while let Some(pos) = source[search_from..].find(name) {
        let abs_pos = search_from + pos;
        // Skip the symbol's own definition
        if abs_pos < symbol.span.start_byte || abs_pos >= symbol.span.end_byte {
            // Verify it's a whole word (not substring of another identifier)
            let before = if abs_pos > 0 { source_bytes[abs_pos - 1] } else { b' ' };
            let after_pos = abs_pos + name.len();
            let after = if after_pos < source_bytes.len() { source_bytes[after_pos] } else { b' ' };

            if !is_ident_char(before) && !is_ident_char(after) {
                // Compute line/col
                let line = source[..abs_pos].matches('\n').count() as u32;
                let col = abs_pos - source[..abs_pos].rfind('\n').map(|p| p + 1).unwrap_or(0);
                let end_pos = abs_pos + name.len();
                let end_line = source[..end_pos].matches('\n').count() as u32;
                let end_col = end_pos - source[..end_pos].rfind('\n').map(|p| p + 1).unwrap_or(0);

                references.push(Span {
                    start_byte: abs_pos,
                    end_byte: end_pos,
                    start_line: line,
                    start_col: col as u32,
                    end_line: end_line,
                    end_col: end_col as u32,
                });
            }
        }
        search_from = abs_pos + 1;
    }

    SymbolContext {
        symbol: symbol.clone(),
        source: sym_source,
        references_in_file: references,
    }
}

/// Produce a compact source string containing only the given symbols,
/// separated by elision markers.
pub fn extract_symbols_source(symbols: &[&Symbol], source: &str) -> String {
    if symbols.is_empty() {
        return String::new();
    }

    let mut parts: Vec<(usize, String)> = symbols
        .iter()
        .map(|s| (s.span.start_byte, extract_source(s, source)))
        .collect();

    // Sort by position in file
    parts.sort_by_key(|(pos, _)| *pos);

    let mut result = String::new();
    for (i, (_, src)) in parts.iter().enumerate() {
        if i > 0 {
            result.push_str("\n\n// ...\n\n");
        }
        result.push_str(src);
    }
    result
}

/// Diff two parsed files to find structural changes.
pub fn diff_symbols(old: &ParsedFile, new: &ParsedFile) -> Vec<SymbolChange> {
    let mut changes = Vec::new();

    let old_syms = old.all_symbols();
    let new_syms = new.all_symbols();

    // Build name -> symbol maps
    let old_map: std::collections::HashMap<(&str, SymbolKind), &Symbol> = old_syms
        .iter()
        .map(|s| ((s.name.as_str(), s.kind), *s))
        .collect();
    let new_map: std::collections::HashMap<(&str, SymbolKind), &Symbol> = new_syms
        .iter()
        .map(|s| ((s.name.as_str(), s.kind), *s))
        .collect();

    // Find removed and modified
    for (key, old_sym) in &old_map {
        if let Some(new_sym) = new_map.get(key) {
            // Check if signature changed
            if old_sym.signature != new_sym.signature
                || old_sym.span.start_byte != new_sym.span.start_byte
                || old_sym.span.end_byte != new_sym.span.end_byte
            {
                changes.push(SymbolChange::Modified {
                    old: (*old_sym).clone(),
                    new: (*new_sym).clone(),
                    signature_changed: old_sym.signature != new_sym.signature,
                });
            }
        } else {
            changes.push(SymbolChange::Removed((*old_sym).clone()));
        }
    }

    // Find added
    for (key, new_sym) in &new_map {
        if !old_map.contains_key(key) {
            changes.push(SymbolChange::Added((*new_sym).clone()));
        }
    }

    changes
}

fn is_ident_char(b: u8) -> bool {
    b.is_ascii_alphanumeric() || b == b'_'
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_rust() {
        let source = r#"
use std::collections::HashMap;

/// A cool struct.
pub struct Foo {
    bar: String,
}

impl Foo {
    /// Create a new Foo.
    pub fn new(bar: String) -> Self {
        Self { bar }
    }

    pub fn bar(&self) -> &str {
        &self.bar
    }
}

fn helper() -> bool {
    true
}
"#;
        let parsed = parse(source, Language::Rust).unwrap();

        // Should have: Foo (struct), Foo (impl with 2 methods), helper (function)
        assert!(!parsed.symbols.is_empty());
        assert!(!parsed.imports.is_empty());

        // Find the struct
        let foo_structs = parsed.find_symbol("Foo");
        assert!(foo_structs.iter().any(|s| s.kind == SymbolKind::Struct));

        // Find the impl and its methods
        let foo_impls: Vec<_> = parsed.symbols.iter()
            .filter(|s| s.kind == SymbolKind::Impl && s.name == "Foo")
            .collect();
        assert_eq!(foo_impls.len(), 1);
        assert_eq!(foo_impls[0].children.len(), 2);
        assert_eq!(foo_impls[0].children[0].name, "new");
        assert_eq!(foo_impls[0].children[0].kind, SymbolKind::Method);

        // Find helper function
        let helpers = parsed.find_symbol("helper");
        assert_eq!(helpers.len(), 1);
        assert_eq!(helpers[0].kind, SymbolKind::Function);

        // Check import
        assert_eq!(parsed.imports.len(), 1);
        assert!(parsed.imports[0].path.contains("HashMap"));
    }

    #[test]
    fn test_parse_python() {
        let source = r#"
import os
from pathlib import Path

MAX_SIZE = 1024

class Parser:
    """A code parser."""

    def __init__(self, lang: str):
        self.lang = lang

    def parse(self, source: str) -> dict:
        return {}

def standalone():
    pass
"#;
        let parsed = parse(source, Language::Python).unwrap();

        assert!(!parsed.imports.is_empty());

        let classes = parsed.find_symbol("Parser");
        assert_eq!(classes.len(), 1);
        assert_eq!(classes[0].kind, SymbolKind::Class);
        assert_eq!(classes[0].children.len(), 2); // __init__ and parse

        let funcs = parsed.find_symbol("standalone");
        assert_eq!(funcs.len(), 1);

        let consts = parsed.find_symbol("MAX_SIZE");
        assert_eq!(consts.len(), 1);
        assert_eq!(consts[0].kind, SymbolKind::Const);
    }

    #[test]
    fn test_parse_file_detection() {
        let source = "fn main() {}";
        let parsed = parse_file(source, "main.rs").unwrap();
        assert_eq!(parsed.language, Language::Rust);

        assert!(parse_file(source, "unknown.xyz").is_none());
    }

    #[test]
    fn test_diff_symbols() {
        let old_source = r#"
fn foo() -> i32 { 1 }
fn bar() -> i32 { 2 }
"#;
        let new_source = r#"
fn foo() -> i64 { 1 }
fn baz() -> i32 { 3 }
"#;
        let old = parse(old_source, Language::Rust).unwrap();
        let new = parse(new_source, Language::Rust).unwrap();

        let changes = diff_symbols(&old, &new);

        // foo: modified (signature changed i32 -> i64)
        // bar: removed
        // baz: added
        assert!(changes.iter().any(|c| matches!(c, SymbolChange::Modified { old, .. } if old.name == "foo")));
        assert!(changes.iter().any(|c| matches!(c, SymbolChange::Removed(s) if s.name == "bar")));
        assert!(changes.iter().any(|c| matches!(c, SymbolChange::Added(s) if s.name == "baz")));
    }

    #[test]
    fn test_extract_context() {
        let source = r#"
fn helper() -> bool { true }

fn main() {
    let x = helper();
    println!("{}", x);
}
"#;
        let parsed = parse(source, Language::Rust).unwrap();
        let helper = &parsed.find_symbol("helper")[0];
        let ctx = extract_context(helper, &parsed, source);

        assert!(!ctx.source.is_empty());
        assert!(ctx.source.contains("fn helper()"));
        // Should find the reference in main()
        assert!(!ctx.references_in_file.is_empty());
    }

    #[test]
    fn test_extract_symbols_source() {
        let source = r#"
fn a() { 1 }

fn b() { 2 }

fn c() { 3 }
"#;
        let parsed = parse(source, Language::Rust).unwrap();
        let a = &parsed.find_symbol("a")[0];
        let c = &parsed.find_symbol("c")[0];

        let compact = extract_symbols_source(&[a, c], source);
        assert!(compact.contains("fn a()"));
        assert!(compact.contains("fn c()"));
        assert!(compact.contains("// ..."));
        assert!(!compact.contains("fn b()"));
    }
}