layer_conform_core/
similarity.rs1use compact_str::CompactString;
7
8#[derive(Clone, Copy, Debug, Default, PartialEq)]
9pub struct SimilarityScore {
10 pub overall: f64,
11 pub shape: f64,
12 pub calls: f64,
13 pub imports: f64,
14 pub signature: f64,
15}
16
17#[derive(Clone, Copy, Debug)]
18pub struct Weights {
19 pub shape: f64,
20 pub calls: f64,
21 pub imports: f64,
22 pub signature: f64,
23}
24
25impl Default for Weights {
26 fn default() -> Self {
27 Self { shape: 0.6, calls: 0.3, imports: 0.1, signature: 0.0 }
28 }
29}
30
31pub fn jaccard_sorted(a: &[CompactString], b: &[CompactString]) -> f64 {
33 if a.is_empty() && b.is_empty() {
34 return 1.0;
35 }
36 let (mut i, mut j, mut intersect, mut union_n) = (0_usize, 0_usize, 0_usize, 0_usize);
37 while i < a.len() && j < b.len() {
38 match a[i].cmp(&b[j]) {
39 std::cmp::Ordering::Equal => {
40 intersect += 1;
41 union_n += 1;
42 i += 1;
43 j += 1;
44 }
45 std::cmp::Ordering::Less => {
46 union_n += 1;
47 i += 1;
48 }
49 std::cmp::Ordering::Greater => {
50 union_n += 1;
51 j += 1;
52 }
53 }
54 }
55 union_n += a.len() - i;
56 union_n += b.len() - j;
57 if union_n == 0 {
58 return 1.0;
59 }
60 intersect as f64 / union_n as f64
61}
62
63pub fn aggregate(
65 shape: f64,
66 calls: f64,
67 imports: f64,
68 signature: f64,
69 w: Weights,
70) -> SimilarityScore {
71 let total_w = w.shape + w.calls + w.imports + w.signature;
72 let overall = if total_w > 0.0 {
73 (shape * w.shape + calls * w.calls + imports * w.imports + signature * w.signature)
74 / total_w
75 } else {
76 0.0
77 };
78 SimilarityScore { overall, shape, calls, imports, signature }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 fn cs(items: &[&str]) -> Vec<CompactString> {
86 let mut v: Vec<CompactString> = items.iter().map(|s| (*s).into()).collect();
87 v.sort();
88 v
89 }
90
91 #[test]
92 fn jaccard_empty_inputs() {
93 assert!((jaccard_sorted(&[], &[]) - 1.0).abs() < 1e-9);
94 }
95
96 #[test]
97 fn jaccard_identical_sets() {
98 let a = cs(&["useSWR", "axios"]);
99 let b = cs(&["useSWR", "axios"]);
100 assert!((jaccard_sorted(&a, &b) - 1.0).abs() < 1e-9);
101 }
102
103 #[test]
104 fn jaccard_disjoint_sets() {
105 let a = cs(&["useSWR"]);
106 let b = cs(&["axios"]);
107 assert!(jaccard_sorted(&a, &b).abs() < 1e-9);
108 }
109
110 #[test]
111 fn jaccard_partial_overlap() {
112 let a = cs(&["useSWR", "axios"]);
113 let b = cs(&["useSWR", "fetch"]);
114 let j = jaccard_sorted(&a, &b);
115 assert!((j - 1.0 / 3.0).abs() < 1e-9, "got {j}");
116 }
117
118 #[test]
119 fn aggregate_uses_weights() {
120 let s = aggregate(1.0, 0.0, 0.0, 0.0, Weights::default());
121 assert!((s.overall - 0.6).abs() < 1e-9);
122 assert!((s.shape - 1.0).abs() < 1e-9);
123 }
124}