use crate::models::{Class, Function};
use crate::parsers::ParseResult;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Node, Parser, Query, QueryCursor, Language};
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
parse_source(&source, path, ext)
}
pub fn parse_source(source: &str, path: &Path, ext: &str) -> Result<ParseResult> {
let mut parser = Parser::new();
let language: Language = match ext {
"ts" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
"tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
"js" | "jsx" | "mjs" | "cjs" => tree_sitter_javascript::LANGUAGE.into(),
_ => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
};
parser
.set_language(&language)
.context("Failed to set TypeScript/JavaScript language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_functions(&root, source_bytes, path, &mut result, &language)?;
extract_classes(&root, source_bytes, path, &mut result, &language)?;
extract_imports(&root, source_bytes, &mut result, &language)?;
extract_calls(&root, source_bytes, path, &mut result)?;
Ok(result)
}
fn extract_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
language: &Language,
) -> Result<()> {
let query_str = r#"
(function_declaration
name: (identifier) @func_name
parameters: (formal_parameters) @params
) @func
(generator_function_declaration
name: (identifier) @func_name
parameters: (formal_parameters) @params
) @func
(arrow_function
parameters: [(formal_parameters) (identifier)] @params
) @arrow_func
(variable_declarator
name: (identifier) @var_name
value: (arrow_function
parameters: [(formal_parameters) (identifier)] @params
) @arrow_func
)
(variable_declarator
name: (identifier) @var_name
value: (function_expression
parameters: (formal_parameters) @params
) @func_expr
)
(export_statement
declaration: (function_declaration
name: (identifier) @func_name
parameters: (formal_parameters) @params
) @func
)
"#;
let query = Query::new(language, query_str).context("Failed to create function query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut func_node = None;
let mut name = String::new();
let mut params_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"func" | "arrow_func" | "func_expr" => func_node = Some(capture.node),
"func_name" | "var_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
"params" => params_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = func_node {
if is_inside_class(&node) {
continue;
}
if name.is_empty() {
if let Some(parent) = node.parent() {
if parent.kind() == "variable_declarator" {
if let Some(name_node) = parent.child_by_field_name("name") {
name = name_node.utf8_text(source).unwrap_or("").to_string();
}
}
}
}
if name.is_empty() {
continue;
}
let is_async = is_async_function(&node, source);
let parameters = extract_parameters(params_node, source);
let return_type = extract_return_type(&node, source);
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
result.functions.push(Function {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(&node, source)),
});
}
}
Ok(())
}
fn is_inside_class(node: &Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
if parent.kind() == "class_body" {
return true;
}
current = parent.parent();
}
false
}
fn is_async_function(node: &Node, source: &[u8]) -> bool {
if let Ok(text) = node.utf8_text(source) {
if text.starts_with("async ") || text.starts_with("async\n") {
return true;
}
}
for child in node.children(&mut node.walk()) {
if child.kind() == "async" {
return true;
}
}
false
}
fn extract_return_type(func_node: &Node, source: &[u8]) -> Option<String> {
for child in func_node.children(&mut func_node.walk()) {
if child.kind() == "type_annotation" {
return child.utf8_text(source).ok().map(|s| {
s.trim_start_matches(':').trim().to_string()
});
}
}
None
}
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(node) = params_node else {
return vec![];
};
if node.kind() == "identifier" {
return node.utf8_text(source)
.ok()
.map(|s| vec![s.to_string()])
.unwrap_or_default();
}
let mut params = Vec::new();
for child in node.children(&mut node.walk()) {
match child.kind() {
"identifier" => {
if let Ok(text) = child.utf8_text(source) {
params.push(text.to_string());
}
}
"required_parameter" | "optional_parameter" => {
if let Some(pattern) = child.child_by_field_name("pattern") {
if let Ok(text) = pattern.utf8_text(source) {
let name = if child.kind() == "optional_parameter" {
format!("{}?", text)
} else {
text.to_string()
};
params.push(name);
}
}
}
"rest_parameter" => {
if let Some(pattern) = child.child_by_field_name("pattern") {
if let Ok(text) = pattern.utf8_text(source) {
params.push(format!("...{}", text));
}
}
}
"assignment_pattern" => {
if let Some(left) = child.child_by_field_name("left") {
if let Ok(text) = left.utf8_text(source) {
params.push(text.to_string());
}
}
}
_ => {}
}
}
params
}
fn extract_classes(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
language: &Language,
) -> Result<()> {
let ts_query_str = r#"
(class_declaration
name: (type_identifier) @class_name
) @class
(interface_declaration
name: (type_identifier) @iface_name
) @interface
(type_alias_declaration
name: (type_identifier) @type_name
) @type_alias
(export_statement
declaration: (class_declaration
name: (type_identifier) @class_name
) @class
)
"#;
let js_query_str = r#"
(class_declaration
name: (identifier) @class_name
) @class
(export_statement
declaration: (class_declaration
name: (identifier) @class_name
) @class
)
"#;
let query = Query::new(language, ts_query_str)
.or_else(|_| Query::new(language, js_query_str))
.context("Failed to create class query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut class_node = None;
let mut name = String::new();
let mut is_interface = false;
let mut is_type_alias = false;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"class" => class_node = Some(capture.node),
"interface" => {
class_node = Some(capture.node);
is_interface = true;
}
"type_alias" => {
class_node = Some(capture.node);
is_type_alias = true;
}
"class_name" | "iface_name" | "type_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
_ => {}
}
}
if let Some(node) = class_node {
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let kind = if is_interface {
"interface"
} else if is_type_alias {
"type"
} else {
"class"
};
let qualified_name = format!("{}::{}::{}:{}", path.display(), kind, name, line_start);
let bases = if !is_type_alias {
extract_class_heritage(&node, source)
} else {
vec![]
};
let methods = if !is_type_alias && !is_interface {
extract_class_methods(&node, source, path, result, &name);
extract_method_names(&node, source)
} else if is_interface {
extract_interface_methods(&node, source)
} else {
vec![]
};
result.classes.push(Class {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
});
}
}
Ok(())
}
fn extract_class_heritage(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut bases = Vec::new();
for child in class_node.children(&mut class_node.walk()) {
if child.kind() == "class_heritage" {
for heritage_child in child.children(&mut child.walk()) {
if heritage_child.kind() == "extends_clause" || heritage_child.kind() == "implements_clause" {
for type_child in heritage_child.children(&mut heritage_child.walk()) {
if type_child.kind() == "type_identifier" || type_child.kind() == "generic_type" {
if let Ok(text) = type_child.utf8_text(source) {
bases.push(text.to_string());
}
}
}
}
}
}
}
bases
}
fn extract_method_names(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
if let Some(body) = class_node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
if child.kind() == "method_definition" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(name.to_string());
}
}
} else if child.kind() == "public_field_definition" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(name.to_string());
}
}
}
}
}
methods
}
fn extract_interface_methods(iface_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
if let Some(body) = iface_node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
if child.kind() == "method_signature" || child.kind() == "property_signature" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(name.to_string());
}
}
}
}
}
methods
}
fn extract_class_methods(
class_node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
class_name: &str,
) {
if let Some(body) = class_node.child_by_field_name("body") {
for child in body.children(&mut body.walk()) {
if child.kind() == "method_definition" {
if let Some(func) = parse_method_node(&child, source, path, class_name) {
result.functions.push(func);
}
}
}
}
}
fn parse_method_node(node: &Node, source: &[u8], path: &Path, class_name: &str) -> Option<Function> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let params_node = node.child_by_field_name("parameters");
let parameters = extract_parameters(params_node, source);
let return_type = extract_return_type(node, source);
let is_async = is_async_function(node, source);
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}.{}:{}", path.display(), class_name, name, line_start);
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(node, source)),
})
}
fn extract_imports(root: &Node, source: &[u8], result: &mut ParseResult, language: &Language) -> Result<()> {
let query_str = r#"
(import_statement
source: (string) @import_source
)
(import_statement
source: (string
(string_fragment) @import_source
)
)
(export_statement
source: (string) @export_source
)
"#;
let query = Query::new(language, query_str).context("Failed to create import query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
if let Ok(text) = capture.node.utf8_text(source) {
let import = text
.trim_start_matches(|c| c == '"' || c == '\'')
.trim_end_matches(|c| c == '"' || c == '\'')
.to_string();
if !import.is_empty() && !result.imports.contains(&import) {
result.imports.push(import);
}
}
}
}
Ok(())
}
fn extract_calls(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let mut scope_map: HashMap<(u32, u32), String> = HashMap::new();
for func in &result.functions {
scope_map.insert(
(func.line_start, func.line_end),
func.qualified_name.clone(),
);
}
extract_calls_recursive(root, source, path, &scope_map, result);
Ok(())
}
fn extract_calls_recursive(
node: &Node,
source: &[u8],
path: &Path,
scope_map: &HashMap<(u32, u32), String>,
result: &mut ParseResult,
) {
if node.kind() == "call_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map);
if let Some(func_node) = node.child_by_field_name("function") {
let callee = extract_call_target(&func_node, source);
if let (Some(caller), Some(callee)) = (caller, callee) {
result.calls.push((caller, callee));
}
}
}
if node.kind() == "new_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map);
if let Some(constructor) = node.child_by_field_name("constructor") {
if let Ok(callee) = constructor.utf8_text(source) {
if let Some(caller) = caller {
result.calls.push((caller, format!("new {}", callee)));
}
}
}
}
for child in node.children(&mut node.walk()) {
extract_calls_recursive(&child, source, path, scope_map, result);
}
}
fn find_containing_scope(line: u32, scope_map: &HashMap<(u32, u32), String>) -> Option<String> {
let mut best_match: Option<(&(u32, u32), &String)> = None;
for (range, name) in scope_map {
if line >= range.0 && line <= range.1 {
match best_match {
None => best_match = Some((range, name)),
Some((best_range, _)) => {
if (range.1 - range.0) < (best_range.1 - best_range.0) {
best_match = Some((range, name));
}
}
}
}
}
best_match.map(|(_, name)| name.clone())
}
fn extract_call_target(node: &Node, source: &[u8]) -> Option<String> {
match node.kind() {
"identifier" => node.utf8_text(source).ok().map(|s| s.to_string()),
"member_expression" => {
node.utf8_text(source).ok().map(|s| s.to_string())
}
"subscript_expression" => {
node.child_by_field_name("object")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string())
}
_ => node.utf8_text(source).ok().map(|s| s.to_string()),
}
}
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_statement" | "while_statement" | "for_statement" | "for_in_statement" | "do_statement" => {
*complexity += 1;
}
"switch_case" | "switch_default" => {
*complexity += 1;
}
"catch_clause" => {
*complexity += 1;
}
"ternary_expression" => {
*complexity += 1;
}
"binary_expression" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "&&" || child.kind() == "||" {
*complexity += 1;
}
}
}
"arrow_function" | "function_expression" => {
*complexity += 1;
}
"optional_chain" => {
*complexity += 1;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
function hello(name: string): string {
return `Hello, ${name}!`;
}
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "hello");
}
#[test]
fn test_parse_async_function() {
let source = r#"
async function fetchData(url: string): Promise<string> {
return await fetch(url);
}
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert!(func.is_async);
}
#[test]
fn test_parse_arrow_function() {
let source = r#"
const add = (a: number, b: number): number => a + b;
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert!(result.functions.iter().any(|f| f.name == "add"));
}
#[test]
fn test_parse_class() {
let source = r#"
class MyClass extends BaseClass implements Interface {
constructor() {
super();
}
method(): void {
console.log("hello");
}
}
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "MyClass");
}
#[test]
fn test_parse_interface() {
let source = r#"
interface MyInterface {
name: string;
doSomething(): void;
}
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert_eq!(result.classes.len(), 1);
let iface = &result.classes[0];
assert_eq!(iface.name, "MyInterface");
}
#[test]
fn test_parse_imports() {
let source = r#"
import { Component } from 'react';
import axios from 'axios';
import * as fs from 'fs';
export function main() {}
"#;
let path = PathBuf::from("test.ts");
let result = parse_source(source, &path, "ts").unwrap();
assert!(result.imports.iter().any(|i| i == "react"));
assert!(result.imports.iter().any(|i| i == "axios"));
}
#[test]
fn test_parse_javascript() {
let source = r#"
function greet(name) {
return "Hello, " + name;
}
"#;
let path = PathBuf::from("test.js");
let result = parse_source(source, &path, "js").unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "greet");
}
}