omega_agentdb/
simd_ops.rs

1//! SIMD-Optimized Vector Operations
2//!
3//! High-performance distance and similarity computations:
4//! - Cosine similarity/distance
5//! - Euclidean (L2) distance
6//! - Dot product (inner product)
7//! - Manhattan (L1) distance
8//! - Batch operations
9
10use simsimd::SpatialSimilarity;
11
12/// SIMD-accelerated cosine similarity
13/// Returns: 1.0 = identical, 0.0 = orthogonal, -1.0 = opposite
14#[inline]
15pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
16    match f32::cosine(a, b) {
17        Some(distance) => (1.0 - distance) as f32,
18        None => 0.0,
19    }
20}
21
22/// SIMD-accelerated cosine distance
23/// Returns: 0.0 = identical, 1.0 = orthogonal, 2.0 = opposite
24#[inline]
25pub fn cosine_distance_f32(a: &[f32], b: &[f32]) -> f32 {
26    match f32::cosine(a, b) {
27        Some(distance) => distance as f32,
28        None => 1.0,
29    }
30}
31
32/// SIMD-accelerated L2 (Euclidean) squared distance
33#[inline]
34pub fn l2_squared_f32(a: &[f32], b: &[f32]) -> f32 {
35    match f32::sqeuclidean(a, b) {
36        Some(distance) => distance as f32,
37        None => f32::MAX,
38    }
39}
40
41/// SIMD-accelerated L2 (Euclidean) distance
42#[inline]
43pub fn l2_distance_f32(a: &[f32], b: &[f32]) -> f32 {
44    l2_squared_f32(a, b).sqrt()
45}
46
47/// SIMD-accelerated dot product (inner product)
48#[inline]
49pub fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
50    match f32::dot(a, b) {
51        Some(product) => product as f32,
52        None => 0.0,
53    }
54}
55
56/// Distance metric types
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum DistanceMetric {
59    /// Cosine distance (1 - cosine similarity)
60    Cosine,
61    /// L2 (Euclidean) distance
62    L2,
63    /// L2 squared (faster, avoids sqrt)
64    L2Squared,
65    /// Inner product (dot product)
66    InnerProduct,
67    /// Manhattan (L1) distance - scalar fallback
68    Manhattan,
69}
70
71impl DistanceMetric {
72    /// Compute distance between two vectors
73    #[inline]
74    pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
75        match self {
76            Self::Cosine => cosine_distance_f32(a, b),
77            Self::L2 => l2_distance_f32(a, b),
78            Self::L2Squared => l2_squared_f32(a, b),
79            Self::InnerProduct => -dot_product_f32(a, b), // Negative for min-search
80            Self::Manhattan => manhattan_distance_f32(a, b),
81        }
82    }
83
84    /// Convert distance to similarity (for ranking)
85    #[inline]
86    pub fn to_similarity(&self, distance: f32) -> f32 {
87        match self {
88            Self::Cosine => 1.0 - distance,
89            Self::L2 | Self::L2Squared | Self::Manhattan => 1.0 / (1.0 + distance),
90            Self::InnerProduct => -distance, // Already negated
91        }
92    }
93}
94
95/// Manhattan (L1) distance - scalar implementation
96pub fn manhattan_distance_f32(a: &[f32], b: &[f32]) -> f32 {
97    a.iter()
98        .zip(b.iter())
99        .map(|(&x, &y)| (x - y).abs())
100        .sum()
101}
102
103/// Batch cosine similarities - compute similarity of query against multiple vectors
104pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
105    vectors
106        .iter()
107        .map(|v| cosine_similarity_f32(query, v))
108        .collect()
109}
110
111/// Batch L2 distances - compute L2 distance of query against multiple vectors
112pub fn batch_l2_distance(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
113    vectors.iter().map(|v| l2_distance_f32(query, v)).collect()
114}
115
116/// Find top-k most similar vectors
117pub fn top_k_similar(
118    query: &[f32],
119    vectors: &[(&str, &[f32])],
120    k: usize,
121    metric: DistanceMetric,
122) -> Vec<(String, f32)> {
123    let mut scored: Vec<(String, f32)> = vectors
124        .iter()
125        .map(|(id, v)| {
126            let dist = metric.compute(query, v);
127            let sim = metric.to_similarity(dist);
128            (id.to_string(), sim)
129        })
130        .collect();
131
132    // Sort by similarity descending
133    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
134    scored.truncate(k);
135    scored
136}
137
138/// Vector normalization for cosine similarity
139pub fn normalize_f32(v: &mut [f32]) {
140    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
141    if norm > 1e-10 {
142        for x in v.iter_mut() {
143            *x /= norm;
144        }
145    }
146}
147
148/// Check if vector is normalized
149pub fn is_normalized(v: &[f32], tolerance: f32) -> bool {
150    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
151    (norm - 1.0).abs() < tolerance
152}
153
154/// Compute vector magnitude (L2 norm)
155pub fn magnitude_f32(v: &[f32]) -> f32 {
156    v.iter().map(|x| x * x).sum::<f32>().sqrt()
157}
158
159/// Compute centroid of multiple vectors
160pub fn centroid(vectors: &[&[f32]]) -> Vec<f32> {
161    if vectors.is_empty() {
162        return Vec::new();
163    }
164
165    let dim = vectors[0].len();
166    let mut result = vec![0.0f32; dim];
167    let n = vectors.len() as f32;
168
169    for v in vectors {
170        for (i, &x) in v.iter().enumerate() {
171            if i < dim {
172                result[i] += x;
173            }
174        }
175    }
176
177    for x in &mut result {
178        *x /= n;
179    }
180
181    result
182}
183
184/// Compute variance of vectors around their centroid
185pub fn variance(vectors: &[&[f32]], centroid: &[f32]) -> f32 {
186    if vectors.is_empty() {
187        return 0.0;
188    }
189
190    vectors
191        .iter()
192        .map(|v| l2_squared_f32(v, centroid))
193        .sum::<f32>()
194        / vectors.len() as f32
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_cosine_similarity() {
203        let a = [1.0, 0.0, 0.0];
204        let b = [1.0, 0.0, 0.0];
205        assert!((cosine_similarity_f32(&a, &b) - 1.0).abs() < 0.001);
206
207        let c = [1.0, 0.0, 0.0];
208        let d = [0.0, 1.0, 0.0];
209        assert!(cosine_similarity_f32(&c, &d).abs() < 0.001);
210    }
211
212    #[test]
213    fn test_l2_distance() {
214        let a = [0.0, 0.0, 0.0];
215        let b = [3.0, 4.0, 0.0];
216        assert!((l2_distance_f32(&a, &b) - 5.0).abs() < 0.001);
217    }
218
219    #[test]
220    fn test_dot_product() {
221        let a = [1.0, 2.0, 3.0];
222        let b = [4.0, 5.0, 6.0];
223        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
224        assert!((dot_product_f32(&a, &b) - 32.0).abs() < 0.001);
225    }
226
227    #[test]
228    fn test_manhattan_distance() {
229        let a = [1.0, 2.0, 3.0];
230        let b = [4.0, 6.0, 3.0];
231        // |1-4| + |2-6| + |3-3| = 3 + 4 + 0 = 7
232        assert!((manhattan_distance_f32(&a, &b) - 7.0).abs() < 0.001);
233    }
234
235    #[test]
236    fn test_normalize() {
237        let mut v = [3.0, 4.0, 0.0];
238        normalize_f32(&mut v);
239        assert!((magnitude_f32(&v) - 1.0).abs() < 0.001);
240        assert!(is_normalized(&v, 0.001));
241    }
242
243    #[test]
244    fn test_top_k() {
245        let query = [1.0, 0.0, 0.0];
246        let vectors: Vec<(&str, &[f32])> = vec![
247            ("a", &[1.0, 0.0, 0.0][..]),
248            ("b", &[0.0, 1.0, 0.0][..]),
249            ("c", &[0.7, 0.7, 0.0][..]),
250        ];
251
252        let result = top_k_similar(&query, &vectors, 2, DistanceMetric::Cosine);
253        assert_eq!(result.len(), 2);
254        assert_eq!(result[0].0, "a"); // Most similar
255    }
256
257    #[test]
258    fn test_centroid() {
259        let v1 = [0.0, 0.0];
260        let v2 = [2.0, 2.0];
261        let vectors: Vec<&[f32]> = vec![&v1, &v2];
262
263        let c = centroid(&vectors);
264        assert!((c[0] - 1.0).abs() < 0.001);
265        assert!((c[1] - 1.0).abs() < 0.001);
266    }
267
268    #[test]
269    fn test_distance_metrics() {
270        let a = [1.0, 0.0, 0.0];
271        let b = [0.0, 1.0, 0.0];
272
273        let cosine_dist = DistanceMetric::Cosine.compute(&a, &b);
274        assert!((cosine_dist - 1.0).abs() < 0.001);
275
276        let l2_dist = DistanceMetric::L2.compute(&a, &b);
277        assert!((l2_dist - std::f32::consts::SQRT_2).abs() < 0.001);
278    }
279}