semantic/analysis/
analysis_similarity.rs1use std::collections::{HashMap, HashSet};
5
6use crate::parser::{Language, ParsedFile};
7
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum SimilarityMethod {
11 Lines,
13 Tokens,
15 Ast,
17}
18
19pub fn compute_similarity(a: &str, b: &str, method: SimilarityMethod) -> f64 {
21 match method {
22 SimilarityMethod::Lines => {
23 let lines_a: HashSet<&str> = a.lines().filter(|l| !l.trim().is_empty()).collect();
24 let lines_b: HashSet<&str> = b.lines().filter(|l| !l.trim().is_empty()).collect();
25
26 if lines_a.is_empty() && lines_b.is_empty() {
27 return 1.0;
28 }
29 if lines_a.is_empty() || lines_b.is_empty() {
30 return 0.0;
31 }
32
33 let intersection: HashSet<_> = lines_a.intersection(&lines_b).collect();
34 let union: HashSet<_> = lines_a.union(&lines_b).collect();
35
36 let line_similarity = intersection.len() as f64 / union.len() as f64;
37 if line_similarity == 0.0 {
38 return compute_similarity(a, b, SimilarityMethod::Tokens);
39 }
40
41 line_similarity
42 }
43 SimilarityMethod::Tokens => {
44 let tokens_a: HashSet<&str> = a.split_whitespace().collect();
45 let tokens_b: HashSet<&str> = b.split_whitespace().collect();
46
47 if tokens_a.is_empty() && tokens_b.is_empty() {
48 return 1.0;
49 }
50 if tokens_a.is_empty() || tokens_b.is_empty() {
51 return 0.0;
52 }
53
54 let intersection: HashSet<_> = tokens_a.intersection(&tokens_b).collect();
55 let union: HashSet<_> = tokens_a.union(&tokens_b).collect();
56
57 intersection.len() as f64 / union.len() as f64
58 }
59 SimilarityMethod::Ast => compute_similarity(a, b, SimilarityMethod::Tokens),
60 }
61}
62
63pub fn compute_similarity_with_language(
64 a: &str,
65 b: &str,
66 method: SimilarityMethod,
67 language: Language,
68) -> f64 {
69 match method {
70 SimilarityMethod::Ast => {
71 if let Some(score) = compute_ast_similarity(a, b, language) {
72 return score;
73 }
74 compute_similarity(a, b, SimilarityMethod::Tokens)
75 }
76 _ => compute_similarity(a, b, method),
77 }
78}
79
80fn compute_ast_similarity(a: &str, b: &str, language: Language) -> Option<f64> {
81 let parsed_a = ParsedFile::parse(a, language)?;
82 let parsed_b = ParsedFile::parse(b, language)?;
83
84 let mut counts_a = HashMap::new();
85 let mut counts_b = HashMap::new();
86
87 collect_node_kinds(parsed_a.root_node(), &mut counts_a);
88 collect_node_kinds(parsed_b.root_node(), &mut counts_b);
89
90 if counts_a.is_empty() && counts_b.is_empty() {
91 return Some(1.0);
92 }
93 if counts_a.is_empty() || counts_b.is_empty() {
94 return Some(0.0);
95 }
96
97 let mut intersection = 0usize;
98 let mut union = 0usize;
99 let mut keys: HashSet<&str> = HashSet::new();
100 keys.extend(counts_a.keys().map(|k| k.as_str()));
101 keys.extend(counts_b.keys().map(|k| k.as_str()));
102
103 for key in keys {
104 let count_a = counts_a.get(key).copied().unwrap_or(0);
105 let count_b = counts_b.get(key).copied().unwrap_or(0);
106 intersection += count_a.min(count_b);
107 union += count_a.max(count_b);
108 }
109
110 if union == 0 {
111 Some(0.0)
112 } else {
113 Some(intersection as f64 / union as f64)
114 }
115}
116
117fn collect_node_kinds(node: tree_sitter::Node<'_>, counts: &mut HashMap<String, usize>) {
118 let mut stack = vec![node];
119
120 while let Some(current) = stack.pop() {
121 let kind = current.kind();
122 let entry = counts.entry(kind.to_string()).or_insert(0);
123 *entry += 1;
124
125 let child_count = current.child_count();
126 for index in (0..child_count).rev() {
127 if let Some(child) = current.child(index as u32) {
128 stack.push(child);
129 }
130 }
131 }
132}