1use serde::{Deserialize, Serialize};
2use wide::f32x8;
3
4#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
5pub enum Metric {
6 Cosine,
7 DotProduct,
8 Euclidean,
9 Manhattan,
10}
11
12pub trait SimilarityMetric {
13 fn distance(a: &[f32], b: &[f32]) -> f32;
14 fn higher_is_better() -> bool;
15}
16
17pub struct CosineMetric;
18impl SimilarityMetric for CosineMetric {
19 fn distance(a: &[f32], b: &[f32]) -> f32 {
20 let mut dot_simd = f32x8::ZERO;
21 let mut norm_a_simd = f32x8::ZERO;
22 let mut norm_b_simd = f32x8::ZERO;
23
24 let chunks_a = a.chunks_exact(8);
25 let chunks_b = b.chunks_exact(8);
26 let rem_a = chunks_a.remainder();
27 let rem_b = chunks_b.remainder();
28
29 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
30 let va = f32x8::new([
31 chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
32 chunk_a[7],
33 ]);
34 let vb = f32x8::new([
35 chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
36 chunk_b[7],
37 ]);
38 dot_simd += va * vb;
39 norm_a_simd += va * va;
40 norm_b_simd += vb * vb;
41 }
42
43 let mut dot = dot_simd.reduce_add();
44 let mut norm_a = norm_a_simd.reduce_add();
45 let mut norm_b = norm_b_simd.reduce_add();
46
47 for (x, y) in rem_a.iter().zip(rem_b.iter()) {
48 dot += x * y;
49 norm_a += x * x;
50 norm_b += y * y;
51 }
52
53 if norm_a == 0.0 || norm_b == 0.0 {
54 return 0.0;
55 }
56 dot / (norm_a.sqrt() * norm_b.sqrt())
57 }
58 fn higher_is_better() -> bool {
59 true
60 }
61}
62
63pub struct EuclideanMetric;
64impl SimilarityMetric for EuclideanMetric {
65 fn distance(a: &[f32], b: &[f32]) -> f32 {
66 let mut sum_sq_simd = f32x8::ZERO;
67
68 let chunks_a = a.chunks_exact(8);
69 let chunks_b = b.chunks_exact(8);
70 let rem_a = chunks_a.remainder();
71 let rem_b = chunks_b.remainder();
72
73 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
74 let va = f32x8::new([
75 chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
76 chunk_a[7],
77 ]);
78 let vb = f32x8::new([
79 chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
80 chunk_b[7],
81 ]);
82 let diff = va - vb;
83 sum_sq_simd += diff * diff;
84 }
85
86 let mut sum_sq = sum_sq_simd.reduce_add();
87 for (x, y) in rem_a.iter().zip(rem_b.iter()) {
88 let diff = x - y;
89 sum_sq += diff * diff;
90 }
91 sum_sq.sqrt()
92 }
93 fn higher_is_better() -> bool {
94 false
95 }
96}
97
98pub struct DotMetric;
99impl SimilarityMetric for DotMetric {
100 fn distance(a: &[f32], b: &[f32]) -> f32 {
101 let mut dot_simd = f32x8::ZERO;
102 let chunks_a = a.chunks_exact(8);
103 let chunks_b = b.chunks_exact(8);
104 let rem_a = chunks_a.remainder();
105 let rem_b = chunks_b.remainder();
106
107 for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
108 let va = f32x8::new([
109 chunk_a[0], chunk_a[1], chunk_a[2], chunk_a[3], chunk_a[4], chunk_a[5], chunk_a[6],
110 chunk_a[7],
111 ]);
112 let vb = f32x8::new([
113 chunk_b[0], chunk_b[1], chunk_b[2], chunk_b[3], chunk_b[4], chunk_b[5], chunk_b[6],
114 chunk_b[7],
115 ]);
116 dot_simd += va * vb;
117 }
118 let mut dot = dot_simd.reduce_add();
119 for (x, y) in rem_a.iter().zip(rem_b.iter()) {
120 dot += x * y;
121 }
122 dot
123 }
124 fn higher_is_better() -> bool {
125 true
126 }
127}
128
129pub struct ManhattanMetric;
130impl SimilarityMetric for ManhattanMetric {
131 fn distance(a: &[f32], b: &[f32]) -> f32 {
132 let mut dist = 0.0;
133 for (x, y) in a.iter().zip(b.iter()) {
134 dist += (x - y).abs();
135 }
136 dist
137 }
138 fn higher_is_better() -> bool {
139 false
140 }
141}
142
143pub mod hnsw;