car-ast 0.14.0

Tree-sitter AST parsing for code-aware inference
Documentation
use super::{extract_doc_comment, extract_signature, field_text, node_text};
use crate::types::*;

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();
    let mut cursor = root.walk();

    for child in root.children(&mut cursor) {
        match child.kind() {
            "function_definition" => {
                if let Some(sym) = extract_function(&child, source) {
                    symbols.push(sym);
                }
            }
            "declaration" => {
                extract_declaration(&child, source, &mut symbols);
            }
            "struct_specifier" => {
                if let Some(sym) = extract_struct(&child, source) {
                    symbols.push(sym);
                }
            }
            "enum_specifier" => {
                if let Some(sym) = extract_enum(&child, source) {
                    symbols.push(sym);
                }
            }
            "type_definition" => {
                if let Some(sym) = extract_typedef(&child, source) {
                    symbols.push(sym);
                }
            }
            "preproc_include" => {
                let path =
                    field_text(&child, "path", source).unwrap_or_else(|| node_text(&child, source));
                imports.push(Import {
                    path: path.to_string(),
                    alias: None,
                    span: Span::from_node(&child),
                });
            }
            "preproc_def" => {
                if let Some(sym) = extract_define(&child, source) {
                    symbols.push(sym);
                }
            }
            _ => {}
        }
    }

    (symbols, imports)
}

fn extract_function(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let declarator = node.child_by_field_name("declarator")?;
    let name = find_declarator_name(&declarator, source)?;
    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Function,
        span: Span::from_node(node),
        signature: extract_signature(node, "body", source),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

/// Walk into nested declarators to find the identifier name.
fn find_declarator_name<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
    match node.kind() {
        "identifier" | "type_identifier" | "field_identifier" => Some(node_text(node, source)),
        "function_declarator" | "pointer_declarator" | "parenthesized_declarator" => {
            if let Some(inner) = node.child_by_field_name("declarator") {
                find_declarator_name(&inner, source)
            } else {
                // function_declarator may have name directly
                node.child_by_field_name("name")
                    .map(|n| node_text(&n, source))
            }
        }
        _ => {
            // Try "declarator" field first, then "name"
            node.child_by_field_name("declarator")
                .and_then(|d| find_declarator_name(&d, source))
                .or_else(|| {
                    node.child_by_field_name("name")
                        .map(|n| node_text(&n, source))
                })
        }
    }
}

fn extract_declaration(node: &tree_sitter::Node, source: &[u8], symbols: &mut Vec<Symbol>) {
    // A top-level declaration can contain struct/enum specifiers or function prototypes.
    // Check if it contains a struct_specifier or enum_specifier with a name.
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        match child.kind() {
            "struct_specifier" => {
                if let Some(sym) = extract_struct(&child, source) {
                    symbols.push(sym);
                }
            }
            "enum_specifier" => {
                if let Some(sym) = extract_enum(&child, source) {
                    symbols.push(sym);
                }
            }
            _ => {}
        }
    }

    // Also extract named declarators as function prototypes or global variables
    if let Some(declarator) = node.child_by_field_name("declarator") {
        if let Some(name) = find_declarator_name(&declarator, source) {
            // Check if this is a function declaration (has parameter_list child)
            let is_func = declarator.kind() == "function_declarator";
            if is_func {
                symbols.push(Symbol {
                    name: name.to_string(),
                    kind: SymbolKind::Function,
                    span: Span::from_node(node),
                    signature: node_text(node, source)
                        .trim_end_matches(';')
                        .trim()
                        .to_string(),
                    doc_comment: extract_doc_comment(node, source),
                    parent: None,
                    children: Vec::new(),
                });
            }
        }
    }
}

fn extract_struct(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)?;
    if name.is_empty() {
        return None;
    }
    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Struct,
        span: Span::from_node(node),
        signature: node_text(node, source).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

fn extract_enum(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)?;
    if name.is_empty() {
        return None;
    }
    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Enum,
        span: Span::from_node(node),
        signature: node_text(node, source).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

fn extract_typedef(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    // typedef creates a type alias; the name is the last declarator
    // tree-sitter-c: type_definition has "type" and "declarator" fields
    let declarator = node.child_by_field_name("declarator")?;
    let name = find_declarator_name(&declarator, source)?;
    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::TypeAlias,
        span: Span::from_node(node),
        signature: node_text(node, source).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

fn extract_define(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)?;
    // Only extract #define without parameters (simple constants).
    // If the name node is followed by "(" immediately (macro with params), skip.
    let _name_node = node.child_by_field_name("name")?;
    // Check if there's a preproc_params child (function-like macro)
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "preproc_params" {
            return None; // Function-like macro, skip
        }
    }
    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Const,
        span: Span::from_node(node),
        signature: node_text(node, source).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

#[cfg(test)]
#[cfg(feature = "c")]
mod tests {
    use super::*;

    fn parse_c(source: &str) -> (Vec<Symbol>, Vec<Import>) {
        let mut parser = tree_sitter::Parser::new();
        parser
            .set_language(&tree_sitter_c::LANGUAGE.into())
            .expect("failed to set c language");
        let tree = parser.parse(source, None).expect("failed to parse");
        extract(&tree, source.as_bytes())
    }

    #[test]
    fn test_c_extraction() {
        let src = r#"
#include <stdio.h>
#include "myheader.h"

#define MAX_SIZE 100
#define SQUARE(x) ((x)*(x))

typedef unsigned long ulong;

struct Point {
    int x;
    int y;
};

enum Color { RED, GREEN, BLUE };

int add(int a, int b) {
    return a + b;
}

void greet(const char* name) {
    printf("Hello, %s\n", name);
}
"#;
        let (symbols, imports) = parse_c(src);

        // Includes
        assert_eq!(imports.len(), 2);
        assert!(imports[0].path.contains("stdio.h"));
        assert!(imports[1].path.contains("myheader.h"));

        // #define: only MAX_SIZE (SQUARE has params, skipped)
        let consts: Vec<_> = symbols
            .iter()
            .filter(|s| s.kind == SymbolKind::Const)
            .collect();
        assert_eq!(consts.len(), 1);
        assert_eq!(consts[0].name, "MAX_SIZE");

        // Typedef
        let aliases: Vec<_> = symbols
            .iter()
            .filter(|s| s.kind == SymbolKind::TypeAlias)
            .collect();
        assert_eq!(aliases.len(), 1);
        assert_eq!(aliases[0].name, "ulong");

        // Struct
        let structs: Vec<_> = symbols
            .iter()
            .filter(|s| s.kind == SymbolKind::Struct)
            .collect();
        assert_eq!(structs.len(), 1);
        assert_eq!(structs[0].name, "Point");

        // Enum
        let enums: Vec<_> = symbols
            .iter()
            .filter(|s| s.kind == SymbolKind::Enum)
            .collect();
        assert_eq!(enums.len(), 1);
        assert_eq!(enums[0].name, "Color");

        // Functions
        let funcs: Vec<_> = symbols
            .iter()
            .filter(|s| s.kind == SymbolKind::Function)
            .collect();
        assert!(funcs.len() >= 2);
        let func_names: Vec<_> = funcs.iter().map(|f| f.name.as_str()).collect();
        assert!(func_names.contains(&"add"));
        assert!(func_names.contains(&"greet"));
    }
}