1use crate::{Error, Result};
4
5pub type Vector = Vec<f32>;
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> Result<f32> {
8 if a.len() != b.len() {
9 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
10 }
11 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
12}
13
14pub fn magnitude(v: &[f32]) -> f32 {
15 v.iter().map(|x| x * x).sum::<f32>().sqrt()
16}
17
18pub fn normalize_vector(v: &[f32]) -> Vector {
19 let mag = magnitude(v);
20 if mag == 0.0 { return v.to_vec(); }
21 v.iter().map(|x| x / mag).collect()
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
25 if a.len() != b.len() {
26 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
27 }
28 let dot = dot_product(a, b)?;
29 let mag_a = magnitude(a);
30 let mag_b = magnitude(b);
31 if mag_a == 0.0 || mag_b == 0.0 { return Ok(0.0); }
32 Ok(dot / (mag_a * mag_b))
33}
34
35pub fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
36 if a.len() != b.len() {
37 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
38 }
39 Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt())
40}
41
42pub fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
43 if a.len() != b.len() {
44 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
45 }
46 Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum())
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SimilarityMetric { Cosine, Euclidean, DotProduct, Manhattan }
51
52#[derive(Debug, Clone)]
53pub struct SimilarityResult { pub index: usize, pub score: f32 }
54
55pub fn find_most_similar(query: &[f32], candidates: &[Vec<f32>], top_k: usize, metric: SimilarityMetric) -> Result<Vec<SimilarityResult>> {
56 let mut scores: Vec<SimilarityResult> = candidates.iter().enumerate()
57 .filter_map(|(i, c)| {
58 let score = match metric {
59 SimilarityMetric::Cosine => cosine_similarity(query, c).ok(),
60 SimilarityMetric::Euclidean => euclidean_distance(query, c).ok(),
61 SimilarityMetric::DotProduct => dot_product(query, c).ok(),
62 SimilarityMetric::Manhattan => manhattan_distance(query, c).ok(),
63 };
64 score.map(|s| SimilarityResult { index: i, score: s })
65 }).collect();
66 match metric {
67 SimilarityMetric::Cosine | SimilarityMetric::DotProduct => scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)),
68 _ => scores.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)),
69 }
70 scores.truncate(top_k);
71 Ok(scores)
72}
73
74pub fn average_vectors(vectors: &[Vec<f32>]) -> Result<Vector> {
75 if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
76 let dim = vectors[0].len();
77 if !vectors.iter().all(|v| v.len() == dim) { return Err(Error::validation("All vectors must have same dimensions")); }
78 let n = vectors.len() as f32;
79 let mut result = vec![0.0; dim];
80 for v in vectors { for (i, val) in v.iter().enumerate() { result[i] += val; } }
81 for val in &mut result { *val /= n; }
82 Ok(result)
83}
84
85pub fn weighted_average_vectors(vectors: &[Vec<f32>], weights: &[f32]) -> Result<Vector> {
86 if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
87 if vectors.len() != weights.len() { return Err(Error::validation("Vectors and weights must match")); }
88 let total: f32 = weights.iter().sum();
89 if total == 0.0 { return Err(Error::validation("Total weight cannot be zero")); }
90 let dim = vectors[0].len();
91 let mut result = vec![0.0; dim];
92 for (v, w) in vectors.iter().zip(weights.iter()) {
93 let nw = w / total;
94 for (i, val) in v.iter().enumerate() { result[i] += val * nw; }
95 }
96 Ok(result)
97}
98
99pub fn add_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
100 if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
101 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
102}
103
104pub fn subtract_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
105 if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
106 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
107}
108
109pub fn scale_vector(v: &[f32], scalar: f32) -> Vector {
110 v.iter().map(|x| x * scalar).collect()
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 const EPSILON: f32 = 1e-6;
118
119 fn approx_eq(a: f32, b: f32) -> bool {
120 (a - b).abs() < EPSILON
121 }
122
123 #[test]
124 fn test_dot_product_basic() {
125 let a = vec![1.0, 2.0, 3.0];
126 let b = vec![4.0, 5.0, 6.0];
127 let result = dot_product(&a, &b).unwrap();
128 assert!(approx_eq(result, 32.0));
130 }
131
132 #[test]
133 fn test_dot_product_dimension_mismatch() {
134 let a = vec![1.0, 2.0];
135 let b = vec![1.0, 2.0, 3.0];
136 assert!(dot_product(&a, &b).is_err());
137 }
138
139 #[test]
140 fn test_dot_product_orthogonal() {
141 let a = vec![1.0, 0.0, 0.0];
142 let b = vec![0.0, 1.0, 0.0];
143 let result = dot_product(&a, &b).unwrap();
144 assert!(approx_eq(result, 0.0));
145 }
146
147 #[test]
148 fn test_magnitude_basic() {
149 let v = vec![3.0, 4.0];
150 let result = magnitude(&v);
151 assert!(approx_eq(result, 5.0));
153 }
154
155 #[test]
156 fn test_magnitude_unit_vector() {
157 let v = vec![1.0, 0.0, 0.0];
158 let result = magnitude(&v);
159 assert!(approx_eq(result, 1.0));
160 }
161
162 #[test]
163 fn test_magnitude_zero_vector() {
164 let v = vec![0.0, 0.0, 0.0];
165 let result = magnitude(&v);
166 assert!(approx_eq(result, 0.0));
167 }
168
169 #[test]
170 fn test_normalize_vector_basic() {
171 let v = vec![3.0, 4.0];
172 let normalized = normalize_vector(&v);
173 assert!(approx_eq(normalized[0], 0.6));
175 assert!(approx_eq(normalized[1], 0.8));
176 assert!(approx_eq(magnitude(&normalized), 1.0));
178 }
179
180 #[test]
181 fn test_normalize_vector_zero() {
182 let v = vec![0.0, 0.0, 0.0];
183 let normalized = normalize_vector(&v);
184 assert_eq!(normalized, v);
186 }
187
188 #[test]
189 fn test_cosine_similarity_identical() {
190 let a = vec![1.0, 2.0, 3.0];
191 let b = vec![1.0, 2.0, 3.0];
192 let result = cosine_similarity(&a, &b).unwrap();
193 assert!(approx_eq(result, 1.0));
194 }
195
196 #[test]
197 fn test_cosine_similarity_opposite() {
198 let a = vec![1.0, 2.0, 3.0];
199 let b = vec![-1.0, -2.0, -3.0];
200 let result = cosine_similarity(&a, &b).unwrap();
201 assert!(approx_eq(result, -1.0));
202 }
203
204 #[test]
205 fn test_cosine_similarity_orthogonal() {
206 let a = vec![1.0, 0.0];
207 let b = vec![0.0, 1.0];
208 let result = cosine_similarity(&a, &b).unwrap();
209 assert!(approx_eq(result, 0.0));
210 }
211
212 #[test]
213 fn test_cosine_similarity_zero_vector() {
214 let a = vec![0.0, 0.0];
215 let b = vec![1.0, 1.0];
216 let result = cosine_similarity(&a, &b).unwrap();
217 assert!(approx_eq(result, 0.0));
219 }
220
221 #[test]
222 fn test_euclidean_distance_basic() {
223 let a = vec![0.0, 0.0];
224 let b = vec![3.0, 4.0];
225 let result = euclidean_distance(&a, &b).unwrap();
226 assert!(approx_eq(result, 5.0));
227 }
228
229 #[test]
230 fn test_euclidean_distance_identical() {
231 let a = vec![1.0, 2.0, 3.0];
232 let b = vec![1.0, 2.0, 3.0];
233 let result = euclidean_distance(&a, &b).unwrap();
234 assert!(approx_eq(result, 0.0));
235 }
236
237 #[test]
238 fn test_manhattan_distance_basic() {
239 let a = vec![0.0, 0.0];
240 let b = vec![3.0, 4.0];
241 let result = manhattan_distance(&a, &b).unwrap();
242 assert!(approx_eq(result, 7.0));
244 }
245
246 #[test]
247 fn test_manhattan_distance_negative() {
248 let a = vec![1.0, 2.0];
249 let b = vec![-1.0, -2.0];
250 let result = manhattan_distance(&a, &b).unwrap();
251 assert!(approx_eq(result, 6.0));
253 }
254
255 #[test]
256 fn test_find_most_similar_cosine() {
257 let query = vec![1.0, 0.0, 0.0];
258 let candidates = vec![
259 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.7, 0.7, 0.0], ];
263
264 let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Cosine).unwrap();
265 assert_eq!(results.len(), 2);
266 assert_eq!(results[0].index, 0); assert!(approx_eq(results[0].score, 1.0));
268 }
269
270 #[test]
271 fn test_find_most_similar_euclidean() {
272 let query = vec![0.0, 0.0];
273 let candidates = vec![
274 vec![1.0, 0.0], vec![3.0, 4.0], vec![0.5, 0.5], ];
278
279 let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Euclidean).unwrap();
280 assert_eq!(results.len(), 2);
281 assert_eq!(results[0].index, 2); }
284
285 #[test]
286 fn test_find_most_similar_top_k() {
287 let query = vec![1.0, 0.0];
288 let candidates = vec![
289 vec![1.0, 0.0],
290 vec![0.9, 0.1],
291 vec![0.8, 0.2],
292 vec![0.0, 1.0],
293 ];
294
295 let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Cosine).unwrap();
296 assert_eq!(results.len(), 2);
297 }
298
299 #[test]
300 fn test_average_vectors_basic() {
301 let vectors = vec![
302 vec![1.0, 2.0],
303 vec![3.0, 4.0],
304 ];
305 let result = average_vectors(&vectors).unwrap();
306 assert!(approx_eq(result[0], 2.0));
308 assert!(approx_eq(result[1], 3.0));
309 }
310
311 #[test]
312 fn test_average_vectors_empty() {
313 let vectors: Vec<Vec<f32>> = vec![];
314 assert!(average_vectors(&vectors).is_err());
315 }
316
317 #[test]
318 fn test_average_vectors_dimension_mismatch() {
319 let vectors = vec![
320 vec![1.0, 2.0],
321 vec![3.0, 4.0, 5.0],
322 ];
323 assert!(average_vectors(&vectors).is_err());
324 }
325
326 #[test]
327 fn test_weighted_average_vectors_basic() {
328 let vectors = vec![
329 vec![1.0, 0.0],
330 vec![0.0, 1.0],
331 ];
332 let weights = vec![1.0, 1.0];
333 let result = weighted_average_vectors(&vectors, &weights).unwrap();
334 assert!(approx_eq(result[0], 0.5));
336 assert!(approx_eq(result[1], 0.5));
337 }
338
339 #[test]
340 fn test_weighted_average_vectors_unequal() {
341 let vectors = vec![
342 vec![1.0, 0.0],
343 vec![0.0, 1.0],
344 ];
345 let weights = vec![3.0, 1.0]; let result = weighted_average_vectors(&vectors, &weights).unwrap();
347 assert!(approx_eq(result[0], 0.75));
348 assert!(approx_eq(result[1], 0.25));
349 }
350
351 #[test]
352 fn test_weighted_average_vectors_zero_weights() {
353 let vectors = vec![vec![1.0, 2.0]];
354 let weights = vec![0.0];
355 assert!(weighted_average_vectors(&vectors, &weights).is_err());
356 }
357
358 #[test]
359 fn test_add_vectors_basic() {
360 let a = vec![1.0, 2.0, 3.0];
361 let b = vec![4.0, 5.0, 6.0];
362 let result = add_vectors(&a, &b).unwrap();
363 assert_eq!(result, vec![5.0, 7.0, 9.0]);
364 }
365
366 #[test]
367 fn test_add_vectors_dimension_mismatch() {
368 let a = vec![1.0, 2.0];
369 let b = vec![1.0];
370 assert!(add_vectors(&a, &b).is_err());
371 }
372
373 #[test]
374 fn test_subtract_vectors_basic() {
375 let a = vec![5.0, 7.0, 9.0];
376 let b = vec![1.0, 2.0, 3.0];
377 let result = subtract_vectors(&a, &b).unwrap();
378 assert_eq!(result, vec![4.0, 5.0, 6.0]);
379 }
380
381 #[test]
382 fn test_scale_vector_basic() {
383 let v = vec![1.0, 2.0, 3.0];
384 let result = scale_vector(&v, 2.0);
385 assert_eq!(result, vec![2.0, 4.0, 6.0]);
386 }
387
388 #[test]
389 fn test_scale_vector_zero() {
390 let v = vec![1.0, 2.0, 3.0];
391 let result = scale_vector(&v, 0.0);
392 assert_eq!(result, vec![0.0, 0.0, 0.0]);
393 }
394
395 #[test]
396 fn test_scale_vector_negative() {
397 let v = vec![1.0, 2.0];
398 let result = scale_vector(&v, -1.0);
399 assert_eq!(result, vec![-1.0, -2.0]);
400 }
401}