car-ast 0.7.0

Tree-sitter AST parsing for code-aware inference
Documentation
use crate::types::*;
use super::{node_text, extract_doc_comment};

pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
    let root = tree.root_node();
    let mut symbols = Vec::new();
    let mut imports = Vec::new();

    extract_children(&root, source, &mut symbols, &mut imports, None);

    (symbols, imports)
}

fn extract_children(
    node: &tree_sitter::Node,
    source: &[u8],
    symbols: &mut Vec<Symbol>,
    imports: &mut Vec<Import>,
    parent_name: Option<&str>,
) {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        match child.kind() {
            "import_declaration" => {
                let text = node_text(&child, source).trim().to_string();
                let path = text.strip_prefix("import").unwrap_or(&text).trim().to_string();
                imports.push(Import {
                    path,
                    alias: None,
                    span: Span::from_node(&child),
                });
            }
            "function_declaration" => {
                if let Some(sym) = extract_function(&child, source, parent_name) {
                    symbols.push(sym);
                }
            }
            "protocol_function_declaration" => {
                if let Some(sym) = extract_function(&child, source, parent_name) {
                    symbols.push(sym);
                }
            }
            "class_declaration" => {
                // tree-sitter-swift uses class_declaration for struct, class, and enum
                let kind = detect_class_kind(&child, source);
                if let Some(sym) = extract_type_decl(&child, source, kind, parent_name) {
                    symbols.push(sym);
                }
            }
            "protocol_declaration" => {
                if let Some(sym) = extract_protocol(&child, source, parent_name) {
                    symbols.push(sym);
                }
            }
            "property_declaration" => {
                if let Some(sym) = extract_property(&child, source, parent_name) {
                    symbols.push(sym);
                }
            }
            "typealias_declaration" => {
                if let Some(sym) = extract_typealias(&child, source, parent_name) {
                    symbols.push(sym);
                }
            }
            // Recurse into body nodes
            "class_body" | "enum_class_body" | "protocol_body" | "source_file" => {
                extract_children(&child, source, symbols, imports, parent_name);
            }
            _ => {}
        }
    }
}

/// Detect whether a class_declaration is actually a struct, class, or enum
/// by looking at the keyword child node.
fn detect_class_kind(node: &tree_sitter::Node, _source: &[u8]) -> SymbolKind {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        match child.kind() {
            "struct" => return SymbolKind::Struct,
            "enum" => return SymbolKind::Enum,
            "class" => return SymbolKind::Class,
            _ => {}
        }
    }
    SymbolKind::Class
}

fn extract_function(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = find_first_child_of_kind(node, "simple_identifier", source)?;

    let kind = if parent_name.is_some() {
        SymbolKind::Method
    } else {
        SymbolKind::Function
    };

    // Try to get signature up to the body
    let signature = if let Some(body) = find_first_child_of_kind_node(node, "function_body") {
        let sig = &source[node.start_byte()..body.start_byte()];
        std::str::from_utf8(sig).unwrap_or("").trim().to_string()
    } else {
        // Protocol method declarations have no body
        node_text(node, source).trim().to_string()
    };

    Some(Symbol {
        name: name.to_string(),
        kind,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children: Vec::new(),
    })
}

fn extract_type_decl(
    node: &tree_sitter::Node,
    source: &[u8],
    kind: SymbolKind,
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = find_first_child_of_kind(node, "type_identifier", source)?;

    // Signature: everything before the body
    let body_node = find_first_child_of_kind_node(node, "class_body")
        .or_else(|| find_first_child_of_kind_node(node, "enum_class_body"));
    let signature = if let Some(body) = body_node {
        let sig = &source[node.start_byte()..body.start_byte()];
        std::str::from_utf8(sig).unwrap_or("").trim().to_string()
    } else {
        node_text(node, source).lines().next().unwrap_or("").trim().to_string()
    };

    // Recurse into body to find children
    let mut children = Vec::new();
    let mut child_imports = Vec::new();
    let body = find_first_child_of_kind_node(node, "class_body")
        .or_else(|| find_first_child_of_kind_node(node, "enum_class_body"));
    if let Some(body) = body {
        extract_children(&body, source, &mut children, &mut child_imports, Some(name));
    }

    Some(Symbol {
        name: name.to_string(),
        kind,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children,
    })
}

fn extract_protocol(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = find_first_child_of_kind(node, "type_identifier", source)?;

    let body = find_first_child_of_kind_node(node, "protocol_body");
    let signature = if let Some(ref body) = body {
        let sig = &source[node.start_byte()..body.start_byte()];
        std::str::from_utf8(sig).unwrap_or("").trim().to_string()
    } else {
        node_text(node, source).lines().next().unwrap_or("").trim().to_string()
    };

    let mut children = Vec::new();
    let mut child_imports = Vec::new();
    if let Some(body) = body {
        extract_children(&body, source, &mut children, &mut child_imports, Some(name));
    }

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Interface,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children,
    })
}

fn extract_property(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = find_property_name(node, source)?;

    let signature = node_text(node, source)
        .lines()
        .next()
        .unwrap_or("")
        .trim()
        .to_string();

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::Const,
        span: Span::from_node(node),
        signature,
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children: Vec::new(),
    })
}

