1use crate::types::GpuDistanceMetric;
4
5pub mod vector_math {
7 use super::*;
8
9 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
11 if a.len() != b.len() {
12 return 0.0;
13 }
14
15 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
16 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
17 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
18
19 if norm_a == 0.0 || norm_b == 0.0 {
20 0.0
21 } else {
22 dot_product / (norm_a * norm_b)
23 }
24 }
25
26 pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
28 if a.len() != b.len() {
29 return f32::INFINITY;
30 }
31
32 let sum_squares: f32 = a.iter()
33 .zip(b.iter())
34 .map(|(x, y)| (x - y).powi(2))
35 .sum();
36
37 sum_squares.sqrt()
38 }
39
40 pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
42 if a.len() != b.len() {
43 return 0.0;
44 }
45
46 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
47 }
48
49 pub fn normalize_vector(v: &mut [f32]) {
51 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
52 if norm > 0.0 {
53 for x in v.iter_mut() {
54 *x /= norm;
55 }
56 }
57 }
58
59 pub fn vector_magnitude(v: &[f32]) -> f32 {
61 v.iter().map(|x| x * x).sum::<f32>().sqrt()
62 }
63}
64
65pub mod similarity_calculations {
67 use super::*;
68
69 pub fn calculate_similarity(
71 a: &[f32],
72 b: &[f32],
73 metric: GpuDistanceMetric
74 ) -> f32 {
75 match metric {
76 GpuDistanceMetric::Cosine => vector_math::cosine_similarity(a, b),
77 GpuDistanceMetric::Euclidean => {
78 let distance = vector_math::euclidean_distance(a, b);
80 1.0 / (1.0 + distance)
81 },
82 GpuDistanceMetric::DotProduct => vector_math::dot_product(a, b),
83 }
84 }
85
86 pub fn batch_similarity(
88 query: &[f32],
89 vectors: &[Vec<f32>],
90 metric: GpuDistanceMetric
91 ) -> Vec<f32> {
92 vectors.iter()
93 .map(|v| calculate_similarity(query, v, metric))
94 .collect()
95 }
96
97 pub fn top_k_similar(
99 query: &[f32],
100 vectors: &[Vec<f32>],
101 k: usize,
102 metric: GpuDistanceMetric
103 ) -> Vec<(usize, f32)> {
104 let similarities: Vec<(usize, f32)> = vectors.iter()
105 .enumerate()
106 .map(|(i, v)| (i, calculate_similarity(query, v, metric)))
107 .collect();
108
109 let mut sorted = similarities;
110 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
111 sorted.truncate(k);
112 sorted
113 }
114}