Skip to main content

ai_lib_rust/embeddings/
vectors.rs

1//! Vector operations for embeddings.
2
3use crate::{Error, Result};
4
5pub type Vector = Vec<f32>;
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> Result<f32> {
8    if a.len() != b.len() {
9        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
10    }
11    Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
12}
13
14pub fn magnitude(v: &[f32]) -> f32 {
15    v.iter().map(|x| x * x).sum::<f32>().sqrt()
16}
17
18pub fn normalize_vector(v: &[f32]) -> Vector {
19    let mag = magnitude(v);
20    if mag == 0.0 { return v.to_vec(); }
21    v.iter().map(|x| x / mag).collect()
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
25    if a.len() != b.len() {
26        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
27    }
28    let dot = dot_product(a, b)?;
29    let mag_a = magnitude(a);
30    let mag_b = magnitude(b);
31    if mag_a == 0.0 || mag_b == 0.0 { return Ok(0.0); }
32    Ok(dot / (mag_a * mag_b))
33}
34
35pub fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
36    if a.len() != b.len() {
37        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
38    }
39    Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt())
40}
41
42pub fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
43    if a.len() != b.len() {
44        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
45    }
46    Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum())
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SimilarityMetric { Cosine, Euclidean, DotProduct, Manhattan }
51
52#[derive(Debug, Clone)]
53pub struct SimilarityResult { pub index: usize, pub score: f32 }
54
55pub fn find_most_similar(query: &[f32], candidates: &[Vec<f32>], top_k: usize, metric: SimilarityMetric) -> Result<Vec<SimilarityResult>> {
56    let mut scores: Vec<SimilarityResult> = candidates.iter().enumerate()
57        .filter_map(|(i, c)| {
58            let score = match metric {
59                SimilarityMetric::Cosine => cosine_similarity(query, c).ok(),
60                SimilarityMetric::Euclidean => euclidean_distance(query, c).ok(),
61                SimilarityMetric::DotProduct => dot_product(query, c).ok(),
62                SimilarityMetric::Manhattan => manhattan_distance(query, c).ok(),
63            };
64            score.map(|s| SimilarityResult { index: i, score: s })
65        }).collect();
66    match metric {
67        SimilarityMetric::Cosine | SimilarityMetric::DotProduct => scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)),
68        _ => scores.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)),
69    }
70    scores.truncate(top_k);
71    Ok(scores)
72}
73
74pub fn average_vectors(vectors: &[Vec<f32>]) -> Result<Vector> {
75    if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
76    let dim = vectors[0].len();
77    if !vectors.iter().all(|v| v.len() == dim) { return Err(Error::validation("All vectors must have same dimensions")); }
78    let n = vectors.len() as f32;
79    let mut result = vec![0.0; dim];
80    for v in vectors { for (i, val) in v.iter().enumerate() { result[i] += val; } }
81    for val in &mut result { *val /= n; }
82    Ok(result)
83}
84
85pub fn weighted_average_vectors(vectors: &[Vec<f32>], weights: &[f32]) -> Result<Vector> {
86    if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
87    if vectors.len() != weights.len() { return Err(Error::validation("Vectors and weights must match")); }
88    let total: f32 = weights.iter().sum();
89    if total == 0.0 { return Err(Error::validation("Total weight cannot be zero")); }
90    let dim = vectors[0].len();
91    let mut result = vec![0.0; dim];
92    for (v, w) in vectors.iter().zip(weights.iter()) {
93        let nw = w / total;
94        for (i, val) in v.iter().enumerate() { result[i] += val * nw; }
95    }
96    Ok(result)
97}
98
99pub fn add_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
100    if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
101    Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
102}
103
104pub fn subtract_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
105    if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
106    Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
107}
108
109pub fn scale_vector(v: &[f32], scalar: f32) -> Vector {
110    v.iter().map(|x| x * scalar).collect()
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    const EPSILON: f32 = 1e-6;
118
119    fn approx_eq(a: f32, b: f32) -> bool {
120        (a - b).abs() < EPSILON
121    }
122
123    #[test]
124    fn test_dot_product_basic() {
125        let a = vec![1.0, 2.0, 3.0];
126        let b = vec![4.0, 5.0, 6.0];
127        let result = dot_product(&a, &b).unwrap();
128        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
129        assert!(approx_eq(result, 32.0));
130    }
131
132    #[test]
133    fn test_dot_product_dimension_mismatch() {
134        let a = vec![1.0, 2.0];
135        let b = vec![1.0, 2.0, 3.0];
136        assert!(dot_product(&a, &b).is_err());
137    }
138
139    #[test]
140    fn test_dot_product_orthogonal() {
141        let a = vec![1.0, 0.0, 0.0];
142        let b = vec![0.0, 1.0, 0.0];
143        let result = dot_product(&a, &b).unwrap();
144        assert!(approx_eq(result, 0.0));
145    }
146
147    #[test]
148    fn test_magnitude_basic() {
149        let v = vec![3.0, 4.0];
150        let result = magnitude(&v);
151        // sqrt(9 + 16) = sqrt(25) = 5
152        assert!(approx_eq(result, 5.0));
153    }
154
155    #[test]
156    fn test_magnitude_unit_vector() {
157        let v = vec![1.0, 0.0, 0.0];
158        let result = magnitude(&v);
159        assert!(approx_eq(result, 1.0));
160    }
161
162    #[test]
163    fn test_magnitude_zero_vector() {
164        let v = vec![0.0, 0.0, 0.0];
165        let result = magnitude(&v);
166        assert!(approx_eq(result, 0.0));
167    }
168
169    #[test]
170    fn test_normalize_vector_basic() {
171        let v = vec![3.0, 4.0];
172        let normalized = normalize_vector(&v);
173        // [3/5, 4/5] = [0.6, 0.8]
174        assert!(approx_eq(normalized[0], 0.6));
175        assert!(approx_eq(normalized[1], 0.8));
176        // Magnitude should be 1
177        assert!(approx_eq(magnitude(&normalized), 1.0));
178    }
179
180    #[test]
181    fn test_normalize_vector_zero() {
182        let v = vec![0.0, 0.0, 0.0];
183        let normalized = normalize_vector(&v);
184        // Should return original for zero vector
185        assert_eq!(normalized, v);
186    }
187
188    #[test]
189    fn test_cosine_similarity_identical() {
190        let a = vec![1.0, 2.0, 3.0];
191        let b = vec![1.0, 2.0, 3.0];
192        let result = cosine_similarity(&a, &b).unwrap();
193        assert!(approx_eq(result, 1.0));
194    }
195
196    #[test]
197    fn test_cosine_similarity_opposite() {
198        let a = vec![1.0, 2.0, 3.0];
199        let b = vec![-1.0, -2.0, -3.0];
200        let result = cosine_similarity(&a, &b).unwrap();
201        assert!(approx_eq(result, -1.0));
202    }
203
204    #[test]
205    fn test_cosine_similarity_orthogonal() {
206        let a = vec![1.0, 0.0];
207        let b = vec![0.0, 1.0];
208        let result = cosine_similarity(&a, &b).unwrap();
209        assert!(approx_eq(result, 0.0));
210    }
211
212    #[test]
213    fn test_cosine_similarity_zero_vector() {
214        let a = vec![0.0, 0.0];
215        let b = vec![1.0, 1.0];
216        let result = cosine_similarity(&a, &b).unwrap();
217        // Should return 0 for zero vector
218        assert!(approx_eq(result, 0.0));
219    }
220
221    #[test]
222    fn test_euclidean_distance_basic() {
223        let a = vec![0.0, 0.0];
224        let b = vec![3.0, 4.0];
225        let result = euclidean_distance(&a, &b).unwrap();
226        assert!(approx_eq(result, 5.0));
227    }
228
229    #[test]
230    fn test_euclidean_distance_identical() {
231        let a = vec![1.0, 2.0, 3.0];
232        let b = vec![1.0, 2.0, 3.0];
233        let result = euclidean_distance(&a, &b).unwrap();
234        assert!(approx_eq(result, 0.0));
235    }
236
237    #[test]
238    fn test_manhattan_distance_basic() {
239        let a = vec![0.0, 0.0];
240        let b = vec![3.0, 4.0];
241        let result = manhattan_distance(&a, &b).unwrap();
242        // |3-0| + |4-0| = 7
243        assert!(approx_eq(result, 7.0));
244    }
245
246    #[test]
247    fn test_manhattan_distance_negative() {
248        let a = vec![1.0, 2.0];
249        let b = vec![-1.0, -2.0];
250        let result = manhattan_distance(&a, &b).unwrap();
251        // |1-(-1)| + |2-(-2)| = 2 + 4 = 6
252        assert!(approx_eq(result, 6.0));
253    }
254
255    #[test]
256    fn test_find_most_similar_cosine() {
257        let query = vec![1.0, 0.0, 0.0];
258        let candidates = vec![
259            vec![1.0, 0.0, 0.0],  // identical
260            vec![0.0, 1.0, 0.0],  // orthogonal
261            vec![0.7, 0.7, 0.0],  // 45 degrees
262        ];
263        
264        let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Cosine).unwrap();
265        assert_eq!(results.len(), 2);
266        assert_eq!(results[0].index, 0); // Identical should be first
267        assert!(approx_eq(results[0].score, 1.0));
268    }
269
270    #[test]
271    fn test_find_most_similar_euclidean() {
272        let query = vec![0.0, 0.0];
273        let candidates = vec![
274            vec![1.0, 0.0],  // distance 1
275            vec![3.0, 4.0],  // distance 5
276            vec![0.5, 0.5],  // distance ~0.7
277        ];
278        
279        let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Euclidean).unwrap();
280        assert_eq!(results.len(), 2);
281        // Euclidean: smaller is better, so closest first
282        assert_eq!(results[0].index, 2); // 0.5, 0.5 is closest
283    }
284
285    #[test]
286    fn test_find_most_similar_top_k() {
287        let query = vec![1.0, 0.0];
288        let candidates = vec![
289            vec![1.0, 0.0],
290            vec![0.9, 0.1],
291            vec![0.8, 0.2],
292            vec![0.0, 1.0],
293        ];
294        
295        let results = find_most_similar(&query, &candidates, 2, SimilarityMetric::Cosine).unwrap();
296        assert_eq!(results.len(), 2);
297    }
298
299    #[test]
300    fn test_average_vectors_basic() {
301        let vectors = vec![
302            vec![1.0, 2.0],
303            vec![3.0, 4.0],
304        ];
305        let result = average_vectors(&vectors).unwrap();
306        // [(1+3)/2, (2+4)/2] = [2, 3]
307        assert!(approx_eq(result[0], 2.0));
308        assert!(approx_eq(result[1], 3.0));
309    }
310
311    #[test]
312    fn test_average_vectors_empty() {
313        let vectors: Vec<Vec<f32>> = vec![];
314        assert!(average_vectors(&vectors).is_err());
315    }
316
317    #[test]
318    fn test_average_vectors_dimension_mismatch() {
319        let vectors = vec![
320            vec![1.0, 2.0],
321            vec![3.0, 4.0, 5.0],
322        ];
323        assert!(average_vectors(&vectors).is_err());
324    }
325
326    #[test]
327    fn test_weighted_average_vectors_basic() {
328        let vectors = vec![
329            vec![1.0, 0.0],
330            vec![0.0, 1.0],
331        ];
332        let weights = vec![1.0, 1.0];
333        let result = weighted_average_vectors(&vectors, &weights).unwrap();
334        // Equal weights: [0.5, 0.5]
335        assert!(approx_eq(result[0], 0.5));
336        assert!(approx_eq(result[1], 0.5));
337    }
338
339    #[test]
340    fn test_weighted_average_vectors_unequal() {
341        let vectors = vec![
342            vec![1.0, 0.0],
343            vec![0.0, 1.0],
344        ];
345        let weights = vec![3.0, 1.0]; // 75% first, 25% second
346        let result = weighted_average_vectors(&vectors, &weights).unwrap();
347        assert!(approx_eq(result[0], 0.75));
348        assert!(approx_eq(result[1], 0.25));
349    }
350
351    #[test]
352    fn test_weighted_average_vectors_zero_weights() {
353        let vectors = vec![vec![1.0, 2.0]];
354        let weights = vec![0.0];
355        assert!(weighted_average_vectors(&vectors, &weights).is_err());
356    }
357
358    #[test]
359    fn test_add_vectors_basic() {
360        let a = vec![1.0, 2.0, 3.0];
361        let b = vec![4.0, 5.0, 6.0];
362        let result = add_vectors(&a, &b).unwrap();
363        assert_eq!(result, vec![5.0, 7.0, 9.0]);
364    }
365
366    #[test]
367    fn test_add_vectors_dimension_mismatch() {
368        let a = vec![1.0, 2.0];
369        let b = vec![1.0];
370        assert!(add_vectors(&a, &b).is_err());
371    }
372
373    #[test]
374    fn test_subtract_vectors_basic() {
375        let a = vec![5.0, 7.0, 9.0];
376        let b = vec![1.0, 2.0, 3.0];
377        let result = subtract_vectors(&a, &b).unwrap();
378        assert_eq!(result, vec![4.0, 5.0, 6.0]);
379    }
380
381    #[test]
382    fn test_scale_vector_basic() {
383        let v = vec![1.0, 2.0, 3.0];
384        let result = scale_vector(&v, 2.0);
385        assert_eq!(result, vec![2.0, 4.0, 6.0]);
386    }
387
388    #[test]
389    fn test_scale_vector_zero() {
390        let v = vec![1.0, 2.0, 3.0];
391        let result = scale_vector(&v, 0.0);
392        assert_eq!(result, vec![0.0, 0.0, 0.0]);
393    }
394
395    #[test]
396    fn test_scale_vector_negative() {
397        let v = vec![1.0, 2.0];
398        let result = scale_vector(&v, -1.0);
399        assert_eq!(result, vec![-1.0, -2.0]);
400    }
401}