use crate::apted::{compute_edit_distance, APTEDOptions};
use crate::tree::TreeNode;
use std::rc::Rc;
pub struct EnhancedSimilarityOptions {
pub structural_weight: f64,
pub size_weight: f64,
pub type_distribution_weight: f64,
pub min_size_ratio: f64,
pub apted_options: APTEDOptions,
}
impl Default for EnhancedSimilarityOptions {
fn default() -> Self {
Self {
structural_weight: 0.4,
size_weight: 0.3,
type_distribution_weight: 0.3,
min_size_ratio: 0.5,
apted_options: APTEDOptions::default(),
}
}
}
pub fn calculate_enhanced_similarity(
tree1: &Rc<TreeNode>,
tree2: &Rc<TreeNode>,
options: &EnhancedSimilarityOptions,
) -> f64 {
let distance = compute_edit_distance(tree1, tree2, &options.apted_options);
let max_size = tree1.get_subtree_size().max(tree2.get_subtree_size()) as f64;
let structural_similarity = if max_size > 0.0 { 1.0 - (distance / max_size) } else { 1.0 };
let size1 = tree1.get_subtree_size() as f64;
let size2 = tree2.get_subtree_size() as f64;
let size_ratio = size1.min(size2) / size1.max(size2).max(1.0);
let size_similarity = if size_ratio < options.min_size_ratio {
size_ratio / options.min_size_ratio } else {
1.0
};
let dist1 = get_node_type_distribution(tree1);
let dist2 = get_node_type_distribution(tree2);
let type_similarity = calculate_distribution_similarity(&dist1, &dist2);
let mut penalty_factor = 1.0;
if size_ratio < 0.5 {
penalty_factor *= 0.8;
}
let semantic_sim = calculate_semantic_similarity(tree1, tree2);
if structural_similarity > 0.7 && semantic_sim < 0.2 {
penalty_factor *= 0.7;
}
let complexity_ratio = calculate_complexity_ratio(tree1, tree2);
if complexity_ratio < 0.5 {
penalty_factor *= 0.85;
}
let total_weight =
options.structural_weight + options.size_weight + options.type_distribution_weight;
let combined_similarity = (structural_similarity * options.structural_weight
+ size_similarity * options.size_weight
+ type_similarity * options.type_distribution_weight)
/ total_weight;
combined_similarity * penalty_factor
}
fn get_node_type_distribution(tree: &TreeNode) -> std::collections::HashMap<String, usize> {
let mut distribution = std::collections::HashMap::new();
count_node_types(tree, &mut distribution);
distribution
}
fn count_node_types(node: &TreeNode, distribution: &mut std::collections::HashMap<String, usize>) {
*distribution.entry(node.label.clone()).or_insert(0) += 1;
for child in &node.children {
count_node_types(child, distribution);
}
}
fn calculate_distribution_similarity(
dist1: &std::collections::HashMap<String, usize>,
dist2: &std::collections::HashMap<String, usize>,
) -> f64 {
let all_types: std::collections::HashSet<_> =
dist1.keys().chain(dist2.keys()).cloned().collect();
if all_types.is_empty() {
return 1.0;
}
let mut intersection = 0;
let mut union = 0;
for node_type in all_types {
let count1 = dist1.get(&node_type).copied().unwrap_or(0);
let count2 = dist2.get(&node_type).copied().unwrap_or(0);
intersection += count1.min(count2);
union += count1.max(count2);
}
if union == 0 {
1.0
} else {
intersection as f64 / union as f64
}
}
pub fn calculate_semantic_similarity(tree1: &TreeNode, tree2: &TreeNode) -> f64 {
let features1 = extract_semantic_features(tree1);
let features2 = extract_semantic_features(tree2);
let mut matches = 0;
let mut total = 0;
let common_identifiers = features1.identifiers.intersection(&features2.identifiers).count();
let total_identifiers = features1.identifiers.union(&features2.identifiers).count();
if total_identifiers > 0 {
matches += common_identifiers;
total += total_identifiers;
}
let common_operators = features1.operators.intersection(&features2.operators).count();
let total_operators = features1.operators.union(&features2.operators).count();
if total_operators > 0 {
matches += common_operators;
total += total_operators;
}
let common_control = features1.control_flow.intersection(&features2.control_flow).count();
let total_control = features1.control_flow.union(&features2.control_flow).count();
if total_control > 0 {
matches += common_control;
total += total_control;
}
if total == 0 {
0.0 } else {
matches as f64 / total as f64
}
}
#[derive(Debug)]
struct SemanticFeatures {
identifiers: std::collections::HashSet<String>,
operators: std::collections::HashSet<String>,
control_flow: std::collections::HashSet<String>,
function_calls: std::collections::HashSet<String>,
}
fn extract_semantic_features(tree: &TreeNode) -> SemanticFeatures {
let mut features = SemanticFeatures {
identifiers: std::collections::HashSet::new(),
operators: std::collections::HashSet::new(),
control_flow: std::collections::HashSet::new(),
function_calls: std::collections::HashSet::new(),
};
extract_features_recursive(tree, &mut features);
features
}
fn calculate_complexity_ratio(tree1: &TreeNode, tree2: &TreeNode) -> f64 {
let complexity1 = calculate_tree_complexity(tree1);
let complexity2 = calculate_tree_complexity(tree2);
if complexity1 == 0 && complexity2 == 0 {
return 1.0;
}
let min_complexity = complexity1.min(complexity2) as f64;
let max_complexity = complexity1.max(complexity2) as f64;
min_complexity / max_complexity
}
fn calculate_tree_complexity(tree: &TreeNode) -> usize {
calculate_complexity_recursive(tree, 0)
}
fn calculate_complexity_recursive(node: &TreeNode, depth: usize) -> usize {
let mut complexity = depth + 1;
match node.label.as_str() {
"if_expression" | "if_statement" | "loop_expression" | "while_expression"
| "for_expression" | "match_expression" => {
complexity += 3;
}
"function_item" | "closure_expression" => {
complexity += 2;
}
_ => {}
}
for child in &node.children {
complexity += calculate_complexity_recursive(child, depth + 1);
}
complexity
}
fn extract_features_recursive(node: &TreeNode, features: &mut SemanticFeatures) {
match node.label.as_str() {
"identifier" => {
if !node.value.is_empty() {
features.identifiers.insert(node.value.clone());
}
}
"+" | "-" | "*" | "/" | "%" | "==" | "!=" | "<" | ">" | "<=" | ">=" | "&&" | "||" | "!"
| "&" | "|" | "^" | "<<" | ">>" => {
features.operators.insert(node.label.clone());
}
"if_expression" | "if_statement" => {
features.control_flow.insert("if".to_string());
}
"loop_expression" | "while_expression" | "for_expression" => {
features.control_flow.insert("loop".to_string());
}
"return_expression" | "return_statement" => {
features.control_flow.insert("return".to_string());
}
"call_expression" => {
features.control_flow.insert("call".to_string());
if let Some(first_child) = node.children.first() {
if first_child.label == "identifier" && !first_child.value.is_empty() {
features.function_calls.insert(first_child.value.clone());
}
}
}
_ => {}
}
for child in &node.children {
extract_features_recursive(child, features);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_size_similarity() {
let options = EnhancedSimilarityOptions::default();
let mut tree1 = TreeNode::new("root".to_string(), "".to_string(), 1);
tree1.add_child(Rc::new(TreeNode::new("child".to_string(), "".to_string(), 2)));
let mut tree2 = TreeNode::new("root".to_string(), "".to_string(), 3);
for i in 0..10 {
tree2.add_child(Rc::new(TreeNode::new("child".to_string(), "".to_string(), 4 + i)));
}
let tree1_rc = Rc::new(tree1);
let tree2_rc = Rc::new(tree2);
let similarity = calculate_enhanced_similarity(&tree1_rc, &tree2_rc, &options);
assert!(
similarity < 0.5,
"Expected low similarity due to size difference, got {}",
similarity
);
}
}