harn_hostlib/embed/
similarity.rs1pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
16 if a.is_empty() || a.len() != b.len() {
17 return 0.0;
18 }
19 let mut dot = 0.0f32;
20 let mut norm_a = 0.0f32;
21 let mut norm_b = 0.0f32;
22 for (x, y) in a.iter().zip(b.iter()) {
23 dot += x * y;
24 norm_a += x * x;
25 norm_b += y * y;
26 }
27 if norm_a <= 0.0 || norm_b <= 0.0 {
28 return 0.0;
29 }
30 let denom = (norm_a.sqrt()) * (norm_b.sqrt());
31 if denom <= 0.0 {
32 return 0.0;
33 }
34 let sim = dot / denom;
35 if sim.is_finite() {
36 sim.clamp(-1.0, 1.0)
38 } else {
39 0.0
40 }
41}
42
43#[derive(Debug, Clone, PartialEq)]
45pub struct Scored {
46 pub index: usize,
48 pub score: f32,
50}
51
52pub fn top_k(query: &[f32], corpus: &[Vec<f32>], k: usize) -> Vec<Scored> {
61 if k == 0 || corpus.is_empty() {
62 return Vec::new();
63 }
64 let mut scored: Vec<Scored> = corpus
65 .iter()
66 .enumerate()
67 .map(|(index, vec)| Scored {
68 index,
69 score: cosine(query, vec),
70 })
71 .collect();
72 scored.sort_by(|a, b| {
75 b.score
76 .total_cmp(&a.score)
77 .then_with(|| a.index.cmp(&b.index))
78 });
79 scored.truncate(k);
80 scored
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn cosine_identical_vectors_is_one() {
89 let v = vec![1.0, 2.0, 3.0];
90 assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
91 }
92
93 #[test]
94 fn cosine_orthogonal_is_zero() {
95 assert!(cosine(&[1.0, 0.0], &[0.0, 1.0]).abs() < 1e-6);
96 }
97
98 #[test]
99 fn cosine_opposite_is_negative_one() {
100 assert!((cosine(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
101 }
102
103 #[test]
104 fn cosine_handles_degenerate_inputs() {
105 assert_eq!(cosine(&[], &[]), 0.0);
106 assert_eq!(cosine(&[1.0, 2.0], &[1.0]), 0.0); assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0); assert_eq!(cosine(&[f32::NAN, 1.0], &[1.0, 1.0]), 0.0); }
110
111 #[test]
112 fn top_k_ranks_highest_first() {
113 let query = vec![1.0, 0.0];
114 let corpus = vec![
115 vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0], vec![-1.0, 0.0], ];
120 let ranked = top_k(&query, &corpus, 3);
121 assert_eq!(ranked.len(), 3);
122 assert_eq!(ranked[0].index, 1);
123 assert_eq!(ranked[1].index, 2);
124 assert_eq!(ranked[2].index, 0);
125 }
126
127 #[test]
128 fn top_k_is_deterministic_on_ties() {
129 let query = vec![1.0, 0.0];
130 let corpus = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]];
132 let ranked = top_k(&query, &corpus, 3);
133 assert_eq!(
134 ranked.iter().map(|s| s.index).collect::<Vec<_>>(),
135 vec![0, 1, 2]
136 );
137 }
138
139 #[test]
140 fn top_k_empty_and_zero_k() {
141 assert!(top_k(&[1.0], &[], 5).is_empty());
142 assert!(top_k(&[1.0], &[vec![1.0]], 0).is_empty());
143 }
144
145 #[test]
146 fn top_k_clamps_to_corpus_size() {
147 let ranked = top_k(&[1.0], &[vec![1.0], vec![0.5]], 100);
148 assert_eq!(ranked.len(), 2);
149 }
150}