use crate::types::*;
use super::{node_text, extract_doc_comment};
pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
let root = tree.root_node();
let mut symbols = Vec::new();
let mut imports = Vec::new();
extract_children(&root, source, &mut symbols, &mut imports, None);
(symbols, imports)
}
fn extract_children(
node: &tree_sitter::Node,
source: &[u8],
symbols: &mut Vec<Symbol>,
imports: &mut Vec<Import>,
parent_name: Option<&str>,
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"import_declaration" => {
let text = node_text(&child, source).trim().to_string();
let path = text.strip_prefix("import").unwrap_or(&text).trim().to_string();
imports.push(Import {
path,
alias: None,
span: Span::from_node(&child),
});
}
"function_declaration" => {
if let Some(sym) = extract_function(&child, source, parent_name) {
symbols.push(sym);
}
}
"protocol_function_declaration" => {
if let Some(sym) = extract_function(&child, source, parent_name) {
symbols.push(sym);
}
}
"class_declaration" => {
let kind = detect_class_kind(&child, source);
if let Some(sym) = extract_type_decl(&child, source, kind, parent_name) {
symbols.push(sym);
}
}
"protocol_declaration" => {
if let Some(sym) = extract_protocol(&child, source, parent_name) {
symbols.push(sym);
}
}
"property_declaration" => {
if let Some(sym) = extract_property(&child, source, parent_name) {
symbols.push(sym);
}
}
"typealias_declaration" => {
if let Some(sym) = extract_typealias(&child, source, parent_name) {
symbols.push(sym);
}
}
"class_body" | "enum_class_body" | "protocol_body" | "source_file" => {
extract_children(&child, source, symbols, imports, parent_name);
}
_ => {}
}
}
}
fn detect_class_kind(node: &tree_sitter::Node, _source: &[u8]) -> SymbolKind {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"struct" => return SymbolKind::Struct,
"enum" => return SymbolKind::Enum,
"class" => return SymbolKind::Class,
_ => {}
}
}
SymbolKind::Class
}
fn extract_function(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = find_first_child_of_kind(node, "simple_identifier", source)?;
let kind = if parent_name.is_some() {
SymbolKind::Method
} else {
SymbolKind::Function
};
let signature = if let Some(body) = find_first_child_of_kind_node(node, "function_body") {
let sig = &source[node.start_byte()..body.start_byte()];
std::str::from_utf8(sig).unwrap_or("").trim().to_string()
} else {
node_text(node, source).trim().to_string()
};
Some(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature,
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn extract_type_decl(
node: &tree_sitter::Node,
source: &[u8],
kind: SymbolKind,
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = find_first_child_of_kind(node, "type_identifier", source)?;
let body_node = find_first_child_of_kind_node(node, "class_body")
.or_else(|| find_first_child_of_kind_node(node, "enum_class_body"));
let signature = if let Some(body) = body_node {
let sig = &source[node.start_byte()..body.start_byte()];
std::str::from_utf8(sig).unwrap_or("").trim().to_string()
} else {
node_text(node, source).lines().next().unwrap_or("").trim().to_string()
};
let mut children = Vec::new();
let mut child_imports = Vec::new();
let body = find_first_child_of_kind_node(node, "class_body")
.or_else(|| find_first_child_of_kind_node(node, "enum_class_body"));
if let Some(body) = body {
extract_children(&body, source, &mut children, &mut child_imports, Some(name));
}
Some(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature,
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children,
})
}
fn extract_protocol(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = find_first_child_of_kind(node, "type_identifier", source)?;
let body = find_first_child_of_kind_node(node, "protocol_body");
let signature = if let Some(ref body) = body {
let sig = &source[node.start_byte()..body.start_byte()];
std::str::from_utf8(sig).unwrap_or("").trim().to_string()
} else {
node_text(node, source).lines().next().unwrap_or("").trim().to_string()
};
let mut children = Vec::new();
let mut child_imports = Vec::new();
if let Some(body) = body {
extract_children(&body, source, &mut children, &mut child_imports, Some(name));
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Interface,
span: Span::from_node(node),
signature,
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children,
})
}
fn extract_property(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = find_property_name(node, source)?;
let signature = node_text(node, source)
.lines()
.next()
.unwrap_or("")
.trim()
.to_string();
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Const,
span: Span::from_node(node),
signature,
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn extract_typealias(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = find_first_child_of_kind(node, "type_identifier", source)?;
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::TypeAlias,
span: Span::from_node(node),
signature: node_text(node, source).trim().to_string(),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn find_first_child_of_kind<'a>(
node: &tree_sitter::Node,
kind: &str,
source: &'a [u8],
) -> Option<&'a str> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == kind {
return Some(node_text(&child, source));
}
}
None
}
fn find_first_child_of_kind_node<'a>(
node: &'a tree_sitter::Node,
kind: &str,
) -> Option<tree_sitter::Node<'a>> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == kind {
return Some(child);
}
}
None
}
fn find_property_name<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> Option<&'a str> {
fn search_for_identifier<'b>(node: &tree_sitter::Node, source: &'b [u8]) -> Option<&'b str> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "simple_identifier" {
return Some(node_text(&child, source));
}
if child.kind() == "pattern" || child.kind() == "directly_assignable_expression" {
if let Some(name) = search_for_identifier(&child, source) {
return Some(name);
}
}
}
None
}
search_for_identifier(node, source)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_swift(source: &str) -> (Vec<Symbol>, Vec<Import>) {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_swift::LANGUAGE.into()).unwrap();
let tree = parser.parse(source, None).unwrap();
extract(&tree, source.as_bytes())
}
#[test]
fn test_struct_and_protocol() {
let source = r#"
import Foundation
import UIKit
protocol Drawable {
func draw()
}
struct Point {
let x: Double
let y: Double
func distance(to other: Point) -> Double {
let dx = x - other.x
let dy = y - other.y
return (dx * dx + dy * dy).squareRoot()
}
}
"#;
let (symbols, imports) = parse_swift(source);
assert_eq!(imports.len(), 2);
assert_eq!(imports[0].path, "Foundation");
assert_eq!(imports[1].path, "UIKit");
let proto = symbols.iter().find(|s| s.name == "Drawable");
assert!(proto.is_some(), "missing Drawable in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
assert_eq!(proto.unwrap().kind, SymbolKind::Interface);
let point = symbols.iter().find(|s| s.name == "Point");
assert!(point.is_some(), "missing Point in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
let point = point.unwrap();
assert_eq!(point.kind, SymbolKind::Struct);
let child_names: Vec<&str> = point.children.iter().map(|s| s.name.as_str()).collect();
assert!(child_names.contains(&"distance"), "missing distance in {:?}", child_names);
assert!(child_names.contains(&"x"), "missing x in {:?}", child_names);
assert!(child_names.contains(&"y"), "missing y in {:?}", child_names);
let dist = point.children.iter().find(|s| s.name == "distance").unwrap();
assert_eq!(dist.kind, SymbolKind::Method);
let x = point.children.iter().find(|s| s.name == "x").unwrap();
assert_eq!(x.kind, SymbolKind::Const);
}
#[test]
fn test_class_and_enum() {
let source = r#"
class Vehicle {
var speed: Int
func accelerate() {
speed += 10
}
}
enum Direction {
case north
case south
case east
case west
}
func freeFunction() -> String {
return "hello"
}
typealias Speed = Double
"#;
let (symbols, _imports) = parse_swift(source);
let names: Vec<&str> = symbols.iter().map(|s| s.name.as_str()).collect();
let vehicle = symbols.iter().find(|s| s.name == "Vehicle");
assert!(vehicle.is_some(), "missing Vehicle in {:?}", names);
assert_eq!(vehicle.unwrap().kind, SymbolKind::Class);
let dir = symbols.iter().find(|s| s.name == "Direction");
assert!(dir.is_some(), "missing Direction in {:?}", names);
assert_eq!(dir.unwrap().kind, SymbolKind::Enum);
let func = symbols.iter().find(|s| s.name == "freeFunction");
assert!(func.is_some(), "missing freeFunction in {:?}", names);
assert_eq!(func.unwrap().kind, SymbolKind::Function);
let ta = symbols.iter().find(|s| s.name == "Speed");
assert!(ta.is_some(), "missing Speed typealias in {:?}", names);
assert_eq!(ta.unwrap().kind, SymbolKind::TypeAlias);
}
}