use crate::types::*;
use super::{node_text, field_text, extract_doc_comment, extract_signature};
pub fn extract(tree: &tree_sitter::Tree, source: &[u8]) -> (Vec<Symbol>, Vec<Import>) {
let root = tree.root_node();
let mut symbols = Vec::new();
let mut imports = Vec::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
extract_node(&child, source, &mut symbols, &mut imports, None);
}
(symbols, imports)
}
fn extract_node(
node: &tree_sitter::Node,
source: &[u8],
symbols: &mut Vec<Symbol>,
imports: &mut Vec<Import>,
parent_name: Option<&str>,
) {
match node.kind() {
"function_definition" => {
if let Some(sym) = extract_function(node, source, parent_name) {
symbols.push(sym);
}
}
"class_definition" => {
if let Some(sym) = extract_type_def(node, source, SymbolKind::Class, parent_name) {
symbols.push(sym);
}
}
"object_definition" => {
if let Some(sym) = extract_type_def(node, source, SymbolKind::Class, parent_name) {
symbols.push(sym);
}
}
"trait_definition" => {
if let Some(sym) = extract_type_def(node, source, SymbolKind::Trait, parent_name) {
symbols.push(sym);
}
}
"val_definition" | "var_definition" => {
if let Some(sym) = extract_val(node, source, parent_name) {
symbols.push(sym);
}
}
"type_definition" => {
if let Some(sym) = extract_type_alias(node, source, parent_name) {
symbols.push(sym);
}
}
"import_declaration" => {
let text = node_text(node, source);
let path = text.trim_start_matches("import").trim().trim_end_matches(';').trim();
imports.push(Import {
path: path.to_string(),
alias: None,
span: Span::from_node(node),
});
}
"package_clause" => {
let text = node_text(node, source);
let path = text.trim_start_matches("package").trim().trim_end_matches(';').trim();
if !path.is_empty() {
symbols.push(Symbol {
name: path.to_string(),
kind: SymbolKind::Module,
span: Span::from_node(node),
signature: text.trim().to_string(),
doc_comment: None,
parent: None,
children: Vec::new(),
});
}
}
"template_body" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_node(&child, source, symbols, imports, parent_name);
}
}
_ => {}
}
}
fn extract_function(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
let kind = if parent_name.is_some() {
SymbolKind::Method
} else {
SymbolKind::Function
};
Some(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature: extract_signature(node, "body", source),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn extract_type_def(
node: &tree_sitter::Node,
source: &[u8],
kind: SymbolKind,
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = field_text(node, "name", source)?;
let signature = if let Some(body) = node.child_by_field_name("body") {
let sig = &source[node.start_byte()..body.start_byte()];
std::str::from_utf8(sig).unwrap_or("").trim().to_string()
} else {
node_text(node, source).to_string()
};
let mut children = Vec::new();
let mut child_imports = Vec::new();
if let Some(body) = node.child_by_field_name("body") {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
extract_node(&child, source, &mut children, &mut child_imports, Some(name));
}
}
Some(Symbol {
name: name.to_string(),
kind,
span: Span::from_node(node),
signature,
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children,
})
}
fn extract_val(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = field_text(node, "pattern", source)
.or_else(|| field_text(node, "name", source))
.or_else(|| {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
return Some(node_text(&child, source));
}
}
None
})?;
if name.is_empty() {
return None;
}
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::Const,
span: Span::from_node(node),
signature: node_text(node, source).lines().next().unwrap_or("").trim().to_string(),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
fn extract_type_alias(
node: &tree_sitter::Node,
source: &[u8],
parent_name: Option<&str>,
) -> Option<Symbol> {
let name = field_text(node, "name", source)
.or_else(|| {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "type_identifier" {
return Some(node_text(&child, source));
}
}
None
})?;
Some(Symbol {
name: name.to_string(),
kind: SymbolKind::TypeAlias,
span: Span::from_node(node),
signature: node_text(node, source).trim().to_string(),
doc_comment: extract_doc_comment(node, source),
parent: parent_name.map(|s| s.to_string()),
children: Vec::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_scala(source: &str) -> (Vec<Symbol>, Vec<Import>) {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_scala::LANGUAGE.into()).unwrap();
let tree = parser.parse(source, None).unwrap();
extract(&tree, source.as_bytes())
}
#[test]
fn test_scala_class_with_methods() {
let source = r#"
package com.example
import scala.collection.mutable.ListBuffer
import scala.math.Pi
class Calculator(initial: Int) {
val name: String = "calc"
def add(x: Int): Int = {
initial + x
}
def multiply(x: Int, y: Int): Int = x * y
}
"#;
let (symbols, imports) = parse_scala(source);
assert_eq!(imports.len(), 2);
assert_eq!(imports[0].path, "scala.collection.mutable.ListBuffer");
assert_eq!(imports[1].path, "scala.math.Pi");
let pkg = symbols.iter().find(|s| s.kind == SymbolKind::Module);
assert!(pkg.is_some());
assert_eq!(pkg.unwrap().name, "com.example");
let calc = symbols.iter().find(|s| s.name == "Calculator").unwrap();
assert_eq!(calc.kind, SymbolKind::Class);
let child_names: Vec<&str> = calc.children.iter().map(|s| s.name.as_str()).collect();
assert!(child_names.contains(&"add"), "expected 'add' in {:?}", child_names);
assert!(child_names.contains(&"multiply"), "expected 'multiply' in {:?}", child_names);
}
#[test]
fn test_scala_trait_and_object() {
let source = r#"
trait Drawable {
def draw(): Unit
}
object Singleton {
val instance: String = "single"
def doWork(): Boolean = true
}
"#;
let (symbols, _imports) = parse_scala(source);
let drawable = symbols.iter().find(|s| s.name == "Drawable").unwrap();
assert_eq!(drawable.kind, SymbolKind::Trait);
let singleton = symbols.iter().find(|s| s.name == "Singleton").unwrap();
assert_eq!(singleton.kind, SymbolKind::Class);
let method_names: Vec<&str> = singleton.children.iter()
.filter(|s| s.kind == SymbolKind::Method)
.map(|s| s.name.as_str())
.collect();
assert!(method_names.contains(&"doWork"), "expected 'doWork' in {:?}", method_names);
}
#[test]
fn test_scala_type_alias_and_val() {
let source = r#"
type StringMap = Map[String, String]
val PI: Double = 3.14159
def greet(name: String): String = s"Hello, $name"
"#;
let (symbols, _imports) = parse_scala(source);
let type_alias = symbols.iter().find(|s| s.kind == SymbolKind::TypeAlias);
assert!(type_alias.is_some(), "expected TypeAlias in {:?}", symbols.iter().map(|s| (&s.name, &s.kind)).collect::<Vec<_>>());
assert_eq!(type_alias.unwrap().name, "StringMap");
let pi = symbols.iter().find(|s| s.name == "PI");
assert!(pi.is_some(), "expected PI in {:?}", symbols.iter().map(|s| &s.name).collect::<Vec<_>>());
assert_eq!(pi.unwrap().kind, SymbolKind::Const);
let greet = symbols.iter().find(|s| s.name == "greet").unwrap();
assert_eq!(greet.kind, SymbolKind::Function);
}
}