1pub mod vector_math {
5 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
7 if a.len() != b.len() {
8 return 0.0;
9 }
10
11 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
12 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
13 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
14
15 if norm_a == 0.0 || norm_b == 0.0 {
16 0.0
17 } else {
18 dot_product / (norm_a * norm_b)
19 }
20 }
21
22 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
24 if a.len() != b.len() {
25 return f32::INFINITY;
26 }
27
28 let sum_squares: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
29
30 sum_squares.sqrt()
31 }
32
33 pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
35 if a.len() != b.len() {
36 return 0.0;
37 }
38
39 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
40 }
41
42 pub fn normalize_vector(v: &mut [f32]) {
44 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
45 if norm > 0.0 {
46 for x in v.iter_mut() {
47 *x /= norm;
48 }
49 }
50 }
51
52 pub fn vector_magnitude(v: &[f32]) -> f32 {
54 v.iter().map(|x| x * x).sum::<f32>().sqrt()
55 }
56}
57
58pub mod similarity_calculations {
60 use crate::types::GpuDistanceMetric;
61
62 use super::vector_math;
63
64 pub fn calculate_similarity(a: &[f32], b: &[f32], metric: GpuDistanceMetric) -> f32 {
66 match metric {
67 GpuDistanceMetric::Cosine => vector_math::cosine_similarity(a, b),
68 GpuDistanceMetric::Euclidean => {
69 let distance = vector_math::euclidean_distance(a, b);
71 1.0 / (1.0 + distance)
72 }
73 GpuDistanceMetric::DotProduct => vector_math::dot_product(a, b),
74 }
75 }
76
77 pub fn batch_similarity(
79 query: &[f32],
80 vectors: &[Vec<f32>],
81 metric: GpuDistanceMetric,
82 ) -> Vec<f32> {
83 vectors
84 .iter()
85 .map(|v| calculate_similarity(query, v, metric))
86 .collect()
87 }
88
89 pub fn top_k_similar(
91 query: &[f32],
92 vectors: &[Vec<f32>],
93 k: usize,
94 metric: GpuDistanceMetric,
95 ) -> Vec<(usize, f32)> {
96 let similarities: Vec<(usize, f32)> = vectors
97 .iter()
98 .enumerate()
99 .map(|(i, v)| (i, calculate_similarity(query, v, metric)))
100 .collect();
101
102 let mut sorted = similarities;
103 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104 sorted.truncate(k);
105 sorted
106 }
107}