Skip to main content

veclite_index/
lib.rs

1use serde::{Deserialize, Serialize};
2use wide::f32x8;
3
4#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
5pub enum Metric {
6    Cosine,
7    DotProduct,
8    Euclidean,
9    Manhattan,
10}
11
12pub trait SimilarityMetric {
13    fn distance(a: &[f32], b: &[f32]) -> f32;
14    fn higher_is_better() -> bool;
15}
16
17pub struct CosineMetric;
18impl SimilarityMetric for CosineMetric {
19    fn distance(a: &[f32], b: &[f32]) -> f32 {
20        let mut dot_simd = f32x8::ZERO;
21        let mut norm_a_simd = f32x8::ZERO;
22        let mut norm_b_simd = f32x8::ZERO;
23
24        let chunks_a = a.chunks_exact(8);
25        let chunks_b = b.chunks_exact(8);
26        let rem_a = chunks_a.remainder();
27        let rem_b = chunks_b.remainder();
28
29        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
30            let va = f32x8::new([
31                chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
32                chunk_a[7],
33            ]);
34            let vb = f32x8::new([
35                chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
36                chunk_b[7],
37            ]);
38            dot_simd += va * vb;
39            norm_a_simd += va * va;
40            norm_b_simd += vb * vb;
41        }
42
43        let mut dot = dot_simd.reduce_add();
44        let mut norm_a = norm_a_simd.reduce_add();
45        let mut norm_b = norm_b_simd.reduce_add();
46
47        for (x, y) in rem_a.iter().zip(rem_b.iter()) {
48            dot += x * y;
49            norm_a += x * x;
50            norm_b += y * y;
51        }
52
53        if norm_a == 0.0 || norm_b == 0.0 {
54            return 0.0;
55        }
56        dot / (norm_a.sqrt() * norm_b.sqrt())
57    }
58    fn higher_is_better() -> bool {
59        true
60    }
61}
62
63pub struct EuclideanMetric;
64impl SimilarityMetric for EuclideanMetric {
65    fn distance(a: &[f32], b: &[f32]) -> f32 {
66        let mut sum_sq_simd = f32x8::ZERO;
67
68        let chunks_a = a.chunks_exact(8);
69        let chunks_b = b.chunks_exact(8);
70        let rem_a = chunks_a.remainder();
71        let rem_b = chunks_b.remainder();
72
73        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
74            let va = f32x8::new([
75                chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
76                chunk_a[7],
77            ]);
78            let vb = f32x8::new([
79                chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
80                chunk_b[7],
81            ]);
82            let diff = va - vb;
83            sum_sq_simd += diff * diff;
84        }
85
86        let mut sum_sq = sum_sq_simd.reduce_add();
87        for (x, y) in rem_a.iter().zip(rem_b.iter()) {
88            let diff = x - y;
89            sum_sq += diff * diff;
90        }
91        sum_sq.sqrt()
92    }
93    fn higher_is_better() -> bool {
94        false
95    }
96}
97
98pub struct DotMetric;
99impl SimilarityMetric for DotMetric {
100    fn distance(a: &[f32], b: &[f32]) -> f32 {
101        let mut dot_simd = f32x8::ZERO;
102        let chunks_a = a.chunks_exact(8);
103        let chunks_b = b.chunks_exact(8);
104        let rem_a = chunks_a.remainder();
105        let rem_b = chunks_b.remainder();
106
107        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
108            let va = f32x8::new([
109                chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
110                chunk_a[7],
111            ]);
112            let vb = f32x8::new([
113                chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
114                chunk_b[7],
115            ]);
116            dot_simd += va * vb;
117        }
118        let mut dot = dot_simd.reduce_add();
119        for (x, y) in rem_a.iter().zip(rem_b.iter()) {
120            dot += x * y;
121        }
122        dot
123    }
124    fn higher_is_better() -> bool {
125        true
126    }
127}
128
129pub struct ManhattanMetric;
130impl SimilarityMetric for ManhattanMetric {
131    fn distance(a: &[f32], b: &[f32]) -> f32 {
132        let mut dist = 0.0;
133        for (x, y) in a.iter().zip(b.iter()) {
134            dist += (x - y).abs();
135        }
136        dist
137    }
138    fn higher_is_better() -> bool {
139        false
140    }
141}
142
143pub mod hnsw;