car-ast 0.8.0

Tree-sitter AST parsing for code-aware inference
Documentation
use crate::types::*;
use super::{node_text, field_text, extract_doc_comment};

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();
    let mut cursor = root.walk();

    for child in root.children(&mut cursor) {
        extract_node(&child, source, &mut symbols, &mut imports, None);
    }

    (symbols, imports)
}

fn extract_node(
    node: &tree_sitter::Node,
    source: &[u8],
    symbols: &mut Vec<Symbol>,
    imports: &mut Vec<Import>,
    parent_name: Option<&str>,
) {
    match node.kind() {
        "function_declaration" => {
            if let Some(sym) = extract_function(node, source, parent_name) {
                symbols.push(sym);
            }
        }
        "class_declaration" => {
            // Kotlin grammar uses class_declaration for class, interface, and enum.
            // Detect by checking the first keyword child.
            let kind = detect_class_kind(node);
            if let Some(sym) = extract_class(node, source, kind, parent_name) {
                symbols.push(sym);
            }
        }
        "object_declaration" => {
            if let Some(sym) = extract_class(node, source, SymbolKind::Class, parent_name) {
                symbols.push(sym);
            }
        }
        "property_declaration" => {
            if let Some(sym) = extract_property(node, source, parent_name) {
                symbols.push(sym);
            }
        }
        "import" => {
            let text = node_text(node, source);
            let path = text.trim_start_matches("import").trim().trim_end_matches(';').trim();
            imports.push(Import {
                path: path.to_string(),
                alias: None,
                span: Span::from_node(node),
            });
        }
        "package_header" => {
            let text = node_text(node, source);
            let path = text.trim_start_matches("package").trim().trim_end_matches(';').trim();
            if !path.is_empty() {
                symbols.push(Symbol {
                    name: path.to_string(),
                    kind: SymbolKind::Module,
                    span: Span::from_node(node),
                    signature: text.trim().to_string(),
                    doc_comment: None,
                    parent: None,
                    children: Vec::new(),
                });
            }
        }
        _ => {}
    }
}

/// Detect whether a class_declaration is a class, interface, or enum class.
fn detect_class_kind(node: &tree_sitter::Node) -> SymbolKind {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        match child.kind() {
            "interface" => return SymbolKind::Interface,
            "enum" => return SymbolKind::Enum,
            "class" => return SymbolKind::Class,
            _ => {}
        }
    }
    SymbolKind::Class
}

/// Find the first `identifier` child of a node.
fn first_identifier<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "identifier" {
            return Some(node_text(&child, source));
        }
    }
    None
}

fn extract_function(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = field_text(node, "name", source)
        .or_else(|| first_identifier(node, source))?;
    let kind = if parent_name.is_some() {
        SymbolKind::Method
    } else {
        SymbolKind::Function
    };

    // Build signature: everything before the function_body
    let signature = if let Some(body) = node.child_by_field_name("body") {
        let sig = &source[node.start_byte()..body.start_byte()];
        std::str::from_utf8(sig).unwrap_or("").trim().to_string()
    } else {
        // Check for "function_body" child by kind
        let mut cursor = node.walk();
        let mut sig = node_text(node, source).to_string();
        for child in node.children(&mut cursor) {
            if child.kind() == "function_body" {
                let s = &source[node.start_byte()..child.start_byte()];
                sig = std::str::from_utf8(s).unwrap_or("").trim().to_string();
                break;
            }
        }
        sig
    };

    Some(Symbol {
        name: name.to_string(),
        kind,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children: Vec::new(),
    })
}

fn extract_class(
    node: &tree_sitter::Node,
    source: &[u8],
    kind: SymbolKind,
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = field_text(node, "name", source)
        .or_else(|| first_identifier(node, source))?;

    // Signature: everything before the class_body
    let signature = {
        let mut cursor = node.walk();
        let mut sig = node_text(node, source).to_string();
        for child in node.children(&mut cursor) {
            if child.kind() == "class_body" || child.kind() == "enum_class_body" {
                let s = &source[node.start_byte()..child.start_byte()];
                sig = std::str::from_utf8(s).unwrap_or("").trim().to_string();
                break;
            }
        }
        sig
    };

    // Extract children from class_body
    let mut children = Vec::new();
    let mut child_imports = Vec::new();
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "class_body" || child.kind() == "enum_class_body" {
            let mut inner = child.walk();
            for body_child in child.children(&mut inner) {
                extract_node(&body_child, source, &mut children, &mut child_imports, Some(name));
            }
        }
    }

    Some(Symbol {
        name: name.to_string(),
        kind,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children,
    })
}

