car-ast 0.8.0

Tree-sitter AST parsing for code-aware inference
Documentation
use crate::types::*;
use super::{node_text, extract_doc_comment};

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();
    let mut cursor = root.walk();

    for child in root.children(&mut cursor) {
        match child.kind() {
            "binary_operator" => {
                // In tree-sitter-r, `x <- value` and `x = value` parse as binary_operator
                extract_binary_op(&child, source, &mut symbols, &mut imports);
            }
            "call" => {
                // Top-level library() or source() calls
                extract_call_import(&child, source, &mut imports);
            }
            "left_assignment" | "equals_assignment" | "super_assignment" | "right_assignment" => {
                // Some tree-sitter-r versions use these node types
                extract_assignment(&child, source, &mut symbols, &mut imports);
            }
            "function_definition" => {
                if let Some(sym) = extract_bare_function(&child, source) {
                    symbols.push(sym);
                }
            }
            _ => {}
        }
    }

    (symbols, imports)
}

fn extract_binary_op(
    node: &tree_sitter::Node,
    source: &[u8],
    symbols: &mut Vec<Symbol>,
    _imports: &mut Vec<Import>,
) {
    // binary_operator: lhs op rhs — we care about <- and = assignments
    // child(0) = lhs, child(1) = operator, child(2) = rhs
    let op_node = match node.child(1) {
        Some(n) => n,
        None => return,
    };
    let op = node_text(&op_node, source);

    if op != "<-" && op != "<<-" && op != "=" && op != "->" && op != "->>" {
        return;
    }

    let (name_node, value_node) = if op == "->" || op == "->>" {
        (node.child(2), node.child(0))
    } else {
        (node.child(0), node.child(2))
    };

    let name_node = match name_node {
        Some(n) => n,
        None => return,
    };
    let value_node = match value_node {
        Some(v) => v,
        None => return,
    };

    let name = node_text(&name_node, source);
    if name.is_empty() {
        return;
    }

    // Check if RHS is a function definition
    if value_node.kind() == "function_definition" {
        let sig = format!("{} {} function{}", name, op, extract_params(&value_node, source));
        symbols.push(Symbol {
            name: name.to_string(),
            kind: SymbolKind::Function,
            span: Span::from_node(node),
            signature: sig,
            doc_comment: extract_doc_comment(node, source),
            parent: None,
            children: Vec::new(),
        });
        return;
    }

    // Check if RHS is a library/source call
    if value_node.kind() == "call" {
        if extract_call_import(&value_node, source, &mut Vec::new()) {
            return;
        }
    }

    // Otherwise it's a constant/variable assignment
    symbols.push(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Const,
        span: Span::from_node(node),
        signature: node_text(node, source).lines().next().unwrap_or("").to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    });
}

fn extract_assignment(
    node: &tree_sitter::Node,
    source: &[u8],
    symbols: &mut Vec<Symbol>,
    _imports: &mut Vec<Import>,
) {
    let (name_node, value_node) = match node.kind() {
        "right_assignment" => {
            (node.child((node.child_count().saturating_sub(1)) as u32), node.child(0))
        }
        _ => {
            (node.child(0), node.child((node.child_count().saturating_sub(1)) as u32))
        }
    };

    let name_node = match name_node {
        Some(n) => n,
        None => return,
    };
    let value_node = match value_node {
        Some(v) => v,
        None => return,
    };

    let name = node_text(&name_node, source);
    if name.is_empty() {
        return;
    }

    if value_node.kind() == "function_definition" {
        let sig = format!("{} <- function{}", name, extract_params(&value_node, source));
        symbols.push(Symbol {
            name: name.to_string(),
            kind: SymbolKind::Function,
            span: Span::from_node(node),
            signature: sig,
            doc_comment: extract_doc_comment(node, source),
            parent: None,
            children: Vec::new(),
        });
    } else {
        symbols.push(Symbol {
            name: name.to_string(),
            kind: SymbolKind::Const,
            span: Span::from_node(node),
            signature: node_text(node, source).lines().next().unwrap_or("").to_string(),
            doc_comment: extract_doc_comment(node, source),
            parent: None,
            children: Vec::new(),
        });
    }
}

