use std::sync::OnceLock;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator, Tree};
use crate::graph::node::{DecoratorInfo, SymbolInfo, SymbolKind, SymbolVisibility};
const SYMBOL_QUERY_RS: &str = r#"
(function_item name: (identifier) @name) @symbol
(struct_item name: (type_identifier) @name) @symbol
(enum_item name: (type_identifier) @name) @symbol
(trait_item name: (type_identifier) @name) @symbol
(type_item name: (type_identifier) @name) @symbol
(const_item name: (identifier) @name) @symbol
(static_item name: (identifier) @name) @symbol
(macro_definition name: (identifier) @name) @symbol
"#;
const SYMBOL_QUERY_TS: &str = r#"
; Top-level function declarations
(function_declaration
name: (identifier) @name) @symbol
; Class declarations
(class_declaration
name: (type_identifier) @name) @symbol
; Interface declarations (TS-only)
(interface_declaration
name: (type_identifier) @name) @symbol
; Type alias declarations (TS-only)
(type_alias_declaration
name: (type_identifier) @name) @symbol
; Enum declarations
(enum_declaration
name: (identifier) @name) @symbol
; Exported arrow-function constants: export const Foo = () => {}
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Top-level non-exported arrow-function constants: const Foo = () => {}
(program
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Exported variables that are NOT arrow functions: export const Foo = value
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (_) @val))) @symbol
"#;
const SYMBOL_QUERY_TSX: &str = r#"
; Top-level function declarations
(function_declaration
name: (identifier) @name) @symbol
; Class declarations
(class_declaration
name: (type_identifier) @name) @symbol
; Interface declarations (TS-only but TSX grammar supports it)
(interface_declaration
name: (type_identifier) @name) @symbol
; Type alias declarations (TS-only but TSX grammar supports it)
(type_alias_declaration
name: (type_identifier) @name) @symbol
; Enum declarations
(enum_declaration
name: (identifier) @name) @symbol
; Exported arrow-function constants
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Top-level non-exported arrow-function constants
(program
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Exported variables that are NOT arrow functions
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (_) @val))) @symbol
"#;
const SYMBOL_QUERY_JS: &str = r#"
; Top-level function declarations
(function_declaration
name: (identifier) @name) @symbol
; Class declarations
(class_declaration
name: (identifier) @name) @symbol
; Exported arrow-function constants
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Top-level non-exported arrow-function constants
(program
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (arrow_function)))) @symbol
; Exported variables that are NOT arrow functions
(export_statement
(lexical_declaration
(variable_declarator
name: (identifier) @name
value: (_) @val))) @symbol
"#;
static TS_QUERY: OnceLock<Query> = OnceLock::new();
static TSX_QUERY: OnceLock<Query> = OnceLock::new();
static JS_QUERY: OnceLock<Query> = OnceLock::new();
static RS_SYMBOL_QUERY: OnceLock<Query> = OnceLock::new();
fn ts_query(language: &Language) -> &'static Query {
TS_QUERY.get_or_init(|| Query::new(language, SYMBOL_QUERY_TS).expect("invalid TS symbol query"))
}
fn tsx_query(language: &Language) -> &'static Query {
TSX_QUERY
.get_or_init(|| Query::new(language, SYMBOL_QUERY_TSX).expect("invalid TSX symbol query"))
}
fn js_query(language: &Language) -> &'static Query {
JS_QUERY.get_or_init(|| Query::new(language, SYMBOL_QUERY_JS).expect("invalid JS symbol query"))
}
fn rs_symbol_query(language: &Language) -> &'static Query {
RS_SYMBOL_QUERY
.get_or_init(|| Query::new(language, SYMBOL_QUERY_RS).expect("invalid RS symbol query"))
}
fn node_text<'a>(node: Node<'a>, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
fn detect_export(node: Node, source: &[u8]) -> (bool, bool) {
let mut current = Some(node);
while let Some(n) = current {
if n.kind() == "export_statement" {
let is_default = (0..n.child_count()).any(|i| {
n.child(i as u32)
.map(|c| node_text(c, source) == "default")
.unwrap_or(false)
});
return (true, is_default);
}
current = n.parent();
}
(false, false)
}
fn contains_jsx(node: Node) -> bool {
if matches!(
node.kind(),
"jsx_element" | "jsx_fragment" | "jsx_self_closing_element"
) {
return true;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if contains_jsx(child) {
return true;
}
}
false
}
fn is_arrow_or_function_value(node: Node) -> bool {
matches!(node.kind(), "arrow_function" | "function")
}
fn classify_symbol(
symbol_node: Node,
name_node: Node,
val_node: Option<Node>,
is_tsx: bool,
_source: &[u8],
) -> Option<SymbolKind> {
let kind = find_declaration_kind(symbol_node, name_node);
match kind.as_deref() {
Some("function_declaration") => {
if is_tsx && function_body_contains_jsx(symbol_node) {
Some(SymbolKind::Component)
} else {
Some(SymbolKind::Function)
}
}
Some("class_declaration") => Some(SymbolKind::Class),
Some("interface_declaration") => Some(SymbolKind::Interface),
Some("type_alias_declaration") => Some(SymbolKind::TypeAlias),
Some("enum_declaration") => Some(SymbolKind::Enum),
Some("arrow_function_decl") => {
if is_tsx && arrow_body_contains_jsx(symbol_node, name_node) {
Some(SymbolKind::Component)
} else {
Some(SymbolKind::Function)
}
}
Some("exported_variable") => {
if let Some(val) = val_node {
if is_arrow_or_function_value(val) {
if is_tsx && arrow_body_contains_jsx(symbol_node, name_node) {
Some(SymbolKind::Component)
} else {
Some(SymbolKind::Function)
}
} else {
Some(SymbolKind::Variable)
}
} else {
Some(SymbolKind::Variable)
}
}
_ => None,
}
}
fn find_declaration_kind(symbol_node: Node, _name_node: Node) -> Option<String> {
let kind = symbol_node.kind();
match kind {
"function_declaration" => Some("function_declaration".into()),
"class_declaration" => Some("class_declaration".into()),
"interface_declaration" => Some("interface_declaration".into()),
"type_alias_declaration" => Some("type_alias_declaration".into()),
"enum_declaration" => Some("enum_declaration".into()),
"export_statement" => {
let mut cursor = symbol_node.walk();
for child in symbol_node.children(&mut cursor) {
match child.kind() {
"function_declaration" => return Some("function_declaration".into()),
"class_declaration" => return Some("class_declaration".into()),
"interface_declaration" => return Some("interface_declaration".into()),
"type_alias_declaration" => return Some("type_alias_declaration".into()),
"enum_declaration" => return Some("enum_declaration".into()),
"lexical_declaration" => {
return classify_lexical_declaration(child);
}
_ => {}
}
}
None
}
"lexical_declaration" => classify_lexical_declaration(symbol_node),
_ => None,
}
}
fn classify_lexical_declaration(lex_decl: Node) -> Option<String> {
let mut cursor = lex_decl.walk();
for child in lex_decl.children(&mut cursor) {
if child.kind() == "variable_declarator"
&& let Some(value_node) = child.child_by_field_name("value")
{
if is_arrow_or_function_value(value_node) {
return Some("arrow_function_decl".into());
} else {
return Some("exported_variable".into());
}
}
}
None
}
fn function_body_contains_jsx(func_node: Node) -> bool {
if let Some(body) = func_node.child_by_field_name("body") {
return contains_jsx(body);
}
false
}
fn arrow_body_contains_jsx(symbol_node: Node, name_node: Node) -> bool {
find_arrow_body(symbol_node, name_node)
.map(contains_jsx)
.unwrap_or(false)
}
fn find_arrow_body<'a>(node: Node<'a>, name_node: Node<'a>) -> Option<Node<'a>> {
if node.kind() == "variable_declarator"
&& let Some(decl_name) = node.child_by_field_name("name")
&& decl_name.id() == name_node.id()
&& let Some(value) = node.child_by_field_name("value")
&& is_arrow_or_function_value(value)
{
return value.child_by_field_name("body");
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = find_arrow_body(child, name_node) {
return Some(found);
}
}
None
}
fn extract_interface_children(iface_node: Node, source: &[u8]) -> Vec<SymbolInfo> {
let mut children = Vec::new();
let body = {
let mut found = None;
let mut cursor = iface_node.walk();
for child in iface_node.children(&mut cursor) {
if child.kind() == "interface_body" {
found = Some(child);
break;
}
}
match found {
Some(b) => b,
None => return children,
}
};
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
match child.kind() {
"property_signature" | "method_signature" => {
if let Some(name_node) = child.child_by_field_name("name") {
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
children.push(SymbolInfo {
name,
kind: SymbolKind::Property,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
..Default::default()
});
}
}
_ => {}
}
}
children
}
fn extract_class_children(class_node: Node, source: &[u8]) -> Vec<SymbolInfo> {
let mut children = Vec::new();
let body = {
let mut found = None;
let mut cursor = class_node.walk();
for child in class_node.children(&mut cursor) {
if child.kind() == "class_body" {
found = Some(child);
break;
}
}
match found {
Some(b) => b,
None => return children,
}
};
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
if child.kind() == "method_definition"
&& let Some(name_node) = child.child_by_field_name("name")
{
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
let decorators = extract_ts_decorators(child, source);
children.push(SymbolInfo {
name,
kind: SymbolKind::Method,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
decorators,
..Default::default()
});
}
}
children
}
fn extract_ts_decorators(node: tree_sitter::Node, source: &[u8]) -> Vec<DecoratorInfo> {
let mut decorators = Vec::new();
for i in 0..node.child_count() {
let child = node.child(i as u32).unwrap();
if child.kind() == "decorator" {
decorators.push(parse_decorator_node(child, source));
}
}
if !decorators.is_empty() {
return decorators;
}
for i in 0..node.child_count() {
let child = node.child(i as u32).unwrap();
match child.kind() {
"class_declaration"
| "function_declaration"
| "interface_declaration"
| "type_alias_declaration"
| "enum_declaration" => {
for j in 0..child.child_count() {
let grandchild = child.child(j as u32).unwrap();
if grandchild.kind() == "decorator" {
decorators.push(parse_decorator_node(grandchild, source));
}
}
if !decorators.is_empty() {
return decorators;
}
}
_ => {}
}
}
let parent = match node.parent() {
Some(p) => p,
None => return decorators,
};
for i in 0..parent.child_count() {
let child = parent.child(i as u32).unwrap();
if child.id() == node.id() {
break; }
if child.kind() == "decorator" {
decorators.push(parse_decorator_node(child, source));
}
}
decorators
}
fn parse_decorator_node(decorator_node: tree_sitter::Node, source: &[u8]) -> DecoratorInfo {
let inner = decorator_node.named_child(0);
match inner.map(|n| n.kind()) {
Some("identifier") => {
let name = node_text(inner.unwrap(), source).to_owned();
DecoratorInfo {
name,
object: None,
attribute: None,
args_raw: None,
framework: None,
}
}
Some("member_expression") | Some("attribute") => {
let attr_node = inner.unwrap();
let obj = attr_node
.child_by_field_name("object")
.map(|n| node_text(n, source).to_owned());
let attr = attr_node
.child_by_field_name("property")
.or_else(|| attr_node.child_by_field_name("attribute"))
.map(|n| node_text(n, source).to_owned());
let name = format!(
"{}.{}",
obj.as_deref().unwrap_or(""),
attr.as_deref().unwrap_or("")
);
DecoratorInfo {
name,
object: obj,
attribute: attr,
args_raw: None,
framework: None,
}
}
Some("call_expression") => {
let call = inner.unwrap();
let func = call.child_by_field_name("function");
let args = call
.child_by_field_name("arguments")
.map(|n| node_text(n, source).to_owned());
let (name, obj, attr) = match func.map(|f| f.kind()) {
Some("identifier") => {
let n = node_text(func.unwrap(), source).to_owned();
(n, None, None)
}
Some("member_expression") => {
let f = func.unwrap();
let o = f
.child_by_field_name("object")
.map(|n| node_text(n, source).to_owned());
let a = f
.child_by_field_name("property")
.map(|n| node_text(n, source).to_owned());
let n = format!(
"{}.{}",
o.as_deref().unwrap_or(""),
a.as_deref().unwrap_or("")
);
(n, o, a)
}
_ => (node_text(call, source).to_owned(), None, None),
};
DecoratorInfo {
name,
object: obj,
attribute: attr,
args_raw: args,
framework: None,
}
}
_ => DecoratorInfo {
name: node_text(decorator_node, source).to_owned(),
object: None,
attribute: None,
args_raw: None,
framework: None,
},
}
}
fn extract_rust_attributes(item_node: tree_sitter::Node, source: &[u8]) -> Vec<DecoratorInfo> {
let mut attrs = Vec::new();
let parent = match item_node.parent() {
Some(p) => p,
None => return attrs,
};
for i in 0..parent.child_count() {
let child = parent.child(i as u32).unwrap();
if child.id() == item_node.id() {
break;
}
if child.kind() == "attribute_item" {
attrs.push(parse_rust_attribute(child, source));
}
}
attrs
}
fn parse_rust_attribute(attr_item: tree_sitter::Node, source: &[u8]) -> DecoratorInfo {
let full_text = node_text(attr_item, source);
let inner = full_text
.trim_start_matches("#[")
.trim_start_matches("#![")
.trim_end_matches(']');
let (name, args) = match inner.find('(') {
Some(idx) => (
inner[..idx].trim().to_owned(),
Some(inner[idx..].to_owned()),
),
None => (inner.trim().to_owned(), None),
};
DecoratorInfo {
name,
object: None,
attribute: None,
args_raw: args,
framework: None,
}
}
enum LangKind {
TypeScript,
Tsx,
JavaScript,
}
fn lang_kind(language: &Language, is_tsx: bool) -> LangKind {
match language.name().unwrap_or("") {
"javascript" => LangKind::JavaScript,
_ => {
if is_tsx {
LangKind::Tsx
} else {
LangKind::TypeScript
}
}
}
}
pub fn extract_symbols(
tree: &Tree,
source: &[u8],
language: &Language,
is_tsx: bool,
) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let query = match lang_kind(language, is_tsx) {
LangKind::JavaScript => js_query(language),
LangKind::Tsx => tsx_query(language),
LangKind::TypeScript => ts_query(language),
};
let name_idx = query
.capture_index_for_name("name")
.expect("query must have @name capture");
let symbol_idx = query
.capture_index_for_name("symbol")
.expect("query must have @symbol capture");
let val_idx = query.capture_index_for_name("val");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source);
let mut seen: std::collections::HashSet<(String, usize)> = std::collections::HashSet::new();
let mut results: Vec<(SymbolInfo, Vec<SymbolInfo>)> = Vec::new();
while let Some(m) = matches.next() {
let mut symbol_node: Option<Node> = None;
let mut name_node: Option<Node> = None;
let mut val_node: Option<Node> = None;
for capture in m.captures {
if capture.index == symbol_idx {
symbol_node = Some(capture.node);
} else if capture.index == name_idx {
name_node = Some(capture.node);
} else if val_idx == Some(capture.index) {
val_node = Some(capture.node);
}
}
let (sym_node, name_node) = match (symbol_node, name_node) {
(Some(s), Some(n)) => (s, n),
_ => continue,
};
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
let key = (name.clone(), pos.row);
if !seen.insert(key) {
continue;
}
let kind = match classify_symbol(sym_node, name_node, val_node, is_tsx, source) {
Some(k) => k,
None => continue,
};
if kind == SymbolKind::Variable
&& let Some(val) = val_node
&& is_arrow_or_function_value(val)
{
continue;
}
let (is_exported, is_default) = detect_export(sym_node, source);
let decorators = extract_ts_decorators(sym_node, source);
let info = SymbolInfo {
name,
kind: kind.clone(),
line: pos.row + 1,
col: pos.column,
line_end: sym_node.end_position().row + 1,
is_exported,
is_default,
decorators,
..Default::default()
};
let children = match kind {
SymbolKind::Interface => {
let iface_node = find_declaration_node(sym_node, "interface_declaration");
iface_node
.map(|n| extract_interface_children(n, source))
.unwrap_or_default()
}
SymbolKind::Class => {
let class_node = find_declaration_node(sym_node, "class_declaration");
class_node
.map(|n| extract_class_children(n, source))
.unwrap_or_default()
}
_ => vec![],
};
results.push((info, children));
}
results
}
fn find_declaration_node<'a>(node: Node<'a>, target_kind: &str) -> Option<Node<'a>> {
if node.kind() == target_kind {
return Some(node);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = find_declaration_node(child, target_kind) {
return Some(found);
}
}
None
}
fn extract_visibility(node: Node, source: &[u8]) -> SymbolVisibility {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier" {
let text = node_text(child, source);
if text == "pub" {
return SymbolVisibility::Pub;
} else if text.starts_with("pub(") {
return SymbolVisibility::PubCrate;
}
}
}
SymbolVisibility::Private
}
fn extract_simple_type_name<'a>(type_node: Node<'a>, source: &'a [u8]) -> &'a str {
match type_node.kind() {
"type_identifier" | "scoped_type_identifier" => node_text(type_node, source),
"generic_type" => {
if let Some(name_node) = type_node.child_by_field_name("type") {
node_text(name_node, source)
} else {
node_text(type_node, source)
}
}
_ => node_text(type_node, source),
}
}
fn extract_trait_methods(trait_node: Node, trait_name: &str, source: &[u8]) -> Vec<SymbolInfo> {
let mut methods = Vec::new();
let decl_list = {
let mut found = None;
let mut cursor = trait_node.walk();
for child in trait_node.children(&mut cursor) {
if child.kind() == "declaration_list" {
found = Some(child);
break;
}
}
match found {
Some(n) => n,
None => return methods,
}
};
let mut cursor = decl_list.walk();
for child in decl_list.children(&mut cursor) {
match child.kind() {
"function_signature_item" | "function_item" => {
if let Some(name_node) = child.child_by_field_name("name") {
let method_name = node_text(name_node, source);
let qualified_name = format!("{}::{}", trait_name, method_name);
let pos = name_node.start_position();
let visibility = extract_visibility(child, source);
let decorators = extract_rust_attributes(child, source);
methods.push(SymbolInfo {
name: qualified_name,
kind: SymbolKind::ImplMethod,
line: pos.row + 1,
col: pos.column,
line_end: child.end_position().row + 1,
visibility,
decorators,
..Default::default()
});
}
}
_ => {}
}
}
methods
}
pub fn extract_rust_symbols(
tree: &Tree,
source: &[u8],
language: &Language,
) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let query = rs_symbol_query(language);
let name_idx = query
.capture_index_for_name("name")
.expect("RS query must have @name capture");
let symbol_idx = query
.capture_index_for_name("symbol")
.expect("RS query must have @symbol capture");
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, tree.root_node(), source);
let mut seen: std::collections::HashSet<(String, usize)> = std::collections::HashSet::new();
let mut results: Vec<(SymbolInfo, Vec<SymbolInfo>)> = Vec::new();
while let Some(m) = matches.next() {
let mut symbol_node: Option<Node> = None;
let mut name_node: Option<Node> = None;
for capture in m.captures {
if capture.index == symbol_idx {
symbol_node = Some(capture.node);
} else if capture.index == name_idx {
name_node = Some(capture.node);
}
}
let (sym_node, name_node) = match (symbol_node, name_node) {
(Some(s), Some(n)) => (s, n),
_ => continue,
};
let name = node_text(name_node, source).to_owned();
let pos = name_node.start_position();
let key = (name.clone(), pos.row);
if !seen.insert(key) {
continue;
}
let kind = match sym_node.kind() {
"function_item" => SymbolKind::Function,
"struct_item" => SymbolKind::Struct,
"enum_item" => SymbolKind::Enum,
"trait_item" => SymbolKind::Trait,
"type_item" => SymbolKind::TypeAlias,
"const_item" => SymbolKind::Const,
"static_item" => SymbolKind::Static,
"macro_definition" => SymbolKind::Macro,
_ => continue,
};
let visibility = extract_visibility(sym_node, source);
let decorators = extract_rust_attributes(sym_node, source);
let info = SymbolInfo {
name: name.clone(),
kind: kind.clone(),
line: pos.row + 1,
col: pos.column,
line_end: sym_node.end_position().row + 1,
visibility,
decorators,
..Default::default()
};
let children = if kind == SymbolKind::Trait {
extract_trait_methods(sym_node, &name, source)
} else {
vec![]
};
results.push((info, children));
}
results
}
pub fn extract_impl_methods(tree: &Tree, source: &[u8]) -> Vec<(SymbolInfo, Vec<SymbolInfo>)> {
let mut results = Vec::new();
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if child.kind() != "impl_item" {
continue;
}
let type_name = match child.child_by_field_name("type") {
Some(type_node) => extract_simple_type_name(type_node, source).to_owned(),
None => continue,
};
let trait_name: Option<String> = child
.child_by_field_name("trait")
.map(|trait_node| extract_simple_type_name(trait_node, source).to_owned());
let decl_list = {
let mut found = None;
let mut c = child.walk();
for grandchild in child.children(&mut c) {
if grandchild.kind() == "declaration_list" {
found = Some(grandchild);
break;
}
}
match found {
Some(n) => n,
None => continue,
}
};
let mut decl_cursor = decl_list.walk();
for method_node in decl_list.children(&mut decl_cursor) {
if method_node.kind() != "function_item" {
continue;
}
let method_name = match method_node.child_by_field_name("name") {
Some(n) => node_text(n, source).to_owned(),
None => continue,
};
let name_node = method_node.child_by_field_name("name").unwrap();
let pos = name_node.start_position();
let qualified_name = format!("{}::{}", type_name, method_name);
let visibility = extract_visibility(method_node, source);
let decorators = extract_rust_attributes(method_node, source);
results.push((
SymbolInfo {
name: qualified_name,
kind: SymbolKind::ImplMethod,
line: pos.row + 1,
col: pos.column,
line_end: method_node.end_position().row + 1,
visibility,
trait_impl: trait_name.clone(),
decorators,
..Default::default()
},
vec![],
));
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::languages::language_for_extension;
fn parse_ts(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("ts").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
fn parse_tsx(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("tsx").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
fn first_symbol(results: &[(SymbolInfo, Vec<SymbolInfo>)]) -> &SymbolInfo {
&results
.first()
.unwrap_or_else(|| panic!("expected at least one symbol, got none"))
.0
}
#[test]
fn test_export_function_declaration() {
let src = "export function hello() {}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "hello");
assert_eq!(sym.kind, SymbolKind::Function);
assert!(sym.is_exported, "should be exported");
}
#[test]
fn test_export_const_arrow_function() {
let src = "export const greet = () => {};";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "greet");
assert_eq!(sym.kind, SymbolKind::Function);
assert!(sym.is_exported, "should be exported");
}
#[test]
fn test_class_declaration() {
let src = "class MyClass {}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "MyClass");
assert_eq!(sym.kind, SymbolKind::Class);
assert!(!sym.is_exported);
}
#[test]
fn test_interface_with_children() {
let src = "interface IUser { name: string; getId(): number; }";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let (sym, children) = results.first().expect("expected interface symbol");
assert_eq!(sym.name, "IUser");
assert_eq!(sym.kind, SymbolKind::Interface);
assert_eq!(children.len(), 2, "expected 2 child symbols (name, getId)");
let child_names: Vec<_> = children.iter().map(|c| c.name.as_str()).collect();
assert!(child_names.contains(&"name"), "missing 'name' child");
assert!(child_names.contains(&"getId"), "missing 'getId' child");
assert!(
children.iter().all(|c| c.kind == SymbolKind::Property),
"all children should be Property kind"
);
}
#[test]
fn test_type_alias() {
let src = "type ID = string;";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "ID");
assert_eq!(sym.kind, SymbolKind::TypeAlias);
}
#[test]
fn test_enum_declaration() {
let src = "enum Color { Red, Blue }";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "Color");
assert_eq!(sym.kind, SymbolKind::Enum);
}
#[test]
fn test_tsx_component_detection() {
let src = "export const App = () => <div/>;";
let (tree, lang) = parse_tsx(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, true);
let sym = first_symbol(&results);
assert_eq!(sym.name, "App");
assert_eq!(sym.kind, SymbolKind::Component);
assert!(sym.is_exported);
}
#[test]
fn test_tsx_non_component_arrow_fn() {
let src = "export const add = (a: number, b: number) => a + b;";
let (tree, lang) = parse_tsx(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, true);
let sym = first_symbol(&results);
assert_eq!(sym.name, "add");
assert_eq!(sym.kind, SymbolKind::Function);
}
#[test]
fn test_class_with_methods() {
let src = "class Dog { bark() {} sit() {} }";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let (sym, children) = results.first().expect("expected class");
assert_eq!(sym.kind, SymbolKind::Class);
assert_eq!(children.len(), 2, "expected 2 methods");
assert!(children.iter().all(|c| c.kind == SymbolKind::Method));
}
fn parse_rs(source: &str) -> (tree_sitter::Tree, Language) {
let lang = language_for_extension("rs").unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&lang).unwrap();
let tree = parser.parse(source.as_bytes(), None).unwrap();
(tree, lang)
}
#[test]
fn test_ts_decorator_extraction() {
let src = "@Controller\nclass AppController {}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "AppController");
assert_eq!(
sym.decorators.len(),
1,
"expected 1 decorator, got {:?}",
sym.decorators
);
let dec = &sym.decorators[0];
assert_eq!(dec.name, "Controller");
assert!(dec.object.is_none());
assert!(dec.attribute.is_none());
assert!(dec.args_raw.is_none());
}
#[test]
fn test_ts_attribute_decorator() {
let src = "@Injectable()\nclass MyService {}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "MyService");
assert_eq!(
sym.decorators.len(),
1,
"expected 1 decorator, got {:?}",
sym.decorators
);
let dec = &sym.decorators[0];
assert_eq!(dec.name, "Injectable");
assert!(
dec.args_raw.is_some(),
"expected args_raw to be Some for call decorator"
);
}
#[test]
fn test_rust_derive_decorator() {
let src = "#[derive(Clone, Debug)]\npub struct MyStruct {}";
let (tree, lang) = parse_rs(src);
let results = extract_rust_symbols(&tree, src.as_bytes(), &lang);
let sym = first_symbol(&results);
assert_eq!(sym.name, "MyStruct");
assert_eq!(
sym.decorators.len(),
1,
"expected 1 attribute, got {:?}",
sym.decorators
);
let attr = &sym.decorators[0];
assert_eq!(attr.name, "derive");
assert!(
attr.args_raw.is_some(),
"expected args_raw for derive attribute"
);
let args = attr.args_raw.as_deref().unwrap();
assert!(
args.contains("Clone"),
"args_raw should contain 'Clone', got '{}'",
args
);
assert!(
args.contains("Debug"),
"args_raw should contain 'Debug', got '{}'",
args
);
}
#[test]
fn test_rust_route_decorator() {
let src = "#[get(\"/path\")]\npub fn get_path() {}";
let (tree, lang) = parse_rs(src);
let results = extract_rust_symbols(&tree, src.as_bytes(), &lang);
let sym = first_symbol(&results);
assert_eq!(
sym.decorators.len(),
1,
"expected 1 attribute, got {:?}",
sym.decorators
);
let attr = &sym.decorators[0];
assert_eq!(attr.name, "get");
assert!(
attr.args_raw.is_some(),
"expected args_raw for get attribute"
);
}
#[test]
fn test_line_end_ts() {
let src = "export function hello() {\n return 42;\n}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(sym.name, "hello");
assert!(
sym.line_end > sym.line,
"line_end ({}) should be > line ({}) for multi-line function",
sym.line_end,
sym.line
);
}
#[test]
fn test_line_end_rust() {
let src = "pub fn hello() {\n let x = 1;\n x\n}";
let (tree, lang) = parse_rs(src);
let results = extract_rust_symbols(&tree, src.as_bytes(), &lang);
let sym = first_symbol(&results);
assert_eq!(sym.name, "hello");
assert!(
sym.line_end > sym.line,
"line_end ({}) should be > line ({}) for multi-line function",
sym.line_end,
sym.line
);
}
#[test]
fn test_stacked_decorators() {
let src = "@Controller\n@Injectable\nclass AppService {}";
let (tree, lang) = parse_ts(src);
let results = extract_symbols(&tree, src.as_bytes(), &lang, false);
let sym = first_symbol(&results);
assert_eq!(
sym.decorators.len(),
2,
"expected 2 decorators, got {:?}",
sym.decorators
);
assert_eq!(sym.decorators[0].name, "Controller");
assert_eq!(sym.decorators[1].name, "Injectable");
}
}