fn extract_property(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    // In kotlin-ng grammar, property_declaration has:
    //   val/var > variable_declaration > identifier
    let name = field_text(node, "name", source)
        .or_else(|| {
            let mut cursor = node.walk();
            for child in node.children(&mut cursor) {
                if child.kind() == "variable_declaration" {
                    // Look for identifier inside variable_declaration
                    let mut inner = child.walk();
                    for inner_child in child.children(&mut inner) {
                        if inner_child.kind() == "identifier" {
                            return Some(node_text(&inner_child, source));
                        }
                    }
                }
                if child.kind() == "identifier" {
                    return Some(node_text(&child, source));
                }
            }
            None
        })?;

    if name.is_empty() {
        return None;
    }

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Const,
        span: Span::from_node(node),
        signature: node_text(node, source).lines().next().unwrap_or("").trim().to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children: Vec::new(),
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    fn parse_kotlin(source: &str) -> (Vec<Symbol>, Vec<Import>) {
        let mut parser = tree_sitter::Parser::new();
        parser.set_language(&tree_sitter_kotlin_ng::LANGUAGE.into()).unwrap();
        let tree = parser.parse(source, None).unwrap();
        extract(&tree, source.as_bytes())
    }

    #[test]
    fn test_kotlin_class_with_methods() {
        let source = r#"
package com.example.app

import kotlin.collections.List
import kotlin.math.sqrt

class Calculator(val initial: Int) {
    val name: String = "calc"

    fun add(x: Int): Int {
        return initial + x
    }

    fun multiply(x: Int, y: Int): Int {
        return x * y
    }
}
"#;
        let (symbols, imports) = parse_kotlin(source);

        assert_eq!(imports.len(), 2);
        assert_eq!(imports[0].path, "kotlin.collections.List");
        assert_eq!(imports[1].path, "kotlin.math.sqrt");

        let pkg = symbols.iter().find(|s| s.kind == SymbolKind::Module);
        assert!(pkg.is_some());
        assert_eq!(pkg.unwrap().name, "com.example.app");

        let calc = symbols.iter().find(|s| s.name == "Calculator").unwrap();
        assert_eq!(calc.kind, SymbolKind::Class);

        let child_names: Vec<&str> = calc.children.iter().map(|s| s.name.as_str()).collect();
        assert!(child_names.contains(&"add"), "expected 'add' in {:?}", child_names);
        assert!(child_names.contains(&"multiply"), "expected 'multiply' in {:?}", child_names);
    }

    #[test]
    fn test_kotlin_interface_and_object() {
        let source = r#"
interface Drawable {
    fun draw()
}

object Singleton {
    val instance: String = "single"

    fun doWork(): Boolean {
        return true
    }
}
"#;
        let (symbols, _imports) = parse_kotlin(source);

        let drawable = symbols.iter().find(|s| s.name == "Drawable").unwrap();
        assert_eq!(drawable.kind, SymbolKind::Interface);

        let singleton = symbols.iter().find(|s| s.name == "Singleton").unwrap();
        assert_eq!(singleton.kind, SymbolKind::Class);

        let method_names: Vec<&str> = singleton.children.iter()
            .filter(|s| s.kind == SymbolKind::Method)
            .map(|s| s.name.as_str())
            .collect();
        assert!(method_names.contains(&"doWork"), "expected 'doWork' in {:?}", method_names);
    }

    #[test]
    fn test_kotlin_top_level_function_and_property() {
        let source = r#"
val PI: Double = 3.14159

fun greet(name: String): String {
    return "Hello, $name"
}
"#;
        let (symbols, _imports) = parse_kotlin(source);

        let pi = symbols.iter().find(|s| s.name == "PI");
        assert!(pi.is_some(), "expected PI in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
        assert_eq!(pi.unwrap().kind, SymbolKind::Const);

        let greet = symbols.iter().find(|s| s.name == "greet").unwrap();
        assert_eq!(greet.kind, SymbolKind::Function);
    }
}