/// Extract parameters from a function_definition node.
fn extract_params(node: &tree_sitter::Node, source: &[u8]) -> String {
    if let Some(params) = node.child_by_field_name("parameters") {
        return node_text(&params, source).to_string();
    }
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        let kind = child.kind();
        if kind == "formal_parameters" || kind == "parameters" {
            return node_text(&child, source).to_string();
        }
    }
    // Try extracting from text
    let text = node_text(node, source);
    if let Some(start) = text.find('(') {
        if let Some(end) = text.find(')') {
            return text[start..=end].to_string();
        }
    }
    "()".to_string()
}

/// Try to extract an import from a library() or source() call. Returns true if it was one.
fn extract_call_import(node: &tree_sitter::Node, source: &[u8], imports: &mut Vec<Import>) -> bool {
    // A call node: first child is the function name identifier
    let func_node = match node.child_by_field_name("function") {
        Some(n) => n,
        None => match node.child(0) {
            Some(n) => n,
            None => return false,
        },
    };

    let func_name = node_text(&func_node, source);

    if func_name == "library" || func_name == "require" || func_name == "source" {
        let arg = extract_first_arg(node, source);
        if let Some(path) = arg {
            imports.push(Import {
                path,
                alias: None,
                span: Span::from_node(node),
            });
            return true;
        }
    }

    false
}

/// Extract the first argument from a call node as a string.
fn extract_first_arg(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
    // Find the arguments child
    let args = node.child_by_field_name("arguments").or_else(|| {
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if child.kind() == "arguments" {
                return Some(child);
            }
        }
        None
    })?;

    let mut cursor = args.walk();
    for child in args.children(&mut cursor) {
        let kind = child.kind();
        if kind == "(" || kind == ")" || kind == "," {
            continue;
        }
        // The argument node wraps the actual value
        let text = if kind == "argument" {
            // Descend into the argument to get the identifier or string
            let mut inner = child.walk();
            let mut val = String::new();
            for arg_child in child.children(&mut inner) {
                let ak = arg_child.kind();
                if ak != "=" && ak != "(" && ak != ")" {
                    val = node_text(&arg_child, source).to_string();
                }
            }
            val
        } else {
            node_text(&child, source).to_string()
        };
        let stripped = text.trim_matches('"').trim_matches('\'').to_string();
        if !stripped.is_empty() {
            return Some(stripped);
        }
    }
    None
}

fn extract_bare_function(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    Some(Symbol {
        name: "<anonymous>".to_string(),
        kind: SymbolKind::Function,
        span: Span::from_node(node),
        signature: node_text(node, source).lines().next().unwrap_or("").to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    fn parse_r(source: &str) -> (Vec<Symbol>, Vec<Import>) {
        let mut parser = tree_sitter::Parser::new();
        parser.set_language(&tree_sitter_r::LANGUAGE.into()).unwrap();
        let tree = parser.parse(source, None).unwrap();
        extract(&tree, source.as_bytes())
    }

    #[test]
    fn test_r_basics() {
        let source = r#"library(dplyr)
source("helpers.R")

MAX_SIZE <- 100

add <- function(x, y) {
  x + y
}

multiply = function(a, b) {
  a * b
}

greeting <- "hello"
"#;
        let (symbols, imports) = parse_r(source);

        // Imports
        assert!(imports.len() >= 2, "Expected at least 2 imports, got {}: {:?}", imports.len(), imports);
        assert!(imports.iter().any(|i| i.path == "dplyr"));
        assert!(imports.iter().any(|i| i.path == "helpers.R"));

        // Functions
        let add = symbols.iter().find(|s| s.name == "add");
        assert!(add.is_some(), "Expected add function, symbols: {:?}", symbols.iter().map(|s| (&s.name, &s.kind)).collect::<Vec<_>>());
        assert_eq!(add.unwrap().kind, SymbolKind::Function);

        let multiply = symbols.iter().find(|s| s.name == "multiply");
        assert!(multiply.is_some(), "Expected multiply function");
        assert_eq!(multiply.unwrap().kind, SymbolKind::Function);

        // Constants
        let max_size = symbols.iter().find(|s| s.name == "MAX_SIZE");
        assert!(max_size.is_some(), "Expected MAX_SIZE constant");
        assert_eq!(max_size.unwrap().kind, SymbolKind::Const);

        let greeting = symbols.iter().find(|s| s.name == "greeting");
        assert!(greeting.is_some(), "Expected greeting constant");
        assert_eq!(greeting.unwrap().kind, SymbolKind::Const);
    }
}