1use crate::types::GpuDistanceMetric;
4
5pub mod vector_math {
7 use super::*;
8
9 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
11 if a.len() != b.len() {
12 return 0.0;
13 }
14
15 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
16 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
17 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
18
19 if norm_a == 0.0 || norm_b == 0.0 {
20 0.0
21 } else {
22 dot_product / (norm_a * norm_b)
23 }
24 }
25
26 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
28 if a.len() != b.len() {
29 return f32::INFINITY;
30 }
31
32 let sum_squares: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
33
34 sum_squares.sqrt()
35 }
36
37 pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
39 if a.len() != b.len() {
40 return 0.0;
41 }
42
43 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
44 }
45
46 pub fn normalize_vector(v: &mut [f32]) {
48 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
49 if norm > 0.0 {
50 for x in v.iter_mut() {
51 *x /= norm;
52 }
53 }
54 }
55
56 pub fn vector_magnitude(v: &[f32]) -> f32 {
58 v.iter().map(|x| x * x).sum::<f32>().sqrt()
59 }
60}
61
62pub mod similarity_calculations {
64 use super::*;
65
66 pub fn calculate_similarity(a: &[f32], b: &[f32], metric: GpuDistanceMetric) -> f32 {
68 match metric {
69 GpuDistanceMetric::Cosine => vector_math::cosine_similarity(a, b),
70 GpuDistanceMetric::Euclidean => {
71 let distance = vector_math::euclidean_distance(a, b);
73 1.0 / (1.0 + distance)
74 }
75 GpuDistanceMetric::DotProduct => vector_math::dot_product(a, b),
76 }
77 }
78
79 pub fn batch_similarity(
81 query: &[f32],
82 vectors: &[Vec<f32>],
83 metric: GpuDistanceMetric,
84 ) -> Vec<f32> {
85 vectors
86 .iter()
87 .map(|v| calculate_similarity(query, v, metric))
88 .collect()
89 }
90
91 pub fn top_k_similar(
93 query: &[f32],
94 vectors: &[Vec<f32>],
95 k: usize,
96 metric: GpuDistanceMetric,
97 ) -> Vec<(usize, f32)> {
98 let similarities: Vec<(usize, f32)> = vectors
99 .iter()
100 .enumerate()
101 .map(|(i, v)| (i, calculate_similarity(query, v, metric)))
102 .collect();
103
104 let mut sorted = similarities;
105 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106 sorted.truncate(k);
107 sorted
108 }
109}