1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SimilarityMetric {
11 Cosine,
13 DotProduct,
15 Euclidean,
17 Manhattan,
19 Chebyshev,
21}
22
23#[derive(Debug, Clone, PartialEq)]
27pub struct SimilarityResult {
28 pub index: usize,
30 pub score: f64,
32 pub label: Option<String>,
34}
35
36pub struct EmbeddingSimilarity;
40
41impl EmbeddingSimilarity {
42 pub fn cosine(a: &[f64], b: &[f64]) -> f64 {
48 let dot = Self::dot_product(a, b);
49 let norm_a = Self::l2_norm(a);
50 let norm_b = Self::l2_norm(b);
51 if norm_a == 0.0 || norm_b == 0.0 {
52 return 0.0;
53 }
54 dot / (norm_a * norm_b)
55 }
56
57 pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
59 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
60 }
61
62 pub fn euclidean(a: &[f64], b: &[f64]) -> f64 {
64 let dist: f64 = a
65 .iter()
66 .zip(b.iter())
67 .map(|(x, y)| (x - y).powi(2))
68 .sum::<f64>()
69 .sqrt();
70 1.0 / (1.0 + dist)
71 }
72
73 pub fn manhattan(a: &[f64], b: &[f64]) -> f64 {
75 let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
76 1.0 / (1.0 + dist)
77 }
78
79 pub fn chebyshev(a: &[f64], b: &[f64]) -> f64 {
81 let dist = a
82 .iter()
83 .zip(b.iter())
84 .map(|(x, y)| (x - y).abs())
85 .fold(0.0_f64, f64::max);
86 1.0 / (1.0 + dist)
87 }
88
89 pub fn compute(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
93 match metric {
94 SimilarityMetric::Cosine => Self::cosine(a, b),
95 SimilarityMetric::DotProduct => Self::dot_product(a, b),
96 SimilarityMetric::Euclidean => Self::euclidean(a, b),
97 SimilarityMetric::Manhattan => Self::manhattan(a, b),
98 SimilarityMetric::Chebyshev => Self::chebyshev(a, b),
99 }
100 }
101
102 pub fn top_k(
109 query: &[f64],
110 corpus: &[Vec<f64>],
111 k: usize,
112 metric: SimilarityMetric,
113 ) -> Vec<SimilarityResult> {
114 let mut scored: Vec<SimilarityResult> = corpus
115 .iter()
116 .enumerate()
117 .map(|(i, v)| SimilarityResult {
118 index: i,
119 score: Self::compute(query, v, metric),
120 label: None,
121 })
122 .collect();
123
124 scored.sort_by(|a, b| {
126 b.score
127 .partial_cmp(&a.score)
128 .unwrap_or(std::cmp::Ordering::Equal)
129 });
130 scored.truncate(k);
131 scored
132 }
133
134 pub fn normalize(v: &[f64]) -> Vec<f64> {
140 let norm = Self::l2_norm(v);
141 if norm == 0.0 {
142 return vec![0.0; v.len()];
143 }
144 v.iter().map(|x| x / norm).collect()
145 }
146
147 pub fn pairwise(corpus: &[Vec<f64>], metric: SimilarityMetric) -> Vec<Vec<f64>> {
153 let n = corpus.len();
154 let mut matrix = vec![vec![0.0_f64; n]; n];
155 for i in 0..n {
156 for j in 0..n {
157 matrix[i][j] = Self::compute(&corpus[i], &corpus[j], metric);
158 }
159 }
160 matrix
161 }
162
163 fn l2_norm(v: &[f64]) -> f64 {
166 v.iter().map(|x| x * x).sum::<f64>().sqrt()
167 }
168}
169
170#[cfg(test)]
173mod tests {
174 use super::*;
175
176 const EPS: f64 = 1e-9;
177
178 fn approx_eq(a: f64, b: f64) -> bool {
179 (a - b).abs() < 1e-6
180 }
181
182 #[test]
185 fn test_cosine_identical() {
186 let v = vec![1.0, 2.0, 3.0];
187 let s = EmbeddingSimilarity::cosine(&v, &v);
188 assert!(approx_eq(s, 1.0));
189 }
190
191 #[test]
192 fn test_cosine_opposite() {
193 let a = vec![1.0, 0.0];
194 let b = vec![-1.0, 0.0];
195 let s = EmbeddingSimilarity::cosine(&a, &b);
196 assert!(approx_eq(s, -1.0));
197 }
198
199 #[test]
200 fn test_cosine_orthogonal() {
201 let a = vec![1.0, 0.0];
202 let b = vec![0.0, 1.0];
203 let s = EmbeddingSimilarity::cosine(&a, &b);
204 assert!(approx_eq(s, 0.0));
205 }
206
207 #[test]
208 fn test_cosine_zero_vector() {
209 let a = vec![0.0, 0.0];
210 let b = vec![1.0, 2.0];
211 let s = EmbeddingSimilarity::cosine(&a, &b);
212 assert_eq!(s, 0.0);
213 }
214
215 #[test]
216 fn test_cosine_range() {
217 let a = vec![1.0, 2.0, 3.0];
218 let b = vec![4.0, 5.0, 6.0];
219 let s = EmbeddingSimilarity::cosine(&a, &b);
220 assert!((-1.0..=1.0).contains(&s));
221 }
222
223 #[test]
226 fn test_dot_product_basic() {
227 let a = vec![1.0, 2.0, 3.0];
228 let b = vec![4.0, 5.0, 6.0];
229 let d = EmbeddingSimilarity::dot_product(&a, &b);
230 assert!(approx_eq(d, 32.0));
231 }
232
233 #[test]
234 fn test_dot_product_zero() {
235 let a = vec![1.0, 0.0];
236 let b = vec![0.0, 1.0];
237 assert!(approx_eq(EmbeddingSimilarity::dot_product(&a, &b), 0.0));
238 }
239
240 #[test]
241 fn test_dot_product_negative() {
242 let a = vec![1.0, -1.0];
243 let b = vec![1.0, 1.0];
244 assert!(approx_eq(EmbeddingSimilarity::dot_product(&a, &b), 0.0));
245 }
246
247 #[test]
250 fn test_euclidean_identical() {
251 let v = vec![1.0, 2.0, 3.0];
252 let s = EmbeddingSimilarity::euclidean(&v, &v);
253 assert!(approx_eq(s, 1.0)); }
255
256 #[test]
257 fn test_euclidean_unit_apart() {
258 let a = vec![0.0];
259 let b = vec![1.0];
260 let s = EmbeddingSimilarity::euclidean(&a, &b);
262 assert!(approx_eq(s, 0.5));
263 }
264
265 #[test]
266 fn test_euclidean_positive() {
267 let a = vec![1.0, 2.0];
268 let b = vec![4.0, 6.0];
269 let s = EmbeddingSimilarity::euclidean(&a, &b);
270 assert!(s > 0.0 && s < 1.0);
271 }
272
273 #[test]
276 fn test_manhattan_identical() {
277 let v = vec![1.0, 2.0];
278 let s = EmbeddingSimilarity::manhattan(&v, &v);
279 assert!(approx_eq(s, 1.0));
280 }
281
282 #[test]
283 fn test_manhattan_unit_apart() {
284 let a = vec![0.0];
285 let b = vec![1.0];
286 assert!(approx_eq(EmbeddingSimilarity::manhattan(&a, &b), 0.5));
287 }
288
289 #[test]
290 fn test_manhattan_positive() {
291 let a = vec![0.0, 0.0];
292 let b = vec![3.0, 4.0];
293 let s = EmbeddingSimilarity::manhattan(&a, &b);
295 assert!(approx_eq(s, 1.0 / 8.0));
296 }
297
298 #[test]
301 fn test_chebyshev_identical() {
302 let v = vec![1.0, 2.0, 3.0];
303 let s = EmbeddingSimilarity::chebyshev(&v, &v);
304 assert!(approx_eq(s, 1.0));
305 }
306
307 #[test]
308 fn test_chebyshev_picks_max() {
309 let a = vec![0.0, 0.0];
310 let b = vec![1.0, 5.0];
311 let s = EmbeddingSimilarity::chebyshev(&a, &b);
313 assert!(approx_eq(s, 1.0 / 6.0));
314 }
315
316 #[test]
317 fn test_chebyshev_positive() {
318 let a = vec![1.0, 2.0];
319 let b = vec![4.0, 3.0];
320 let s = EmbeddingSimilarity::chebyshev(&a, &b);
321 assert!(s > 0.0 && s < 1.0);
322 }
323
324 #[test]
327 fn test_compute_cosine() {
328 let a = vec![1.0, 0.0];
329 let b = vec![1.0, 0.0];
330 let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Cosine);
331 assert!(approx_eq(s, 1.0));
332 }
333
334 #[test]
335 fn test_compute_dot_product() {
336 let a = vec![2.0, 3.0];
337 let b = vec![4.0, 5.0];
338 let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::DotProduct);
339 assert!(approx_eq(s, 23.0));
340 }
341
342 #[test]
343 fn test_compute_euclidean() {
344 let a = vec![0.0];
345 let b = vec![1.0];
346 let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Euclidean);
347 assert!(approx_eq(s, 0.5));
348 }
349
350 #[test]
351 fn test_compute_manhattan() {
352 let a = vec![0.0];
353 let b = vec![1.0];
354 let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Manhattan);
355 assert!(approx_eq(s, 0.5));
356 }
357
358 #[test]
359 fn test_compute_chebyshev() {
360 let a = vec![0.0, 0.0];
361 let b = vec![2.0, 3.0];
362 let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Chebyshev);
363 assert!(approx_eq(s, 1.0 / 4.0)); }
365
366 #[test]
369 fn test_normalize_unit_length() {
370 let v = vec![3.0, 4.0];
371 let n = EmbeddingSimilarity::normalize(&v);
372 let norm: f64 = n.iter().map(|x| x * x).sum::<f64>().sqrt();
373 assert!(approx_eq(norm, 1.0));
374 }
375
376 #[test]
377 fn test_normalize_zero_vector() {
378 let v = vec![0.0, 0.0];
379 let n = EmbeddingSimilarity::normalize(&v);
380 assert!(n.iter().all(|&x| x == 0.0));
381 }
382
383 #[test]
384 fn test_normalize_already_unit() {
385 let v = vec![1.0, 0.0];
386 let n = EmbeddingSimilarity::normalize(&v);
387 assert!(approx_eq(n[0], 1.0));
388 assert!(approx_eq(n[1], 0.0));
389 }
390
391 #[test]
392 fn test_normalize_preserves_direction() {
393 let v = vec![1.0, 1.0];
394 let n = EmbeddingSimilarity::normalize(&v);
395 assert!(approx_eq(n[0], n[1]));
396 }
397
398 #[test]
401 fn test_top_k_returns_k_results() {
402 let query = vec![1.0, 0.0];
403 let corpus = vec![
404 vec![1.0, 0.0],
405 vec![0.0, 1.0],
406 vec![-1.0, 0.0],
407 vec![0.5, 0.5],
408 ];
409 let results = EmbeddingSimilarity::top_k(&query, &corpus, 2, SimilarityMetric::Cosine);
410 assert_eq!(results.len(), 2);
411 }
412
413 #[test]
414 fn test_top_k_sorted_descending() {
415 let query = vec![1.0, 0.0];
416 let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![-1.0, 0.0]];
417 let results = EmbeddingSimilarity::top_k(&query, &corpus, 3, SimilarityMetric::Cosine);
418 for i in 0..results.len() - 1 {
419 assert!(results[i].score >= results[i + 1].score);
420 }
421 }
422
423 #[test]
424 fn test_top_k_best_is_identical() {
425 let query = vec![1.0, 2.0, 3.0];
426 let corpus = vec![vec![1.0, 2.0, 3.0], vec![0.0, 0.0, 1.0]];
427 let results = EmbeddingSimilarity::top_k(&query, &corpus, 1, SimilarityMetric::Cosine);
428 assert_eq!(results[0].index, 0);
429 }
430
431 #[test]
432 fn test_top_k_empty_corpus() {
433 let query = vec![1.0, 0.0];
434 let results = EmbeddingSimilarity::top_k(&query, &[], 5, SimilarityMetric::Euclidean);
435 assert!(results.is_empty());
436 }
437
438 #[test]
439 fn test_top_k_k_larger_than_corpus() {
440 let query = vec![1.0];
441 let corpus = vec![vec![1.0], vec![2.0]];
442 let results = EmbeddingSimilarity::top_k(&query, &corpus, 100, SimilarityMetric::Euclidean);
443 assert_eq!(results.len(), 2);
444 }
445
446 #[test]
449 fn test_pairwise_dimensions() {
450 let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
451 let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
452 assert_eq!(m.len(), 3);
453 assert_eq!(m[0].len(), 3);
454 }
455
456 #[test]
457 fn test_pairwise_diagonal_is_max_cosine() {
458 let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
459 let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
460 assert!(approx_eq(m[0][0], 1.0));
462 assert!(approx_eq(m[1][1], 1.0));
463 }
464
465 #[test]
466 fn test_pairwise_symmetric_cosine() {
467 let corpus = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
468 let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
469 assert!(approx_eq(m[0][1], m[1][0]));
470 }
471
472 #[test]
473 fn test_pairwise_empty_corpus() {
474 let m = EmbeddingSimilarity::pairwise(&[], SimilarityMetric::Cosine);
475 assert!(m.is_empty());
476 }
477
478 #[test]
481 fn test_similarity_result_fields() {
482 let r = SimilarityResult {
483 index: 5,
484 score: 0.95,
485 label: Some("example".to_string()),
486 };
487 assert_eq!(r.index, 5);
488 assert!((r.score - 0.95).abs() < EPS);
489 assert_eq!(r.label, Some("example".to_string()));
490 }
491
492 #[test]
493 fn test_similarity_result_no_label() {
494 let r = SimilarityResult {
495 index: 0,
496 score: 1.0,
497 label: None,
498 };
499 assert!(r.label.is_none());
500 }
501
502 #[test]
503 fn test_similarity_result_clone() {
504 let r = SimilarityResult {
505 index: 1,
506 score: 0.5,
507 label: None,
508 };
509 assert_eq!(r, r.clone());
510 }
511
512 #[test]
515 fn test_metric_copy() {
516 let m = SimilarityMetric::Cosine;
517 let m2 = m;
518 assert_eq!(m, m2);
519 }
520
521 #[test]
522 fn test_metric_debug() {
523 let s = format!("{:?}", SimilarityMetric::DotProduct);
524 assert!(s.contains("DotProduct"));
525 }
526
527 #[test]
528 fn test_chebyshev_identical_vectors() {
529 let a = vec![1.0, 2.0, 3.0];
530 let sim = EmbeddingSimilarity::compute(&a, &a, SimilarityMetric::Chebyshev);
531 assert!(approx_eq(sim, 1.0));
533 }
534
535 #[test]
536 fn test_manhattan_orthogonal() {
537 let a = vec![1.0, 0.0];
538 let b = vec![0.0, 1.0];
539 let sim = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Manhattan);
540 assert!((sim - 1.0 / 3.0).abs() < EPS);
542 }
543
544 #[test]
545 fn test_dot_product_negative_components() {
546 let a = vec![-1.0, -1.0];
547 let b = vec![-1.0, -1.0];
548 let sim = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::DotProduct);
549 assert!(approx_eq(sim, 2.0));
551 }
552
553 #[test]
554 fn test_top_k_all_metrics() {
555 let query = vec![1.0, 0.0];
556 let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
557 for metric in [
558 SimilarityMetric::Cosine,
559 SimilarityMetric::Euclidean,
560 SimilarityMetric::Manhattan,
561 SimilarityMetric::Chebyshev,
562 ] {
563 let results = EmbeddingSimilarity::top_k(&query, &corpus, 2, metric);
564 assert_eq!(
565 results.len(),
566 2,
567 "metric {:?} should return 2 results",
568 metric
569 );
570 }
571 }
572
573 #[test]
574 fn test_similarity_result_debug() {
575 let r = SimilarityResult {
576 index: 0,
577 score: 0.9,
578 label: None,
579 };
580 let s = format!("{r:?}");
581 assert!(s.contains("SimilarityResult"));
582 }
583}