use super::tree_edit_distance::LabeledTree;
pub fn ast_to_labeled_tree(node: tree_sitter::Node<'_>, source: &str) -> LabeledTree {
let label = node_to_label(node, source);
let mut children = Vec::new();
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if !is_trivial_node(child) {
children.push(ast_to_labeled_tree(child, source));
}
if !cursor.goto_next_sibling() {
break;
}
}
}
LabeledTree::with_children(label, children)
}
fn node_to_label(node: tree_sitter::Node<'_>, _source: &str) -> String {
let kind = node.kind();
match kind {
"if_statement" | "if_expression" => "if".to_string(),
"else_clause" | "else" => "else".to_string(),
"elif_clause" => "elif".to_string(),
"for_statement" | "for_in_statement" | "for_expression" => "for".to_string(),
"while_statement" | "while_expression" => "while".to_string(),
"do_statement" => "do_while".to_string(),
"switch_statement" | "match_expression" => "switch".to_string(),
"case_clause" | "match_arm" => "case".to_string(),
"try_statement" => "try".to_string(),
"catch_clause" | "except_clause" => "catch".to_string(),
"finally_clause" => "finally".to_string(),
"function_definition"
| "function_declaration"
| "method_definition"
| "function_item"
| "method_declaration" => "funcdef".to_string(),
"arrow_function" | "lambda" | "lambda_expression" => "lambda".to_string(),
"class_definition" | "class_declaration" => "classdef".to_string(),
"return_statement" => "return".to_string(),
"expression_statement" => "expr_stmt".to_string(),
"assignment" | "assignment_expression" => "assign".to_string(),
"augmented_assignment" => "aug_assign".to_string(),
"variable_declaration" | "lexical_declaration" | "let_declaration" => "declare".to_string(),
"call_expression" | "call" => "call".to_string(),
"binary_expression" | "binary_operator" => "binop".to_string(),
"unary_expression" | "unary_operator" => "unop".to_string(),
"comparison_operator" => "compare".to_string(),
"boolean_operator" => "boolop".to_string(),
"subscript" | "subscript_expression" => "subscript".to_string(),
"attribute" | "member_expression" | "field_expression" => "member".to_string(),
"identifier" | "property_identifier" | "field_identifier" | "type_identifier" => {
"$ID".to_string()
}
"integer" | "integer_literal" | "number" | "float" | "float_literal" => "$NUM".to_string(),
"string" | "string_literal" | "template_string" | "string_content" => "$STR".to_string(),
"true" | "false" | "boolean" => "$BOOL".to_string(),
"none" | "null" | "nil" => "$NULL".to_string(),
"parameters" | "formal_parameters" | "parameter_list" => "params".to_string(),
"parameter" | "simple_parameter" | "typed_parameter" | "typed_default_parameter" => {
"param".to_string()
}
"argument_list" | "arguments" => "args".to_string(),
"argument" | "keyword_argument" => "arg".to_string(),
"block" | "statement_block" | "compound_statement" => "block".to_string(),
"import_statement" | "import_declaration" => "import".to_string(),
"=" | "+=" | "-=" | "*=" | "/=" => "op_assign".to_string(),
"+" | "-" | "*" | "/" | "%" | "**" => "op_arith".to_string(),
"==" | "!=" | "<" | ">" | "<=" | ">=" => "op_cmp".to_string(),
"&&" | "||" | "and" | "or" | "not" | "!" => "op_logic".to_string(),
_ => kind.to_string(),
}
}
fn is_trivial_node(node: tree_sitter::Node<'_>) -> bool {
let kind = node.kind();
matches!(
kind,
"(" | ")"
| "{"
| "}"
| "["
| "]"
| ";"
| ","
| ":"
| "."
| "->"
| "=>"
| "::"
| "comment"
| "line_comment"
| "block_comment"
| "newline"
| "indent"
| "dedent"
| "NEWLINE"
| "INDENT"
| "DEDENT"
| "\n"
)
}
pub fn parse_to_labeled_tree(
source: &str,
ts_language: tree_sitter::Language,
) -> Option<LabeledTree> {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&ts_language).ok()?;
let tree = parser.parse(source, None)?;
Some(ast_to_labeled_tree(tree.root_node(), source))
}
pub fn ast_to_labeled_tree_bounded(
node: tree_sitter::Node<'_>,
source: &str,
max_depth: usize,
) -> LabeledTree {
ast_to_labeled_tree_inner(node, source, max_depth, 0)
}
fn ast_to_labeled_tree_inner(
node: tree_sitter::Node<'_>,
source: &str,
max_depth: usize,
current_depth: usize,
) -> LabeledTree {
let label = node_to_label(node, source);
if max_depth > 0 && current_depth >= max_depth {
return LabeledTree::new(label);
}
let mut children = Vec::new();
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if !is_trivial_node(child) {
children.push(ast_to_labeled_tree_inner(
child,
source,
max_depth,
current_depth + 1,
));
}
if !cursor.goto_next_sibling() {
break;
}
}
}
LabeledTree::with_children(label, children)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_tree(source: &str, ts_language: tree_sitter::Language) -> LabeledTree {
parse_to_labeled_tree(source, ts_language).expect("Parsing should succeed")
}
fn python_lang() -> tree_sitter::Language {
tree_sitter_python::LANGUAGE.into()
}
fn js_lang() -> tree_sitter::Language {
tree_sitter_javascript::LANGUAGE.into()
}
#[test]
fn test_ast_tree_has_proper_nesting() {
let source = "def foo(x):\n if x > 0:\n return x\n return 0\n";
let tree = parse_tree(source, python_lang());
assert!(
tree.size() > 4,
"AST tree should have multiple nested nodes, got size {}",
tree.size()
);
fn max_depth(t: &LabeledTree) -> usize {
if t.children.is_empty() {
1
} else {
1 + t.children.iter().map(max_depth).max().unwrap_or(0)
}
}
let depth = max_depth(&tree);
assert!(
depth >= 3,
"AST tree should have depth >= 3 for nested code, got {depth}"
);
}
#[test]
fn test_filtering_removes_punctuation() {
let source = "x = [1, 2, 3]";
let tree = parse_tree(source, python_lang());
fn has_punctuation(t: &LabeledTree) -> bool {
let punct = ["(", ")", "{", "}", "[", "]", ";", ",", ":", "."];
if punct.contains(&t.label.as_str()) {
return true;
}
t.children.iter().any(has_punctuation)
}
assert!(
!has_punctuation(&tree),
"AST tree should not contain punctuation nodes"
);
}
#[test]
fn test_filtering_removes_comments() {
let source = "# This is a comment\nx = 1\n# Another comment\ny = 2\n";
let tree = parse_tree(source, python_lang());
fn has_comment(t: &LabeledTree) -> bool {
if t.label == "comment" || t.label == "line_comment" || t.label == "block_comment" {
return true;
}
t.children.iter().any(has_comment)
}
assert!(
!has_comment(&tree),
"AST tree should not contain comment nodes"
);
}
#[test]
fn test_labels_are_abstract() {
let source = "def foo(x):\n return x + 1\n";
let tree = parse_tree(source, python_lang());
fn collect_labels(t: &LabeledTree, labels: &mut Vec<String>) {
labels.push(t.label.clone());
for child in &t.children {
collect_labels(child, labels);
}
}
let mut labels = Vec::new();
collect_labels(&tree, &mut labels);
assert!(
labels
.iter()
.any(|l| l == "funcdef" || l == "$ID" || l == "return" || l == "params"),
"Should contain abstract labels like funcdef, $ID, return. Got: {labels:?}"
);
}
#[test]
fn test_identifiers_normalized() {
let source = "x = foo(bar)";
let tree = parse_tree(source, python_lang());
fn has_id(t: &LabeledTree) -> bool {
if t.label == "$ID" {
return true;
}
t.children.iter().any(has_id)
}
assert!(has_id(&tree), "Identifiers should be normalized to $ID");
}
#[test]
fn test_literals_normalized() {
let source = "x = 42\ny = \"hello\"\nz = True\n";
let tree = parse_tree(source, python_lang());
fn collect_labels(t: &LabeledTree, labels: &mut Vec<String>) {
labels.push(t.label.clone());
for child in &t.children {
collect_labels(child, labels);
}
}
let mut labels = Vec::new();
collect_labels(&tree, &mut labels);
assert!(
labels.contains(&"$NUM".to_string()),
"Numeric literals should be normalized to $NUM. Labels: {labels:?}"
);
assert!(
labels.contains(&"$STR".to_string()),
"String literals should be normalized to $STR. Labels: {labels:?}"
);
}
#[test]
fn test_python_and_js_similar_labels() {
let python_source = "def add(a, b):\n return a + b\n";
let js_source = "function add(a, b) { return a + b; }";
let py_tree = parse_tree(python_source, python_lang());
let js_tree = parse_tree(js_source, js_lang());
fn collect_labels(t: &LabeledTree) -> Vec<String> {
let mut labels = vec![t.label.clone()];
for child in &t.children {
labels.extend(collect_labels(child));
}
labels
}
let py_labels = collect_labels(&py_tree);
let js_labels = collect_labels(&js_tree);
let common_labels = ["funcdef", "return", "$ID", "params"];
for expected in &common_labels {
let expected_str = expected.to_string();
assert!(
py_labels.contains(&expected_str),
"Python tree should contain '{expected}'. Labels: {py_labels:?}"
);
assert!(
js_labels.contains(&expected_str),
"JavaScript tree should contain '{expected}'. Labels: {js_labels:?}"
);
}
}
#[test]
fn test_empty_source_produces_minimal_tree() {
let source = "";
let tree = parse_tree(source, python_lang());
assert!(
tree.size() >= 1,
"Empty source should produce at least a root node"
);
}
#[test]
fn test_single_statement() {
let source = "pass";
let tree = parse_tree(source, python_lang());
assert!(tree.size() >= 1);
}
#[test]
fn test_bounded_depth_limits_tree() {
let source =
"def foo(x):\n if x > 0:\n for i in range(x):\n print(i)\n";
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&python_lang())
.expect("Language should set");
let parsed = parser.parse(source, None).expect("Parse should succeed");
let full_tree = ast_to_labeled_tree(parsed.root_node(), source);
let bounded_tree = ast_to_labeled_tree_bounded(parsed.root_node(), source, 3);
fn max_depth(t: &LabeledTree) -> usize {
if t.children.is_empty() {
1
} else {
1 + t.children.iter().map(max_depth).max().unwrap_or(0)
}
}
let full_depth = max_depth(&full_tree);
let bounded_depth = max_depth(&bounded_tree);
assert!(
bounded_depth <= full_depth,
"Bounded tree (depth {bounded_depth}) should not be deeper than full tree (depth {full_depth})"
);
assert!(
bounded_depth <= 4,
"Bounded tree depth should be <= max_depth+1 (4), got {bounded_depth}"
);
assert!(
bounded_depth < full_depth,
"Bounded tree (depth {bounded_depth}) should be shallower than full tree (depth {full_depth})"
);
}
#[test]
fn test_bounded_depth_zero_means_unlimited() {
let source = "def foo(x):\n return x + 1\n";
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&python_lang())
.expect("Language should set");
let parsed = parser.parse(source, None).expect("Parse should succeed");
let full_tree = ast_to_labeled_tree(parsed.root_node(), source);
let unbounded_tree = ast_to_labeled_tree_bounded(parsed.root_node(), source, 0);
assert_eq!(
full_tree.size(),
unbounded_tree.size(),
"Depth 0 (unlimited) should produce the same tree"
);
}
#[test]
fn test_ast_trees_usable_with_ted() {
use crate::clones::tree_edit_distance::tree_edit_distance;
let source_a = "def add(a, b):\n return a + b\n";
let source_b = "def sub(a, b):\n return a - b\n";
let tree_a = parse_tree(source_a, python_lang());
let tree_b = parse_tree(source_b, python_lang());
let dist = tree_edit_distance(&tree_a, &tree_b);
assert!(
dist < tree_a.size().max(tree_b.size()),
"Similar functions should have distance less than max size"
);
}
#[test]
fn test_ast_trees_usable_with_apted() {
use crate::clones::apted::apted_distance;
let source_a = "def add(a, b):\n return a + b\n";
let source_b = "def sub(a, b):\n return a - b\n";
let tree_a = parse_tree(source_a, python_lang());
let tree_b = parse_tree(source_b, python_lang());
let dist = apted_distance(&tree_a, &tree_b);
assert!(
dist < tree_a.size().max(tree_b.size()),
"Similar functions should have small APTED distance"
);
}
#[test]
fn test_trivial_node_detection() {
let source = "result = foo(a, b, c)";
let tree = parse_tree(source, python_lang());
fn find_label(t: &LabeledTree, target: &str) -> bool {
if t.label == target {
return true;
}
t.children.iter().any(|c| find_label(c, target))
}
assert!(!find_label(&tree, "("), "Should not contain '('");
assert!(!find_label(&tree, ")"), "Should not contain ')'");
assert!(!find_label(&tree, ","), "Should not contain ','");
}
#[test]
fn test_control_flow_labels() {
let source = r#"
def foo(x):
if x > 0:
return 1
else:
return 0
for i in range(x):
pass
while True:
break
"#;
let tree = parse_tree(source, python_lang());
fn collect_labels(t: &LabeledTree) -> Vec<String> {
let mut labels = vec![t.label.clone()];
for child in &t.children {
labels.extend(collect_labels(child));
}
labels
}
let labels = collect_labels(&tree);
assert!(labels.contains(&"if".to_string()), "Should contain 'if'");
assert!(
labels.contains(&"else".to_string()),
"Should contain 'else'"
);
assert!(labels.contains(&"for".to_string()), "Should contain 'for'");
assert!(
labels.contains(&"while".to_string()),
"Should contain 'while'"
);
assert!(
labels.contains(&"funcdef".to_string()),
"Should contain 'funcdef'"
);
assert!(
labels.contains(&"return".to_string()),
"Should contain 'return'"
);
}
#[test]
fn test_js_function_labels() {
let source = r#"
function greet(name) {
if (name === "world") {
return "Hello, World!";
}
return "Hello, " + name;
}
"#;
let tree = parse_tree(source, js_lang());
fn collect_labels(t: &LabeledTree) -> Vec<String> {
let mut labels = vec![t.label.clone()];
for child in &t.children {
labels.extend(collect_labels(child));
}
labels
}
let labels = collect_labels(&tree);
assert!(
labels.contains(&"funcdef".to_string()),
"JS function should have 'funcdef' label. Labels: {labels:?}"
);
assert!(
labels.contains(&"if".to_string()),
"JS if-statement should have 'if' label. Labels: {labels:?}"
);
assert!(
labels.contains(&"return".to_string()),
"JS return should have 'return' label. Labels: {labels:?}"
);
}
}