car-ast 0.6.0

Tree-sitter AST parsing for code-aware inference
Documentation
use crate::types::*;
use super::{node_text, field_text, extract_doc_comment};

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();
    let mut cursor = root.walk();

    for child in root.children(&mut cursor) {
        match child.kind() {
            "imports" => {
                extract_imports(&child, source, &mut imports);
            }
            "declarations" => {
                extract_declarations(&child, source, &mut symbols);
            }
            // Some tree-sitter-haskell versions put these at top level
            "import" => {
                extract_single_import(&child, source, &mut imports);
            }
            "function" | "signature" | "data_type" | "newtype" | "type_synomym"
            | "type_alias" | "class" | "instance" => {
                extract_decl(&child, source, &mut symbols);
            }
            _ => {}
        }
    }

    (symbols, imports)
}

fn extract_imports(node: &tree_sitter::Node, source: &[u8], imports: &mut Vec<Import>) {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "import" {
            extract_single_import(&child, source, imports);
        }
    }
}

fn extract_single_import(node: &tree_sitter::Node, source: &[u8], imports: &mut Vec<Import>) {
    // Find the module path from `module` child nodes
    let mut module_path = String::new();
    let mut alias = None;
    let mut found_as = false;
    let mut cursor = node.walk();

    for child in node.children(&mut cursor) {
        match child.kind() {
            "module" => {
                let mod_text = node_text(&child, source);
                if found_as {
                    alias = Some(mod_text.to_string());
                } else if module_path.is_empty() {
                    module_path = mod_text.to_string();
                }
            }
            "as" => {
                found_as = true;
            }
            _ => {}
        }
    }

    if !module_path.is_empty() {
        imports.push(Import {
            path: module_path,
            alias,
            span: Span::from_node(node),
        });
    }
}

fn extract_declarations(node: &tree_sitter::Node, source: &[u8], symbols: &mut Vec<Symbol>) {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        extract_decl(&child, source, symbols);
    }
}

fn extract_decl(node: &tree_sitter::Node, source: &[u8], symbols: &mut Vec<Symbol>) {
    match node.kind() {
        "function" => {
            if let Some(sym) = extract_function(node, source) {
                // Deduplicate: if we already have a function with this name, skip
                // (Haskell has multi-clause definitions that produce multiple `function` nodes)
                if !symbols.iter().any(|s| s.name == sym.name && s.kind == SymbolKind::Function) {
                    symbols.push(sym);
                }
            }
        }
        "signature" => {
            // Standalone type signatures — we skip these as they'll be associated
            // with the function that follows. But if no function follows, we could
            // capture them. For simplicity, skip.
        }
        "data_type" | "newtype" => {
            if let Some(sym) = extract_data_type(node, source) {
                symbols.push(sym);
            }
        }
        "type_synomym" | "type_alias" => {
            if let Some(sym) = extract_type_alias(node, source) {
                symbols.push(sym);
            }
        }
        "class" => {
            if let Some(sym) = extract_class(node, source) {
                symbols.push(sym);
            }
        }
        "instance" => {
            if let Some(sym) = extract_instance(node, source) {
                symbols.push(sym);
            }
        }
        _ => {}
    }
}

fn extract_function(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    // In tree-sitter-haskell, function nodes have a `variable` child as the name
    let name = field_text(node, "name", source)
        .or_else(|| first_child_text(node, source, &["variable", "identifier"]))?;

    // Look for a preceding type signature
    let sig = extract_type_signature(node, source)
        .unwrap_or_else(|| node_text(node, source).lines().next().unwrap_or("").to_string());

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Function,
        span: Span::from_node(node),
        signature: sig,
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

/// Look at the preceding sibling to find a type signature for this function.
fn extract_type_signature(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
    let mut sibling = node.prev_sibling();
    while let Some(s) = sibling {
        match s.kind() {
            "signature" => return Some(node_text(&s, source).to_string()),
            "comment" => {
                sibling = s.prev_sibling();
                continue;
            }
            _ => break,
        }
    }
    None
}

fn extract_data_type(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)
        .or_else(|| first_child_text(node, source, &["name", "type", "identifier"]))?;

    // Extract constructors as children
    let mut children = Vec::new();
    walk_constructors(node, source, name, &mut children);

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Enum,
        span: Span::from_node(node),
        signature: node_text(node, source).lines().next().unwrap_or("").to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children,
    })
}

fn walk_constructors(node: &tree_sitter::Node, source: &[u8], parent_name: &str, children: &mut Vec<Symbol>) {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        let kind = child.kind();
        if kind == "data_constructor" || kind == "data_constructor_record" {
            // The constructor name is inside a `prefix` > `constructor` or directly a `constructor`
            let con_name = find_constructor_name(&child, source);
            if let Some(con_name) = con_name {
                children.push(Symbol {
                    name: con_name.to_string(),
                    kind: SymbolKind::Const,
                    span: Span::from_node(&child),
                    signature: node_text(&child, source).to_string(),
                    doc_comment: None,
                    parent: Some(parent_name.to_string()),
                    children: Vec::new(),
                });
            }
        } else if kind == "data_constructors" || kind == "constructors" {
            walk_constructors(&child, source, parent_name, children);
        }
    }
}

