use crate::types::Language;
use tree_sitter::Node;
fn find_function_in_node<'a>(node: Node<'a>, max_depth: usize) -> Option<Node<'a>> {
if max_depth == 0 {
return None;
}
if node.kind() == "function_definition" {
return Some(node);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(func) = find_function_in_node(child, max_depth - 1) {
return Some(func);
}
}
None
}
pub fn find_function_node<'a>(
root: Node<'a>,
function_name: &str,
language: Language,
source: &str,
) -> Option<Node<'a>> {
if function_name.contains('.') && !matches!(language, Language::Lua | Language::Luau) {
let parts: Vec<&str> = function_name.split('.').collect();
if parts.len() >= 2 {
let class_name = parts[0];
let remainder = parts[1..].join(".");
if let Some(class_node) = find_class_node(root, class_name, language, source) {
let scope = class_node
.child_by_field_name("body")
.unwrap_or(class_node);
if let Some(found) =
find_function_node_in_subtree(scope, &remainder, language, source)
{
return Some(found);
}
}
let last = *parts.last().unwrap();
return find_function_node_in_subtree(root, last, language, source);
}
}
if function_name.contains("::")
&& matches!(language, Language::C | Language::Cpp)
{
if let Some(found) = find_function_node_in_subtree(root, function_name, language, source)
{
return Some(found);
}
if let Some(found) =
find_cpp_qualified_function_definition(root, function_name, source)
{
return Some(found);
}
let parts: Vec<&str> = function_name.split("::").collect();
if parts.len() >= 2 {
let class_name = parts[0];
let remainder = parts[1..].join("::");
if let Some(class_node) = find_class_node(root, class_name, language, source) {
let scope = class_node
.child_by_field_name("body")
.unwrap_or(class_node);
if let Some(found) =
find_function_node_in_subtree(scope, &remainder, language, source)
{
return Some(found);
}
}
let last = *parts.last().unwrap();
return find_function_node_in_subtree(root, last, language, source);
}
}
find_function_node_in_subtree(root, function_name, language, source)
}
fn find_cpp_qualified_function_definition<'a>(
root: Node<'a>,
qualified_name: &str,
source: &str,
) -> Option<Node<'a>> {
let mut queue: std::collections::VecDeque<Node> =
std::collections::VecDeque::new();
queue.push_back(root);
while let Some(node) = queue.pop_front() {
if node.kind() == "function_definition" {
if let Some(declarator) = node.child_by_field_name("declarator") {
if let Some(inner) = peel_to_function_declarator(declarator) {
if let Some(decl_name) = inner.child_by_field_name("declarator") {
if decl_name.kind() == "qualified_identifier" {
let text = decl_name
.utf8_text(source.as_bytes())
.unwrap_or("")
.trim();
if text == qualified_name {
return Some(node);
}
}
}
}
}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
queue.push_back(cursor.node());
if !cursor.goto_next_sibling() {
break;
}
}
}
}
None
}
fn peel_to_function_declarator(node: Node) -> Option<Node> {
match node.kind() {
"function_declarator" => Some(node),
"pointer_declarator" | "reference_declarator" | "parenthesized_declarator" => {
let inner = node.child_by_field_name("declarator")?;
peel_to_function_declarator(inner)
}
_ => None,
}
}
fn find_function_node_in_subtree<'a>(
root: Node<'a>,
function_name: &str,
language: Language,
source: &str,
) -> Option<Node<'a>> {
let func_kinds = get_function_node_kinds(language);
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if func_kinds.contains(&node.kind()) {
if let Some(name) = get_function_name(node, language, source) {
if name == function_name
|| name
.strip_prefix('#')
.is_some_and(|stripped| stripped == function_name)
|| (matches!(language, Language::Lua | Language::Luau)
&& name.contains('.')
&& name
.rsplit('.')
.next()
.is_some_and(|short| short == function_name))
{
return Some(node);
}
}
}
if matches!(language, Language::TypeScript | Language::JavaScript)
&& matches!(node.kind(), "lexical_declaration" | "variable_declaration")
{
let mut child_cursor = node.walk();
for child in node.children(&mut child_cursor) {
if child.kind() == "variable_declarator" {
if let Some(name_node) = child.child_by_field_name("name") {
let var_name = name_node.utf8_text(source.as_bytes()).unwrap_or("");
if var_name == function_name {
if let Some(value_node) = child.child_by_field_name("value") {
if matches!(
value_node.kind(),
"arrow_function"
| "function"
| "function_expression"
| "generator_function"
) {
return Some(value_node);
}
}
}
}
}
}
}
if matches!(language, Language::TypeScript | Language::JavaScript) {
if node.kind() == "assignment_expression" {
let left = node.child_by_field_name("left");
let right = node.child_by_field_name("right");
if let (Some(left), Some(right)) = (left, right) {
let matches_name = match left.kind() {
"identifier" => {
left.utf8_text(source.as_bytes()).unwrap_or("") == function_name
}
"member_expression" => {
match left.child_by_field_name("property") {
Some(p) => {
p.utf8_text(source.as_bytes()).unwrap_or("") == function_name
}
None => false,
}
}
_ => false,
};
if matches_name
&& matches!(
right.kind(),
"arrow_function"
| "function_expression"
| "function"
| "generator_function"
)
{
return Some(right);
}
}
}
if node.kind() == "pair" {
if let (Some(key), Some(value)) = (
node.child_by_field_name("key"),
node.child_by_field_name("value"),
) {
let key_name = match key.kind() {
"property_identifier" | "identifier" => {
key.utf8_text(source.as_bytes()).unwrap_or("").to_string()
}
"string" => key
.utf8_text(source.as_bytes())
.unwrap_or("")
.trim_matches(|c| c == '"' || c == '\'' || c == '`')
.to_string(),
_ => String::new(),
};
if key_name == function_name
&& matches!(
value.kind(),
"arrow_function"
| "function_expression"
| "function"
| "generator_function"
)
{
return Some(value);
}
}
}
}
if matches!(language, Language::Lua | Language::Luau)
&& matches!(node.kind(), "assignment_statement" | "variable_assignment")
{
let mut child_cursor = node.walk();
let children: Vec<_> = node.children(&mut child_cursor).collect();
for child in &children {
if matches!(child.kind(), "variable_list" | "assignment_variable_list") {
let mut inner_cursor = child.walk();
for inner in child.children(&mut inner_cursor) {
if matches!(inner.kind(), "field_expression" | "dot_index_expression") {
let lhs_text = inner.utf8_text(source.as_bytes()).unwrap_or("");
if let Some(field_name) = lhs_text.rsplit('.').next() {
if field_name == function_name || lhs_text == function_name {
for rhs in &children {
if matches!(
rhs.kind(),
"expression_list" | "assignment_expression_list"
) {
let mut rhs_cursor = rhs.walk();
for rhs_child in rhs.children(&mut rhs_cursor) {
if let Some(func) =
find_function_in_node(rhs_child, 3)
{
return Some(func);
}
}
}
if let Some(func) = find_function_in_node(*rhs, 3) {
return Some(func);
}
}
}
}
}
}
}
}
}
let mut cursor = node.walk();
let children: Vec<_> = node.children(&mut cursor).collect();
for child in children.into_iter().rev() {
stack.push(child);
}
}
None
}
pub fn get_function_node_kinds(language: Language) -> &'static [&'static str] {
match language {
Language::Python => &["function_definition"],
Language::TypeScript | Language::JavaScript => &[
"function_declaration",
"arrow_function",
"method_definition",
"function",
"generator_function_declaration",
"generator_function",
],
Language::Go => &["function_declaration", "method_declaration"],
Language::Rust => &["function_item"],
Language::Java => &["method_declaration", "constructor_declaration"],
Language::C | Language::Cpp => &["function_definition"],
Language::Ruby => &["method", "singleton_method"],
Language::Php => &["function_definition", "method_declaration"],
Language::CSharp => &["method_declaration", "constructor_declaration"],
Language::Kotlin => &["function_declaration"],
Language::Scala => &["function_definition", "function_declaration"],
Language::Elixir => &["call"], Language::Lua | Language::Luau => &[
"function_declaration",
"function_definition",
"local_function",
],
Language::Swift => &["function_declaration", "init_declaration"],
Language::Ocaml => &["let_binding", "value_definition"],
}
}
pub fn get_class_node_kinds(language: Language) -> &'static [&'static str] {
match language {
Language::Python => &["class_definition"],
Language::TypeScript | Language::JavaScript => {
&["class_declaration", "class", "interface_declaration"]
}
Language::Go => &["type_declaration"],
Language::Rust => &[
"struct_item",
"enum_item",
"trait_item",
"impl_item",
"union_item",
],
Language::Java => &[
"class_declaration",
"interface_declaration",
"enum_declaration",
"record_declaration",
],
Language::Cpp => &[
"class_specifier",
"struct_specifier",
"union_specifier",
],
Language::C => &[],
Language::Ruby => &["class", "module"],
Language::Php => &[
"class_declaration",
"interface_declaration",
"trait_declaration",
"enum_declaration",
],
Language::CSharp => &[
"class_declaration",
"struct_declaration",
"interface_declaration",
"record_declaration",
"enum_declaration",
],
Language::Kotlin => &["class_declaration", "object_declaration"],
Language::Scala => &[
"class_definition",
"object_definition",
"trait_definition",
],
Language::Elixir | Language::Lua | Language::Luau | Language::Ocaml => &[],
Language::Swift => &[
"class_declaration",
"protocol_declaration",
"extension_declaration",
],
}
}
pub fn find_class_node<'a>(
root: Node<'a>,
class_name: &str,
language: Language,
source: &str,
) -> Option<Node<'a>> {
let class_kinds = get_class_node_kinds(language);
if class_kinds.is_empty() {
return None;
}
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if class_kinds.contains(&node.kind()) {
if let Some(name_node) = node.child_by_field_name("name") {
let name = name_node.utf8_text(source.as_bytes()).unwrap_or("");
if name == class_name {
return Some(node);
}
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if matches!(
child.kind(),
"identifier"
| "type_identifier"
| "constant"
| "scoped_type_identifier"
) {
let name = child.utf8_text(source.as_bytes()).unwrap_or("");
if name == class_name {
return Some(node);
}
break;
}
}
}
}
let mut cursor = node.walk();
let children: Vec<_> = node.children(&mut cursor).collect();
for child in children.into_iter().rev() {
stack.push(child);
}
}
None
}
fn extract_c_declarator_name(declarator: Node, source: &str) -> Option<String> {
match declarator.kind() {
"identifier" | "field_identifier" => {
Some(
declarator
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
)
}
"destructor_name" => {
Some(
declarator
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
)
}
"qualified_identifier" => {
let mut cursor = declarator.walk();
for child in declarator.children(&mut cursor) {
if child.kind() == "qualified_identifier" {
return extract_c_declarator_name(child, source);
}
}
let mut cursor2 = declarator.walk();
for child in declarator.children(&mut cursor2) {
if matches!(child.kind(), "identifier" | "destructor_name") {
return Some(child.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
if child.kind() == "template_function" {
return child
.child_by_field_name("name")
.or_else(|| child.named_child(0))
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
}
None
}
"template_function" => {
declarator
.child_by_field_name("name")
.or_else(|| declarator.named_child(0))
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
"function_declarator" => {
if let Some(inner) = declarator.child_by_field_name("declarator") {
return extract_c_declarator_name(inner, source);
}
None
}
"pointer_declarator" | "reference_declarator" => {
if let Some(inner) = declarator.child_by_field_name("declarator") {
return extract_c_declarator_name(inner, source);
}
let mut cursor = declarator.walk();
for child in declarator.children(&mut cursor) {
if matches!(
child.kind(),
"function_declarator" | "identifier" | "field_identifier"
) {
return extract_c_declarator_name(child, source);
}
}
None
}
"parenthesized_declarator" => {
let mut cursor = declarator.walk();
for child in declarator.children(&mut cursor) {
if child.is_named() {
if let Some(name) = extract_c_declarator_name(child, source) {
return Some(name);
}
}
}
None
}
_ => None,
}
}
pub fn get_function_name(node: Node, language: Language, source: &str) -> Option<String> {
match language {
Language::C | Language::Cpp => {
if let Some(declarator) = node.child_by_field_name("declarator") {
return extract_c_declarator_name(declarator, source);
}
None
}
Language::Ruby => {
node.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
Language::Php => {
node.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
Language::Elixir => {
if node.kind() == "call" {
let first_child = node.child(0)?;
let first_text = first_child.utf8_text(source.as_bytes()).unwrap_or("");
if first_text == "def" || first_text == "defp" {
if let Some(args) = node.child(1) {
if args.kind() == "identifier" {
return Some(
args.utf8_text(source.as_bytes()).unwrap_or("").to_string(),
);
}
if args.kind() == "arguments" || args.kind() == "call" {
let mut cursor = args.walk();
for child in args.children(&mut cursor) {
if child.kind() == "identifier" {
return Some(
child
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
);
}
if child.kind() == "call" {
if let Some(name) = child.child(0) {
if name.kind() == "identifier" {
return Some(
name.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
);
}
}
}
if child.kind() == "binary_operator" {
let mut bin_cursor = child.walk();
for bin_child in child.children(&mut bin_cursor) {
if bin_child.kind() == "call" {
if let Some(name) = bin_child.child(0) {
if name.kind() == "identifier" {
return Some(
name.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
);
}
}
}
if bin_child.kind() == "identifier" {
return Some(
bin_child
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string(),
);
}
}
}
}
}
}
}
None
} else {
None
}
}
Language::Ocaml => {
if node.kind() == "value_definition" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "let_binding" {
return get_function_name(child, language, source);
}
}
None
} else {
node.child_by_field_name("pattern")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
}
Language::Swift => {
if node.kind() == "init_declaration" {
return Some("init".to_string());
}
node.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
Language::Lua | Language::Luau => {
if let Some(name_node) = node.child_by_field_name("name") {
let name_text = name_node
.utf8_text(source.as_bytes())
.unwrap_or("")
.to_string();
return Some(name_text);
}
if node.kind() == "local_function" || node.kind() == "function_declaration" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if matches!(child.kind(), "identifier" | "dot_index_expression") {
return Some(child.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
}
}
None
}
_ => {
node.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string())
}
}
}
pub fn get_function_body(func_node: Node, language: Language) -> Option<Node> {
match language {
Language::Python => func_node.child_by_field_name("body"),
Language::TypeScript | Language::JavaScript => func_node.child_by_field_name("body"),
Language::Go => func_node.child_by_field_name("body"),
Language::Rust => func_node.child_by_field_name("body"),
Language::Java => func_node.child_by_field_name("body"),
Language::C | Language::Cpp => func_node.child_by_field_name("body"),
Language::Ruby => func_node.child_by_field_name("body"),
Language::Php => func_node.child_by_field_name("body"),
Language::CSharp => func_node.child_by_field_name("body"),
Language::Kotlin => {
func_node.child_by_field_name("body").or_else(|| {
let mut cursor = func_node.walk();
for child in func_node.children(&mut cursor) {
if child.kind() == "function_body" {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "block" {
return Some(inner_child);
}
}
return Some(child);
}
}
None
})
}
Language::Scala => func_node.child_by_field_name("body"),
Language::Elixir => {
let mut cursor = func_node.walk();
for child in func_node.children(&mut cursor) {
if child.kind() == "do_block" {
return Some(child);
}
}
func_node.child_by_field_name("body")
}
Language::Lua | Language::Luau => func_node.child_by_field_name("body"),
Language::Ocaml => {
if func_node.kind() == "value_definition" {
let child_count = func_node.child_count();
let mut binding_body = None;
for i in 0..child_count {
if let Some(child) = func_node.child(i) {
if child.kind() == "let_binding" {
binding_body = child.child_by_field_name("body");
break;
}
}
}
binding_body.or(Some(func_node))
} else {
func_node.child_by_field_name("body").or(Some(func_node))
}
}
_ => func_node.child_by_field_name("body"),
}
}
pub fn get_function_node_kinds_vec(language: Language) -> Vec<&'static str> {
get_function_node_kinds(language).to_vec()
}
pub fn find_function_bounds(
source: &str,
function_name: &str,
language: Language,
) -> Option<(u32, u32)> {
let tree = crate::ast::parser::parse(source, language).ok()?;
let root = tree.root_node();
let node = find_function_node(root, function_name, language, source)?;
let start = node.start_position().row as u32 + 1;
let end = node.end_position().row as u32 + 1;
if end < start {
return None;
}
Some((start, end))
}
pub fn find_function_bounds_from_path_or_source(
source_or_path: &str,
function_name: &str,
language: Language,
) -> Option<(u32, u32)> {
let source = if std::path::Path::new(source_or_path).exists() {
std::fs::read_to_string(source_or_path).ok()?
} else {
source_or_path.to_string()
};
find_function_bounds(&source, function_name, language)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::parser::parse;
#[test]
fn test_ts_generator_function_declaration() {
let source = r#"
function* genNumbers(): Generator<number> {
yield 1;
yield 2;
yield 3;
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "genNumbers", Language::TypeScript, source);
assert!(node.is_some(), "Should find generator function declaration");
let node = node.unwrap();
assert_eq!(node.kind(), "generator_function_declaration");
let name = get_function_name(node, Language::TypeScript, source);
assert_eq!(name.as_deref(), Some("genNumbers"));
let body = get_function_body(node, Language::TypeScript);
assert!(body.is_some(), "Should find body of generator function");
}
#[test]
fn test_ts_async_generator_function() {
let source = r#"
async function* asyncGen(): AsyncGenerator<string> {
yield "hello";
yield "world";
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "asyncGen", Language::TypeScript, source);
assert!(
node.is_some(),
"Should find async generator function declaration"
);
let node = node.unwrap();
assert_eq!(node.kind(), "generator_function_declaration");
let name = get_function_name(node, Language::TypeScript, source);
assert_eq!(name.as_deref(), Some("asyncGen"));
}
#[test]
fn test_ts_generator_function_expression() {
let source = r#"
const genArrow = function*(x: number): Generator<number> {
yield x;
};
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "genArrow", Language::TypeScript, source);
assert!(
node.is_some(),
"Should find generator function expression via const assignment"
);
let node = node.unwrap();
assert_eq!(node.kind(), "generator_function");
let body = get_function_body(node, Language::TypeScript);
assert!(
body.is_some(),
"Should find body of generator function expression"
);
}
#[test]
fn test_js_generator_function_declaration() {
let source = r#"
function* genNumbers() {
yield 1;
yield 2;
}
"#;
let tree = parse(source, Language::JavaScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "genNumbers", Language::JavaScript, source);
assert!(
node.is_some(),
"Should find JS generator function declaration"
);
let node = node.unwrap();
assert_eq!(node.kind(), "generator_function_declaration");
}
#[test]
fn test_js_async_generator_function() {
let source = r#"
async function* asyncIter() {
yield "a";
yield "b";
}
"#;
let tree = parse(source, Language::JavaScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "asyncIter", Language::JavaScript, source);
assert!(node.is_some(), "Should find JS async generator function");
}
#[test]
fn test_ts_regular_function() {
let source = r#"
function greet(name: string): string {
const greeting = "Hello, " + name;
return greeting;
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "greet", Language::TypeScript, source);
assert!(node.is_some(), "Should find regular function declaration");
assert_eq!(node.unwrap().kind(), "function_declaration");
}
#[test]
fn test_ts_arrow_function() {
let source = r#"
const add = (a: number, b: number): number => {
return a + b;
};
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "add", Language::TypeScript, source);
assert!(node.is_some(), "Should find arrow function via const");
assert_eq!(node.unwrap().kind(), "arrow_function");
}
#[test]
fn test_ts_class_method() {
let source = r#"
class MyClass {
myMethod(x: number): number {
return x * 2;
}
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "myMethod", Language::TypeScript, source);
assert!(node.is_some(), "Should find class method");
assert_eq!(node.unwrap().kind(), "method_definition");
}
#[test]
fn test_ts_exported_function() {
let source = r#"
export function fetchData(url: string): Promise<string> {
return fetch(url).then(r => r.text());
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "fetchData", Language::TypeScript, source);
assert!(node.is_some(), "Should find exported function");
}
#[test]
fn test_ts_exported_generator() {
let source = r#"
export function* items(): Generator<number> {
yield 1;
yield 2;
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "items", Language::TypeScript, source);
assert!(node.is_some(), "Should find exported generator function");
}
#[test]
fn test_ts_generator_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
function* genNumbers(): Generator<number> {
const x = 1;
yield x;
const y = 2;
yield y;
}
"#;
let result = get_dfg_context(source, "genNumbers", Language::TypeScript);
assert!(
result.is_ok(),
"DFG should succeed for generator functions, got: {:?}",
result.err()
);
let dfg = result.unwrap();
assert_eq!(dfg.function, "genNumbers");
assert!(
!dfg.variables.is_empty(),
"Should find variables in generator function"
);
}
#[test]
fn test_ts_async_generator_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
async function* asyncGen(): AsyncGenerator<string> {
const msg = "hello";
yield msg;
}
"#;
let result = get_dfg_context(source, "asyncGen", Language::TypeScript);
assert!(
result.is_ok(),
"DFG should succeed for async generator functions, got: {:?}",
result.err()
);
}
#[test]
fn test_ts_generator_cfg() {
use crate::cfg::get_cfg_context;
let source = r#"
function* genNumbers(): Generator<number> {
const x = 1;
yield x;
}
"#;
let result = get_cfg_context(source, "genNumbers", Language::TypeScript);
assert!(result.is_ok());
let cfg = result.unwrap();
assert_eq!(cfg.function, "genNumbers");
assert!(
!cfg.blocks.is_empty(),
"CFG should have blocks for generator function"
);
}
#[test]
fn test_ts_node_kinds_include_generators() {
let kinds = get_function_node_kinds(Language::TypeScript);
assert!(
kinds.contains(&"generator_function_declaration"),
"TypeScript node kinds should include generator_function_declaration"
);
assert!(
kinds.contains(&"generator_function"),
"TypeScript node kinds should include generator_function"
);
}
#[test]
fn test_js_node_kinds_include_generators() {
let kinds = get_function_node_kinds(Language::JavaScript);
assert!(
kinds.contains(&"generator_function_declaration"),
"JavaScript node kinds should include generator_function_declaration"
);
assert!(
kinds.contains(&"generator_function"),
"JavaScript node kinds should include generator_function"
);
}
#[test]
fn test_c_pointer_returning_function() {
let source = r#"
typedef struct { int x; } MyStruct;
MyStruct *createStruct(void) {
int y = 1;
return (MyStruct*)0;
}
"#;
let tree = parse(source, Language::C).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "createStruct", Language::C, source);
assert!(
node.is_some(),
"Should find C function with pointer return type"
);
let node = node.unwrap();
assert_eq!(node.kind(), "function_definition");
let name = get_function_name(node, Language::C, source);
assert_eq!(name.as_deref(), Some("createStruct"));
}
#[test]
fn test_c_pointer_returning_function_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
typedef struct { int x; } MyStruct;
MyStruct *createStruct(int val) {
int y = val + 1;
return (MyStruct*)0;
}
"#;
let result = get_dfg_context(source, "createStruct", Language::C);
assert!(
result.is_ok(),
"DFG should succeed for C pointer-returning function, got: {:?}",
result.err()
);
let dfg = result.unwrap();
assert_eq!(dfg.function, "createStruct");
}
#[test]
fn test_cpp_pointer_returning_function() {
let source = r#"
struct Node { int val; };
Node *createNode(int x) {
int temp = x * 2;
return nullptr;
}
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "createNode", Language::Cpp, source);
assert!(
node.is_some(),
"Should find C++ function with pointer return type"
);
let name = get_function_name(node.unwrap(), Language::Cpp, source);
assert_eq!(name.as_deref(), Some("createNode"));
}
#[test]
fn test_swift_init_declaration() {
let source = r#"
class App {
init(port: Int) {
let x = port + 1
}
}
"#;
let tree = parse(source, Language::Swift).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "init", Language::Swift, source);
assert!(node.is_some(), "Should find Swift init declaration");
}
#[test]
fn test_swift_init_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
class Server {
init(port: Int) {
let addr = port + 1000
}
}
"#;
let result = get_dfg_context(source, "init", Language::Swift);
assert!(
result.is_ok(),
"DFG should succeed for Swift init, got: {:?}",
result.err()
);
}
#[test]
fn test_swift_node_kinds_include_init() {
let kinds = get_function_node_kinds(Language::Swift);
assert!(
kinds.contains(&"init_declaration"),
"Swift node kinds should include init_declaration"
);
}
#[test]
fn test_lua_dot_indexed_function_short_name() {
let source = r#"
function Kong.init()
local x = 1
return x
end
"#;
let tree = parse(source, Language::Lua).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "init", Language::Lua, source);
assert!(
node.is_some(),
"Should find Lua dot-indexed function by short name 'init'"
);
}
#[test]
fn test_lua_dot_indexed_function_full_name() {
let source = r#"
function Kong.init()
local x = 1
return x
end
"#;
let tree = parse(source, Language::Lua).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "Kong.init", Language::Lua, source);
assert!(
node.is_some(),
"Should find Lua dot-indexed function by full name 'Kong.init'"
);
}
#[test]
fn test_lua_dot_indexed_function_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
function M.request(url)
local result = url .. "/api"
return result
end
"#;
let result = get_dfg_context(source, "request", Language::Lua);
assert!(
result.is_ok(),
"DFG should succeed for Lua dot-indexed function by short name, got: {:?}",
result.err()
);
}
#[test]
fn test_luau_dot_indexed_function_short_name() {
let source = r#"
function Module.process(data)
local x = data + 1
return x
end
"#;
let tree = parse(source, Language::Luau).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "process", Language::Luau, source);
assert!(
node.is_some(),
"Should find Luau dot-indexed function by short name 'process'"
);
}
#[test]
fn test_cpp_qualified_method_definition() {
let source = r#"
class MyClass {
public:
void externalMethod();
};
void MyClass::externalMethod() {
int x = 1;
}
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "externalMethod", Language::Cpp, source);
assert!(
node.is_some(),
"Should find C++ qualified method definition (ClassName::method)"
);
let node = node.unwrap();
assert_eq!(node.kind(), "function_definition");
let name = get_function_name(node, Language::Cpp, source);
assert_eq!(
name.as_deref(),
Some("externalMethod"),
"get_function_name should extract bare name from qualified C++ method"
);
}
#[test]
fn test_cpp_qualified_method_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
void MyClass::compute(int a) {
int b = a + 1;
int c = b * 2;
}
"#;
let result = get_dfg_context(source, "compute", Language::Cpp);
assert!(
result.is_ok(),
"DFG should succeed for C++ qualified method, got: {:?}",
result.err()
);
let dfg = result.unwrap();
assert_eq!(dfg.function, "compute");
}
#[test]
fn test_cpp_inline_class_method() {
let source = r#"
class MyClass {
public:
void myMethod() {
int x = 1;
}
};
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "myMethod", Language::Cpp, source);
assert!(
node.is_some(),
"Should find C++ inline class method (field_identifier)"
);
let node = node.unwrap();
assert_eq!(node.kind(), "function_definition");
let name = get_function_name(node, Language::Cpp, source);
assert_eq!(
name.as_deref(),
Some("myMethod"),
"get_function_name should extract name from inline C++ class method"
);
}
#[test]
fn test_cpp_inline_class_method_dfg() {
use crate::dfg::get_dfg_context;
let source = r#"
class Widget {
public:
int calculate(int a, int b) {
int sum = a + b;
int product = a * b;
return sum;
}
};
"#;
let result = get_dfg_context(source, "calculate", Language::Cpp);
assert!(
result.is_ok(),
"DFG should succeed for C++ inline class method, got: {:?}",
result.err()
);
let dfg = result.unwrap();
assert_eq!(dfg.function, "calculate");
assert!(
!dfg.variables.is_empty(),
"Should find variables in inline class method"
);
}
#[test]
fn test_cpp_namespace_function() {
let source = r#"
namespace Foo {
void bar() {
int x = 1;
}
}
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "bar", Language::Cpp, source);
assert!(node.is_some(), "Should find C++ function inside namespace");
}
#[test]
fn test_cpp_const_qualified_method() {
let source = r#"
bool NormalizedStringType::isNever() const {
return !isCofinite && singletons.empty();
}
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "isNever", Language::Cpp, source);
assert!(node.is_some(), "Should find C++ const qualified method");
let name = get_function_name(node.unwrap(), Language::Cpp, source);
assert_eq!(name.as_deref(), Some("isNever"));
}
#[test]
fn test_cpp_nested_namespace_qualified_method() {
let source = r#"
void Luau::Analysis::Normalizer::normalize() {
int x = 1;
}
"#;
let tree = parse(source, Language::Cpp).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "normalize", Language::Cpp, source);
assert!(node.is_some(), "Should find deeply qualified C++ method");
let name = get_function_name(node.unwrap(), Language::Cpp, source);
assert_eq!(name.as_deref(), Some("normalize"));
}
#[test]
fn test_c_inline_struct_method_like() {
let source = r#"
void process(int x) {
int y = x + 1;
}
"#;
let tree = parse(source, Language::C).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "process", Language::C, source);
assert!(node.is_some(), "Should find simple C function");
}
#[test]
fn test_lua_local_function_node_kinds() {
let kinds = get_function_node_kinds(Language::Lua);
let ast_kinds = crate::security::ast_utils::function_node_kinds(Language::Lua);
for kind in ast_kinds {
assert!(
kinds.contains(kind),
"function_finder should include '{}' which ast_utils includes for Lua",
kind
);
}
}
#[test]
fn test_luau_local_function_node_kinds() {
let kinds = get_function_node_kinds(Language::Luau);
let ast_kinds = crate::security::ast_utils::function_node_kinds(Language::Luau);
for kind in ast_kinds {
assert!(
kinds.contains(kind),
"function_finder should include '{}' which ast_utils includes for Luau",
kind
);
}
}
#[test]
fn test_qualified_class_method_python() {
let source = r#"
class Flask:
def run(self):
x = 1
return x
def run():
pass
"#;
let tree = parse(source, Language::Python).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "Flask.run", Language::Python, source);
assert!(node.is_some(), "Should resolve Flask.run to class method");
let node = node.unwrap();
let mut p = node.parent();
let mut found_class = false;
while let Some(parent) = p {
if parent.kind() == "class_definition" {
if let Some(name_node) = parent.child_by_field_name("name") {
if name_node.utf8_text(source.as_bytes()).unwrap_or("") == "Flask" {
found_class = true;
break;
}
}
}
p = parent.parent();
}
assert!(
found_class,
"Resolved method must be lexically inside class Flask (not the bare top-level run)"
);
}
#[test]
fn test_complexity_unqualified_still_works() {
let source = r#"
class Flask:
def run(self):
if True:
pass
"#;
let tree = parse(source, Language::Python).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "run", Language::Python, source);
assert!(node.is_some(), "Bare-name 'run' must still resolve");
}
#[test]
fn test_qualified_class_not_found_falls_back_to_method() {
let source = r#"
def standalone():
return 42
"#;
let tree = parse(source, Language::Python).unwrap();
let root = tree.root_node();
let node = find_function_node(
root,
"NonexistentClass.standalone",
Language::Python,
source,
);
assert!(
node.is_some(),
"Should fall back to bare 'standalone' when class doesn't exist"
);
}
#[test]
fn test_qualified_class_method_typescript() {
let source = r#"
class Server {
start(): void {
const port = 8080;
}
}
class Logger {
start(): void {}
}
"#;
let tree = parse(source, Language::TypeScript).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "Server.start", Language::TypeScript, source);
assert!(node.is_some(), "Should resolve Server.start in TS");
let node = node.unwrap();
let mut p = node.parent();
let mut found = false;
while let Some(parent) = p {
if parent.kind() == "class_declaration" {
if let Some(name_node) = parent.child_by_field_name("name") {
if name_node.utf8_text(source.as_bytes()).unwrap_or("") == "Server" {
found = true;
break;
}
}
}
p = parent.parent();
}
assert!(found, "Resolved method must be inside class Server");
}
#[test]
fn test_qualified_class_method_rust_impl() {
let source = r#"
struct Foo;
impl Foo {
fn process(&self) -> i32 {
let x = 1;
x
}
}
"#;
let tree = parse(source, Language::Rust).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "Foo.process", Language::Rust, source);
assert!(node.is_some(), "Should resolve Foo.process via impl block");
}
#[test]
fn test_qualified_class_method_java() {
let source = r#"
class Calculator {
int add(int a, int b) {
return a + b;
}
}
"#;
let tree = parse(source, Language::Java).unwrap();
let root = tree.root_node();
let node = find_function_node(root, "Calculator.add", Language::Java, source);
assert!(node.is_some(), "Should resolve Calculator.add in Java");
}
#[test]
fn test_qualified_lookup_via_complexity_python() {
use crate::metrics::calculate_complexity;
let source = r#"
class Flask:
def run(self, debug=False):
if debug:
print("debug")
else:
print("normal")
return 0
"#;
let metrics = calculate_complexity(source, "Flask.run", Language::Python).unwrap();
assert!(
metrics.cyclomatic >= 2,
"Flask.run cyclomatic should be >= 2, got {}",
metrics.cyclomatic
);
}
#[test]
fn test_qualified_lookup_via_dfg_python() {
use crate::dfg::get_dfg_context;
let source = r#"
class Service:
def handler(self, req):
x = req.data
y = x + 1
return y
"#;
let result = get_dfg_context(source, "Service.handler", Language::Python);
assert!(
result.is_ok(),
"DFG should succeed for Service.handler, got: {:?}",
result.err()
);
let dfg = result.unwrap();
assert_eq!(dfg.function, "Service.handler");
}
#[test]
fn test_qualified_lookup_via_cfg_python() {
use crate::cfg::get_cfg_context;
let source = r#"
class Worker:
def run(self):
if True:
return 1
return 0
"#;
let result = get_cfg_context(source, "Worker.run", Language::Python);
assert!(
result.is_ok(),
"CFG should succeed for Worker.run, got: {:?}",
result.err()
);
}
#[test]
fn test_qualified_class_method_disambiguates_overloaded() {
let source = r#"
class Alpha:
def shared(self):
a = 1
return a
class Beta:
def shared(self):
b = 2
c = b + 1
return c
"#;
let tree = parse(source, Language::Python).unwrap();
let root = tree.root_node();
let beta_method =
find_function_node(root, "Beta.shared", Language::Python, source).unwrap();
let mut p = beta_method.parent();
let mut enclosing = None;
while let Some(parent) = p {
if parent.kind() == "class_definition" {
enclosing = parent
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap_or("").to_string());
break;
}
p = parent.parent();
}
assert_eq!(
enclosing.as_deref(),
Some("Beta"),
"Beta.shared must resolve to the method inside class Beta"
);
}
#[test]
fn test_find_class_node_python() {
let source = r#"
class Outer:
pass
class Inner:
pass
"#;
let tree = parse(source, Language::Python).unwrap();
let root = tree.root_node();
let inner = find_class_node(root, "Inner", Language::Python, source);
assert!(inner.is_some(), "Should find class Inner");
let nonexistent = find_class_node(root, "DoesNotExist", Language::Python, source);
assert!(
nonexistent.is_none(),
"Nonexistent class should return None"
);
}
#[test]
fn test_find_class_node_languages_without_classes() {
let source = "int foo() { return 0; }";
let tree = parse(source, Language::C).unwrap();
let root = tree.root_node();
let result = find_class_node(root, "Anything", Language::C, source);
assert!(result.is_none(), "C has no class kinds — must return None");
}
}