ai_lib_rust/embeddings/
vectors.rs1use crate::{Error, Result};
4
5pub type Vector = Vec<f32>;
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> Result<f32> {
8 if a.len() != b.len() {
9 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
10 }
11 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
12}
13
14pub fn magnitude(v: &[f32]) -> f32 {
15 v.iter().map(|x| x * x).sum::<f32>().sqrt()
16}
17
18pub fn normalize_vector(v: &[f32]) -> Vector {
19 let mag = magnitude(v);
20 if mag == 0.0 { return v.to_vec(); }
21 v.iter().map(|x| x / mag).collect()
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
25 if a.len() != b.len() {
26 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
27 }
28 let dot = dot_product(a, b)?;
29 let mag_a = magnitude(a);
30 let mag_b = magnitude(b);
31 if mag_a == 0.0 || mag_b == 0.0 { return Ok(0.0); }
32 Ok(dot / (mag_a * mag_b))
33}
34
35pub fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
36 if a.len() != b.len() {
37 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
38 }
39 Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt())
40}
41
42pub fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
43 if a.len() != b.len() {
44 return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
45 }
46 Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum())
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SimilarityMetric { Cosine, Euclidean, DotProduct, Manhattan }
51
52#[derive(Debug, Clone)]
53pub struct SimilarityResult { pub index: usize, pub score: f32 }
54
55pub fn find_most_similar(query: &[f32], candidates: &[Vec<f32>], top_k: usize, metric: SimilarityMetric) -> Result<Vec<SimilarityResult>> {
56 let mut scores: Vec<SimilarityResult> = candidates.iter().enumerate()
57 .filter_map(|(i, c)| {
58 let score = match metric {
59 SimilarityMetric::Cosine => cosine_similarity(query, c).ok(),
60 SimilarityMetric::Euclidean => euclidean_distance(query, c).ok(),
61 SimilarityMetric::DotProduct => dot_product(query, c).ok(),
62 SimilarityMetric::Manhattan => manhattan_distance(query, c).ok(),
63 };
64 score.map(|s| SimilarityResult { index: i, score: s })
65 }).collect();
66 match metric {
67 SimilarityMetric::Cosine | SimilarityMetric::DotProduct => scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)),
68 _ => scores.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)),
69 }
70 scores.truncate(top_k);
71 Ok(scores)
72}
73
74pub fn average_vectors(vectors: &[Vec<f32>]) -> Result<Vector> {
75 if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
76 let dim = vectors[0].len();
77 if !vectors.iter().all(|v| v.len() == dim) { return Err(Error::validation("All vectors must have same dimensions")); }
78 let n = vectors.len() as f32;
79 let mut result = vec![0.0; dim];
80 for v in vectors { for (i, val) in v.iter().enumerate() { result[i] += val; } }
81 for val in &mut result { *val /= n; }
82 Ok(result)
83}
84
85pub fn weighted_average_vectors(vectors: &[Vec<f32>], weights: &[f32]) -> Result<Vector> {
86 if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
87 if vectors.len() != weights.len() { return Err(Error::validation("Vectors and weights must match")); }
88 let total: f32 = weights.iter().sum();
89 if total == 0.0 { return Err(Error::validation("Total weight cannot be zero")); }
90 let dim = vectors[0].len();
91 let mut result = vec![0.0; dim];
92 for (v, w) in vectors.iter().zip(weights.iter()) {
93 let nw = w / total;
94 for (i, val) in v.iter().enumerate() { result[i] += val * nw; }
95 }
96 Ok(result)
97}
98
99pub fn add_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
100 if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
101 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
102}
103
104pub fn subtract_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
105 if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
106 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
107}
108
109pub fn scale_vector(v: &[f32], scalar: f32) -> Vector {
110 v.iter().map(|x| x * scalar).collect()
111}