fn find_constructor_name<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
    // Recursively look for a `constructor` node
    if node.kind() == "constructor" {
        let text = node_text(node, source);
        if !text.is_empty() {
            return Some(text);
        }
    }
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if let Some(name) = find_constructor_name(&child, source) {
            return Some(name);
        }
    }
    None
}

fn extract_type_alias(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)
        .or_else(|| first_child_text(node, source, &["name", "type", "identifier"]))?;

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::TypeAlias,
        span: Span::from_node(node),
        signature: node_text(node, source).to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

fn extract_class(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    let name = field_text(node, "name", source)
        .or_else(|| first_child_text(node, source, &["name", "identifier"]))?;

    let mut methods = Vec::new();

    // Look inside class_declarations for signatures and functions
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == "class_declarations" || child.kind() == "where" || child.kind() == "declarations" {
            let mut inner = child.walk();
            for decl in child.children(&mut inner) {
                extract_class_member(&decl, source, name, &mut methods);
            }
        } else {
            extract_class_member(&child, source, name, &mut methods);
        }
    }

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Trait,
        span: Span::from_node(node),
        signature: node_text(node, source).lines().next().unwrap_or("").to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: methods,
    })
}

fn extract_class_member(node: &tree_sitter::Node, source: &[u8], parent_name: &str, methods: &mut Vec<Symbol>) {
    if node.kind() == "signature" || node.kind() == "function" {
        let mname = field_text(node, "name", source)
            .or_else(|| first_child_text(node, source, &["variable", "identifier", "name"]));
        if let Some(mname) = mname {
            // Avoid duplicates from multi-clause functions
            if !methods.iter().any(|m| m.name == mname) {
                methods.push(Symbol {
                    name: mname.to_string(),
                    kind: SymbolKind::Method,
                    span: Span::from_node(node),
                    signature: node_text(node, source).to_string(),
                    doc_comment: extract_doc_comment(node, source),
                    parent: Some(parent_name.to_string()),
                    children: Vec::new(),
                });
            }
        }
    }
}

fn extract_instance(node: &tree_sitter::Node, source: &[u8]) -> Option<Symbol> {
    // Build the instance name from the head
    let text = node_text(node, source);
    let first_line = text.lines().next().unwrap_or("").to_string();
    let name = first_line
        .trim_start_matches("instance")
        .trim()
        .trim_end_matches("where")
        .trim()
        .to_string();

    if name.is_empty() {
        return None;
    }

    Some(Symbol {
        name,
        kind: SymbolKind::Impl,
        span: Span::from_node(node),
        signature: first_line,
        doc_comment: extract_doc_comment(node, source),
        parent: None,
        children: Vec::new(),
    })
}

/// Helper: find the first child matching one of the given node kinds and return its text.
fn first_child_text<'a>(node: &tree_sitter::Node, source: &'a [u8], kinds: &[&str]) -> Option<&'a str> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if kinds.contains(&child.kind()) {
            let text = node_text(&child, source);
            if !text.is_empty() {
                return Some(text);
            }
        }
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    fn parse_haskell(source: &str) -> (Vec<Symbol>, Vec<Import>) {
        let mut parser = tree_sitter::Parser::new();
        parser.set_language(&tree_sitter_haskell::LANGUAGE.into()).unwrap();
        let tree = parser.parse(source, None).unwrap();
        extract(&tree, source.as_bytes())
    }

    #[test]
    fn test_haskell_basics() {
        let source = r#"import Data.List
import qualified Data.Map as Map

data Color = Red | Green | Blue

type Name = String

class Printable a where
  display :: a -> String

instance Printable Color where
  display Red = "red"
  display Green = "green"
  display Blue = "blue"

greet :: String -> String
greet name = "Hello, " ++ name
"#;
        let (symbols, imports) = parse_haskell(source);

        // Imports
        assert!(imports.len() >= 2, "Expected at least 2 imports, got {}", imports.len());
        assert!(imports.iter().any(|i| i.path.contains("Data.List")));
        assert!(imports.iter().any(|i| i.path.contains("Data.Map")));

        // Data type (Enum)
        let color = symbols.iter().find(|s| s.name == "Color");
        assert!(color.is_some(), "Expected Color symbol, symbols: {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
        assert_eq!(color.unwrap().kind, SymbolKind::Enum);
        // Constructors
        assert!(color.unwrap().children.len() >= 3, "Expected 3 constructors for Color");

        // Type alias
        let name_alias = symbols.iter().find(|s| s.name == "Name");
        assert!(name_alias.is_some(), "Expected Name type alias");
        assert_eq!(name_alias.unwrap().kind, SymbolKind::TypeAlias);

        // Class (Trait)
        let printable = symbols.iter().find(|s| s.name.contains("Printable") && s.kind == SymbolKind::Trait);
        assert!(printable.is_some(), "Expected Printable class");
        assert_eq!(printable.unwrap().kind, SymbolKind::Trait);

        // Instance (Impl)
        let instance = symbols.iter().find(|s| s.kind == SymbolKind::Impl);
        assert!(instance.is_some(), "Expected an instance (Impl) symbol");

        // Function
        let greet = symbols.iter().find(|s| s.name == "greet");
        assert!(greet.is_some(), "Expected greet function");
        assert_eq!(greet.unwrap().kind, SymbolKind::Function);
    }
}