1use nodedb_types::vector_distance::DistanceMetric;
14use nodedb_types::vector_dtype::VectorStorageDtype;
15
16use crate::dtype::{DtypeError, cast_to_f32, validate_byte_len};
17
18#[derive(thiserror::Error, Debug)]
20pub enum DistanceError {
21 #[error("distance: dim mismatch (a: {a_dim}, b: {b_dim})")]
23 DimMismatch { a_dim: usize, b_dim: usize },
24
25 #[error("distance: dtype byte-length error: {0}")]
27 Dtype(#[from] DtypeError),
28}
29
30pub fn distance_typed(
49 metric: DistanceMetric,
50 dtype: VectorStorageDtype,
51 a_bytes: &[u8],
52 b_bytes: &[u8],
53 dim: usize,
54) -> Result<f32, DistanceError> {
55 validate_byte_len(a_bytes, dtype, dim)?;
56 validate_byte_len(b_bytes, dtype, dim)?;
57
58 match dtype {
59 VectorStorageDtype::F32 => {
60 let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
61 let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
62 Ok(crate::distance::distance(&a_f32, &b_f32, metric))
63 }
64 VectorStorageDtype::F16 => Ok(match metric {
65 DistanceMetric::L2 => {
66 (crate::distance::simd::runtime().l2_squared_f16)(a_bytes, b_bytes, dim)
67 }
68 DistanceMetric::Cosine => {
69 (crate::distance::simd::runtime().cosine_distance_f16)(a_bytes, b_bytes, dim)
70 }
71 DistanceMetric::InnerProduct => {
72 (crate::distance::simd::runtime().neg_inner_product_f16)(a_bytes, b_bytes, dim)
73 }
74 _ => {
78 let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
79 let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
80 crate::distance::distance(&a_f32, &b_f32, metric)
81 }
82 }),
83 VectorStorageDtype::BF16 => Ok(match metric {
84 DistanceMetric::L2 => {
85 (crate::distance::simd::runtime().l2_squared_bf16)(a_bytes, b_bytes, dim)
86 }
87 DistanceMetric::Cosine => {
88 (crate::distance::simd::runtime().cosine_distance_bf16)(a_bytes, b_bytes, dim)
89 }
90 DistanceMetric::InnerProduct => {
91 (crate::distance::simd::runtime().neg_inner_product_bf16)(a_bytes, b_bytes, dim)
92 }
93 _ => {
95 let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
96 let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
97 crate::distance::distance(&a_f32, &b_f32, metric)
98 }
99 }),
100 _ => {
104 let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
105 let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
106 Ok(crate::distance::distance(&a_f32, &b_f32, metric))
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::dtype::cast_from_f32;
115
116 const EPS_F32: f32 = 1e-6;
117 const EPS_F16: f32 = 1e-2;
118 const EPS_BF16: f32 = 1e-1;
119
120 const A: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
122 const B: [f32; 4] = [4.0, 3.0, 2.0, 1.0];
123
124 fn f32_ref(metric: DistanceMetric) -> f32 {
125 crate::distance::distance(&A, &B, metric)
126 }
127
128 #[test]
131 fn f32_path_matches_direct_distance() {
132 let a_bytes = cast_from_f32(&A, VectorStorageDtype::F32);
133 let b_bytes = cast_from_f32(&B, VectorStorageDtype::F32);
134 for metric in [
135 DistanceMetric::L2,
136 DistanceMetric::Cosine,
137 DistanceMetric::InnerProduct,
138 ] {
139 let via_typed = distance_typed(metric, VectorStorageDtype::F32, &a_bytes, &b_bytes, 4)
140 .expect("F32 typed distance must not fail");
141 let via_direct = f32_ref(metric);
142 assert_eq!(
143 via_typed, via_direct,
144 "F32 typed vs direct mismatch for {metric:?}"
145 );
146 }
147 }
148
149 #[test]
152 fn f16_round_trip_within_tolerance() {
153 let a_bytes = cast_from_f32(&A, VectorStorageDtype::F16);
154 let b_bytes = cast_from_f32(&B, VectorStorageDtype::F16);
155 for metric in [
156 DistanceMetric::L2,
157 DistanceMetric::Cosine,
158 DistanceMetric::InnerProduct,
159 ] {
160 let via_typed = distance_typed(metric, VectorStorageDtype::F16, &a_bytes, &b_bytes, 4)
161 .expect("F16 typed distance must not fail");
162 let reference = f32_ref(metric);
163 assert!(
164 (via_typed - reference).abs() < EPS_F16,
165 "F16 typed distance for {metric:?}: got {via_typed}, ref {reference}, diff {}",
166 (via_typed - reference).abs()
167 );
168 }
169 }
170
171 #[test]
174 fn bf16_round_trip_within_tolerance() {
175 let a_bytes = cast_from_f32(&A, VectorStorageDtype::BF16);
176 let b_bytes = cast_from_f32(&B, VectorStorageDtype::BF16);
177 for metric in [
178 DistanceMetric::L2,
179 DistanceMetric::Cosine,
180 DistanceMetric::InnerProduct,
181 ] {
182 let via_typed = distance_typed(metric, VectorStorageDtype::BF16, &a_bytes, &b_bytes, 4)
183 .expect("BF16 typed distance must not fail");
184 let reference = f32_ref(metric);
185 assert!(
186 (via_typed - reference).abs() < EPS_BF16,
187 "BF16 typed distance for {metric:?}: got {via_typed}, ref {reference}, diff {}",
188 (via_typed - reference).abs()
189 );
190 }
191 }
192
193 #[test]
196 fn dim_mismatch_returns_dtype_error() {
197 let a_bytes = [0u8; 8];
200 let b_bytes = [0u8; 16];
201 let err = distance_typed(
202 DistanceMetric::L2,
203 VectorStorageDtype::F32,
204 &a_bytes,
205 &b_bytes,
206 2,
207 )
208 .expect_err("mismatched buffer must return an error");
209 match err {
210 DistanceError::Dtype(DtypeError::BadByteLen {
211 dtype,
212 dim,
213 expected,
214 actual,
215 }) => {
216 assert_eq!(dtype, VectorStorageDtype::F32);
217 assert_eq!(dim, 2);
218 assert_eq!(expected, 8);
219 assert_eq!(actual, 16);
220 }
221 other => panic!("expected DistanceError::Dtype(BadByteLen), got {other:?}"),
222 }
223 }
224
225 #[test]
228 fn all_metrics_all_dtypes_finite_non_nan() {
229 let metrics = [
230 DistanceMetric::L2,
231 DistanceMetric::Cosine,
232 DistanceMetric::InnerProduct,
233 ];
234 let dtypes = [
235 VectorStorageDtype::F32,
236 VectorStorageDtype::F16,
237 VectorStorageDtype::BF16,
238 ];
239 for &metric in &metrics {
240 for &dtype in &dtypes {
241 let a_bytes = cast_from_f32(&A, dtype);
242 let b_bytes = cast_from_f32(&B, dtype);
243 let result =
244 distance_typed(metric, dtype, &a_bytes, &b_bytes, 4).unwrap_or_else(|e| {
245 panic!("distance_typed({metric:?}, {dtype:?}) failed: {e}")
246 });
247 assert!(
248 result.is_finite() && !result.is_nan(),
249 "distance_typed({metric:?}, {dtype:?}) returned non-finite/NaN: {result}"
250 );
251 }
252 }
253 }
254
255 #[test]
258 fn f32_result_finite() {
259 let a_bytes = cast_from_f32(&A, VectorStorageDtype::F32);
260 let b_bytes = cast_from_f32(&B, VectorStorageDtype::F32);
261 let result = distance_typed(
262 DistanceMetric::L2,
263 VectorStorageDtype::F32,
264 &a_bytes,
265 &b_bytes,
266 4,
267 )
268 .expect("F32 distance must succeed");
269 assert!(
270 result.is_finite(),
271 "F32 L2 result must be finite, got {result}"
272 );
273 assert!((result - f32_ref(DistanceMetric::L2)).abs() < EPS_F32);
274 }
275}