Skip to main content

ai_lib_rust/embeddings/
vectors.rs

1//! Vector operations for embeddings.
2
3use crate::{Error, Result};
4
5pub type Vector = Vec<f32>;
6
7pub fn dot_product(a: &[f32], b: &[f32]) -> Result<f32> {
8    if a.len() != b.len() {
9        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
10    }
11    Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
12}
13
14pub fn magnitude(v: &[f32]) -> f32 {
15    v.iter().map(|x| x * x).sum::<f32>().sqrt()
16}
17
18pub fn normalize_vector(v: &[f32]) -> Vector {
19    let mag = magnitude(v);
20    if mag == 0.0 { return v.to_vec(); }
21    v.iter().map(|x| x / mag).collect()
22}
23
24pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
25    if a.len() != b.len() {
26        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
27    }
28    let dot = dot_product(a, b)?;
29    let mag_a = magnitude(a);
30    let mag_b = magnitude(b);
31    if mag_a == 0.0 || mag_b == 0.0 { return Ok(0.0); }
32    Ok(dot / (mag_a * mag_b))
33}
34
35pub fn euclidean_distance(a: &[f32], b: &[f32]) -> Result<f32> {
36    if a.len() != b.len() {
37        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
38    }
39    Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt())
40}
41
42pub fn manhattan_distance(a: &[f32], b: &[f32]) -> Result<f32> {
43    if a.len() != b.len() {
44        return Err(Error::validation(format!("Vector dimensions must match: {} != {}", a.len(), b.len())));
45    }
46    Ok(a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum())
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SimilarityMetric { Cosine, Euclidean, DotProduct, Manhattan }
51
52#[derive(Debug, Clone)]
53pub struct SimilarityResult { pub index: usize, pub score: f32 }
54
55pub fn find_most_similar(query: &[f32], candidates: &[Vec<f32>], top_k: usize, metric: SimilarityMetric) -> Result<Vec<SimilarityResult>> {
56    let mut scores: Vec<SimilarityResult> = candidates.iter().enumerate()
57        .filter_map(|(i, c)| {
58            let score = match metric {
59                SimilarityMetric::Cosine => cosine_similarity(query, c).ok(),
60                SimilarityMetric::Euclidean => euclidean_distance(query, c).ok(),
61                SimilarityMetric::DotProduct => dot_product(query, c).ok(),
62                SimilarityMetric::Manhattan => manhattan_distance(query, c).ok(),
63            };
64            score.map(|s| SimilarityResult { index: i, score: s })
65        }).collect();
66    match metric {
67        SimilarityMetric::Cosine | SimilarityMetric::DotProduct => scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)),
68        _ => scores.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal)),
69    }
70    scores.truncate(top_k);
71    Ok(scores)
72}
73
74pub fn average_vectors(vectors: &[Vec<f32>]) -> Result<Vector> {
75    if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
76    let dim = vectors[0].len();
77    if !vectors.iter().all(|v| v.len() == dim) { return Err(Error::validation("All vectors must have same dimensions")); }
78    let n = vectors.len() as f32;
79    let mut result = vec![0.0; dim];
80    for v in vectors { for (i, val) in v.iter().enumerate() { result[i] += val; } }
81    for val in &mut result { *val /= n; }
82    Ok(result)
83}
84
85pub fn weighted_average_vectors(vectors: &[Vec<f32>], weights: &[f32]) -> Result<Vector> {
86    if vectors.is_empty() { return Err(Error::validation("Cannot average empty list")); }
87    if vectors.len() != weights.len() { return Err(Error::validation("Vectors and weights must match")); }
88    let total: f32 = weights.iter().sum();
89    if total == 0.0 { return Err(Error::validation("Total weight cannot be zero")); }
90    let dim = vectors[0].len();
91    let mut result = vec![0.0; dim];
92    for (v, w) in vectors.iter().zip(weights.iter()) {
93        let nw = w / total;
94        for (i, val) in v.iter().enumerate() { result[i] += val * nw; }
95    }
96    Ok(result)
97}
98
99pub fn add_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
100    if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
101    Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
102}
103
104pub fn subtract_vectors(a: &[f32], b: &[f32]) -> Result<Vector> {
105    if a.len() != b.len() { return Err(Error::validation("Dimensions must match")); }
106    Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
107}
108
109pub fn scale_vector(v: &[f32], scalar: f32) -> Vector {
110    v.iter().map(|x| x * scalar).collect()
111}