fn extract_typealias(
    node: &tree_sitter::Node,
    source: &[u8],
    parent_name: Option<&str>,
) -> Option<Symbol> {
    let name = find_first_child_of_kind(node, "type_identifier", source)?;

    Some(Symbol {
        name: name.to_string(),
        kind: SymbolKind::TypeAlias,
        span: Span::from_node(node),
        signature: node_text(node, source).trim().to_string(),
        doc_comment: extract_doc_comment(node, source),
        parent: parent_name.map(|s| s.to_string()),
        children: Vec::new(),
    })
}

/// Find the first child node of a given kind and return its text.
fn find_first_child_of_kind<'a>(
    node: &tree_sitter::Node,
    kind: &str,
    source: &'a [u8],
) -> Option<&'a str> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == kind {
            return Some(node_text(&child, source));
        }
    }
    None
}

/// Find the first child node of a given kind and return the node.
fn find_first_child_of_kind_node<'a>(
    node: &'a tree_sitter::Node,
    kind: &str,
) -> Option<tree_sitter::Node<'a>> {
    let mut cursor = node.walk();
    for child in node.children(&mut cursor) {
        if child.kind() == kind {
            return Some(child);
        }
    }
    None
}

/// Extract property name from a property_declaration node.
/// Swift property declarations use pattern bindings: `let/var name: Type = value`
fn find_property_name<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
    fn search_for_identifier<'b>(node: &tree_sitter::Node, source: &'b [u8]) -> Option<&'b str> {
        let mut cursor = node.walk();
        for child in node.children(&mut cursor) {
            if child.kind() == "simple_identifier" {
                return Some(node_text(&child, source));
            }
            if child.kind() == "pattern" || child.kind() == "directly_assignable_expression" {
                if let Some(name) = search_for_identifier(&child, source) {
                    return Some(name);
                }
            }
        }
        None
    }
    search_for_identifier(node, source)
}

#[cfg(test)]
mod tests {
    use super::*;

    fn parse_swift(source: &str) -> (Vec<Symbol>, Vec<Import>) {
        let mut parser = tree_sitter::Parser::new();
        parser.set_language(&tree_sitter_swift::LANGUAGE.into()).unwrap();
        let tree = parser.parse(source, None).unwrap();
        extract(&tree, source.as_bytes())
    }

    #[test]
    fn test_struct_and_protocol() {
        let source = r#"
import Foundation
import UIKit

protocol Drawable {
    func draw()
}

struct Point {
    let x: Double
    let y: Double

    func distance(to other: Point) -> Double {
        let dx = x - other.x
        let dy = y - other.y
        return (dx * dx + dy * dy).squareRoot()
    }
}
"#;
        let (symbols, imports) = parse_swift(source);

        // Imports
        assert_eq!(imports.len(), 2);
        assert_eq!(imports[0].path, "Foundation");
        assert_eq!(imports[1].path, "UIKit");

        // Protocol
        let proto = symbols.iter().find(|s| s.name == "Drawable");
        assert!(proto.is_some(), "missing Drawable in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
        assert_eq!(proto.unwrap().kind, SymbolKind::Interface);

        // Struct
        let point = symbols.iter().find(|s| s.name == "Point");
        assert!(point.is_some(), "missing Point in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
        let point = point.unwrap();
        assert_eq!(point.kind, SymbolKind::Struct);

        // Children of Point: x, y (properties), distance (method)
        let child_names: Vec<&str> = point.children.iter().map(|s| s.name.as_str()).collect();
        assert!(child_names.contains(&"distance"), "missing distance in {:?}", child_names);
        assert!(child_names.contains(&"x"), "missing x in {:?}", child_names);
        assert!(child_names.contains(&"y"), "missing y in {:?}", child_names);

        let dist = point.children.iter().find(|s| s.name == "distance").unwrap();
        assert_eq!(dist.kind, SymbolKind::Method);

        let x = point.children.iter().find(|s| s.name == "x").unwrap();
        assert_eq!(x.kind, SymbolKind::Const);
    }

    #[test]
    fn test_class_and_enum() {
        let source = r#"
class Vehicle {
    var speed: Int

    func accelerate() {
        speed += 10
    }
}

enum Direction {
    case north
    case south
    case east
    case west
}

func freeFunction() -> String {
    return "hello"
}

typealias Speed = Double
"#;
        let (symbols, _imports) = parse_swift(source);

        let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();

        // Class
        let vehicle = symbols.iter().find(|s| s.name == "Vehicle");
        assert!(vehicle.is_some(), "missing Vehicle in {:?}", names);
        assert_eq!(vehicle.unwrap().kind, SymbolKind::Class);

        // Enum
        let dir = symbols.iter().find(|s| s.name == "Direction");
        assert!(dir.is_some(), "missing Direction in {:?}", names);
        assert_eq!(dir.unwrap().kind, SymbolKind::Enum);

        // Free function
        let func = symbols.iter().find(|s| s.name == "freeFunction");
        assert!(func.is_some(), "missing freeFunction in {:?}", names);
        assert_eq!(func.unwrap().kind, SymbolKind::Function);

        // TypeAlias
        let ta = symbols.iter().find(|s| s.name == "Speed");
        assert!(ta.is_some(), "missing Speed typealias in {:?}", names);
        assert_eq!(ta.unwrap().kind, SymbolKind::TypeAlias);
    }
}