car-ast 0.15.0

Tree-sitter AST parsing for code-aware inference
Documentation
use super::{extract_doc_comment, node_text};
use crate::types::*;

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();

    let mut cursor = root.walk();
    for child in root.children(&mut cursor) {
        match child.kind() {
            "function_declaration" => {
                if let Some(sym) = extract_function(&child, source) {
                    symbols.push(sym);
                }
            }
            "variable_declaration" => {
                extract_var_decl(&child, source, &mut symbols, &mut imports);
            }
            "test_declaration" => {
                if let Some(sym) = extract_test(&child, source) {
                    symbols.push(sym);
                }
            }
            _ => {}
        }
    }

    (symbols, imports)
}

fn extract_function(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    // Children: optional `pub`, `fn`, identifier, parameters, return type, block
    let name = find_child_text(node, "identifier", source)?;

    Some(Symbol {
        name,
        kind: SymbolKind::Function,
        span: Span::from_node(node),
        signature: extract_zig_signature(node, source),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

fn extract_var_decl(
    node: &tree_sitter::Node,
    source: &[u8],
    symbols: &mut Vec<Symbol>,
    imports: &mut Vec<Import>,
) {
    let name = match find_child_text(node, "identifier", source) {
        Some(n) => n,
        None => return,
    };

    // Check if the value is an @import (builtin_function with @import)
    if has_child_kind(node, "builtin_function") {
        let text = node_text(node, source);
        if text.contains("@import") {
            let path = extract_import_path(text);
            imports.push(Import {
                path: path.unwrap_or_else(|| text.to_string()),
                alias: Some(name),
                span: Span::from_node(node),
            });
            return;
        }
    }

    // Determine kind based on the value node type
    let kind = if has_child_kind(node, "struct_declaration") {
        SymbolKind::Struct
    } else if has_child_kind(node, "enum_declaration") {
        SymbolKind::Enum
    } else {
        SymbolKind::Const
    };

    let text = node_text(node, source);
    symbols.push(Symbol {
        name,
        kind,
        span: Span::from_node(node),
        signature: first_line(text).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    });
}

fn extract_test(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    // test_declaration children: `test`, string (the name), block
    let name = find_child_string_content(node, source).unwrap_or_else(|| "test".to_string());

    let text = node_text(node, source);
    Some(Symbol {
        name,
        kind: SymbolKind::Function,
        span: Span::from_node(node),
        signature: first_line(text).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

/// Find the first child of a given kind and return its text.
fn find_child_text(node: &tree_sitter::Node, kind: &str, source: &[u8]) -> Option<String> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == kind {
            let t = node_text(&child, source);
            if !t.is_empty() {
                return Some(t.to_string());
            }
        }
    }
    None
}

/// Check if a node has a direct child of the given kind.
fn has_child_kind(node: &tree_sitter::Node, kind: &str) -> bool {
    let mut cursor = node.walk();
    let result = node.children(&mut cursor).any(|c| c.kind() == kind);
    result
}

/// Extract the string content from a `string` child node (for test names).
fn find_child_string_content(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "string" {
            // The string node contains `"`, string_content, `"`
            let mut inner = child.walk();
            for sc in child.children(&mut inner) {
                if sc.kind() == "string_content" {
                    return Some(node_text(&sc, source).to_string());
                }
            }
        }
    }
    None
}

fn extract_import_path(text: &str) -> Option<String> {
    // @import("path")
    let start = text.find("@import(\"")? + 9;
    let end = text[start..].find('"')? + start;
    Some(text[start..end].to_string())
}

/// Extract signature by finding the first `block` child node and taking text before it.
fn extract_zig_signature(node: &tree_sitter::Node, source: &[u8]) -> String {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "block" {
            let sig = &source[node.start_byte()..child.start_byte()];
            return std::str::from_utf8(sig).unwrap_or("").trim().to_string();
        }
    }
    node_text(node, source).to_string()
}

fn first_line(text: &str) -> &str {
    text.lines().next().unwrap_or(text)
}

#[cfg(test)]
mod tests {
    use crate::{parse, Language, SymbolKind};

    #[test]
    fn test_zig_extract() {
        let source = r#"const std = @import("std");
const os = @import("os");

const Point = struct {
    x: f32,
    y: f32,
};

const Color = enum {
    red,
    green,
    blue,
};

const MAX_SIZE: usize = 1024;

pub fn add(a: i32, b: i32) i32 {
    return a + b;
}

fn helper() void {
    std.debug.print("hello\n", .{});
}

test "addition works" {
    const result = add(2, 3);
    try std.testing.expectEqual(@as(i32, 5), result);
}
"#;
        let parsed = parse(source, Language::Zig).expect("Zig parse should succeed");

        // Check imports
        let import_paths: Vec<&str> = parsed.imports.iter().map(|i| i.path.as_str()).collect();
        assert!(
            import_paths.iter().any(|p| *p == "std"),
            "Should find std import, got: {:?}",
            import_paths
        );
        assert!(
            import_paths.iter().any(|p| *p == "os"),
            "Should find os import, got: {:?}",
            import_paths
        );
        assert_eq!(parsed.imports.len(), 2);

        // Check symbols
        let kinds: Vec<(&str, SymbolKind)> = parsed
            .symbols
            .iter()
            .map(|s| (s.name.as_str(), s.kind))
            .collect();

        // Functions
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "add" && *k == SymbolKind::Function),
            "Should find 'add' function, got: {:?}",
            kinds
        );
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "helper" && *k == SymbolKind::Function),
            "Should find 'helper' function, got: {:?}",
            kinds
        );

        // Struct
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "Point" && *k == SymbolKind::Struct),
            "Should find 'Point' struct, got: {:?}",
            kinds
        );

        // Enum
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "Color" && *k == SymbolKind::Enum),
            "Should find 'Color' enum, got: {:?}",
            kinds
        );

        // Const
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "MAX_SIZE" && *k == SymbolKind::Const),
            "Should find 'MAX_SIZE' const, got: {:?}",
            kinds
        );

        // Test
        assert!(
            kinds
                .iter()
                .any(|(n, k)| *n == "addition works" && *k == SymbolKind::Function),
            "Should find test 'addition works', got: {:?}",
            kinds
        );

        // Function signature should not include body
        let add_sym = parsed.symbols.iter().find(|s| s.name == "add").unwrap();
        assert!(
            add_sym.signature.contains("fn add"),
            "Signature should contain 'fn add': {}",
            add_sym.signature
        );
        assert!(
            !add_sym.signature.contains("return"),
            "Signature should not contain body: {}",
            add_sym.signature
        );
    }
}