Skip to main content

semantic/analysis/
analysis_similarity.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Similarity computation utilities.
3
4use std::collections::{HashMap, HashSet};
5
6use crate::parser::{Language, ParsedFile};
7
8/// Method for computing content similarity.
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum SimilarityMethod {
11    /// Simple line-by-line comparison.
12    Lines,
13    /// Token-based comparison (ignores whitespace).
14    Tokens,
15    /// AST-based comparison (structure only).
16    Ast,
17}
18
19/// Compute similarity between two strings (0.0 to 1.0).
20pub fn compute_similarity(a: &str, b: &str, method: SimilarityMethod) -> f64 {
21    match method {
22        SimilarityMethod::Lines => {
23            let lines_a: HashSet<&str> = a.lines().filter(|l| !l.trim().is_empty()).collect();
24            let lines_b: HashSet<&str> = b.lines().filter(|l| !l.trim().is_empty()).collect();
25
26            if lines_a.is_empty() && lines_b.is_empty() {
27                return 1.0;
28            }
29            if lines_a.is_empty() || lines_b.is_empty() {
30                return 0.0;
31            }
32
33            let intersection: HashSet<_> = lines_a.intersection(&lines_b).collect();
34            let union: HashSet<_> = lines_a.union(&lines_b).collect();
35
36            let line_similarity = intersection.len() as f64 / union.len() as f64;
37            if line_similarity == 0.0 {
38                return compute_similarity(a, b, SimilarityMethod::Tokens);
39            }
40
41            line_similarity
42        }
43        SimilarityMethod::Tokens => {
44            let tokens_a: HashSet<&str> = a.split_whitespace().collect();
45            let tokens_b: HashSet<&str> = b.split_whitespace().collect();
46
47            if tokens_a.is_empty() && tokens_b.is_empty() {
48                return 1.0;
49            }
50            if tokens_a.is_empty() || tokens_b.is_empty() {
51                return 0.0;
52            }
53
54            let intersection: HashSet<_> = tokens_a.intersection(&tokens_b).collect();
55            let union: HashSet<_> = tokens_a.union(&tokens_b).collect();
56
57            intersection.len() as f64 / union.len() as f64
58        }
59        SimilarityMethod::Ast => compute_similarity(a, b, SimilarityMethod::Tokens),
60    }
61}
62
63pub fn compute_similarity_with_language(
64    a: &str,
65    b: &str,
66    method: SimilarityMethod,
67    language: Language,
68) -> f64 {
69    match method {
70        SimilarityMethod::Ast => {
71            if let Some(score) = compute_ast_similarity(a, b, language) {
72                return score;
73            }
74            compute_similarity(a, b, SimilarityMethod::Tokens)
75        }
76        _ => compute_similarity(a, b, method),
77    }
78}
79
80fn compute_ast_similarity(a: &str, b: &str, language: Language) -> Option<f64> {
81    let parsed_a = ParsedFile::parse(a, language)?;
82    let parsed_b = ParsedFile::parse(b, language)?;
83
84    let mut counts_a = HashMap::new();
85    let mut counts_b = HashMap::new();
86
87    collect_node_kinds(parsed_a.root_node(), &mut counts_a);
88    collect_node_kinds(parsed_b.root_node(), &mut counts_b);
89
90    if counts_a.is_empty() && counts_b.is_empty() {
91        return Some(1.0);
92    }
93    if counts_a.is_empty() || counts_b.is_empty() {
94        return Some(0.0);
95    }
96
97    let mut intersection = 0usize;
98    let mut union = 0usize;
99    let mut keys: HashSet<&str> = HashSet::new();
100    keys.extend(counts_a.keys().map(|k| k.as_str()));
101    keys.extend(counts_b.keys().map(|k| k.as_str()));
102
103    for key in keys {
104        let count_a = counts_a.get(key).copied().unwrap_or(0);
105        let count_b = counts_b.get(key).copied().unwrap_or(0);
106        intersection += count_a.min(count_b);
107        union += count_a.max(count_b);
108    }
109
110    if union == 0 {
111        Some(0.0)
112    } else {
113        Some(intersection as f64 / union as f64)
114    }
115}
116
117fn collect_node_kinds(node: tree_sitter::Node<'_>, counts: &mut HashMap<String, usize>) {
118    let mut stack = vec![node];
119
120    while let Some(current) = stack.pop() {
121        let kind = current.kind();
122        let entry = counts.entry(kind.to_string()).or_insert(0);
123        *entry += 1;
124
125        let child_count = current.child_count();
126        for index in (0..child_count).rev() {
127            if let Some(child) = current.child(index as u32) {
128                stack.push(child);
129            }
130        }
131    }
132}