Skip to main content

nodedb_vector/distance/
dispatch.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Dtype-aware distance dispatch shim.
4//!
5//! Routes byte-encoded vector pairs to the correct distance kernel based on
6//! dtype and metric. F16 / BF16 inputs on the three hot metrics (L2, Cosine,
7//! InnerProduct) use fused decode-and-compute kernels from `typed_scalar`,
8//! eliminating the intermediate `Vec<f32>` allocation. Rare metrics (Manhattan,
9//! Chebyshev, Hamming, Jaccard, Pearson) fall through to `cast_to_f32` +
10//! `distance()` — correctness is identical and they are not on the hot path.
11//! The F32 path is always a straight cast + delegate (no conversion needed).
12
13use nodedb_types::vector_distance::DistanceMetric;
14use nodedb_types::vector_dtype::VectorStorageDtype;
15
16use crate::dtype::{DtypeError, cast_to_f32, validate_byte_len};
17
18/// Error type for [`distance_typed`].
19#[derive(thiserror::Error, Debug)]
20pub enum DistanceError {
21    /// The two input buffers encode different numbers of dimensions.
22    #[error("distance: dim mismatch (a: {a_dim}, b: {b_dim})")]
23    DimMismatch { a_dim: usize, b_dim: usize },
24
25    /// A byte buffer has the wrong length for the given dtype and dim.
26    #[error("distance: dtype byte-length error: {0}")]
27    Dtype(#[from] DtypeError),
28}
29
30/// Compute distance between two byte-encoded vectors of the given dtype.
31///
32/// Both buffers must encode `dim` dimensions in `dtype`.
33///
34/// - **F16 / BF16 + L2 / Cosine / InnerProduct**: fused decode-and-compute via
35///   [`typed_scalar`] — no intermediate `Vec<f32>` allocation.
36/// - **F16 / BF16 + other metrics**: up-converts to F32 via [`cast_to_f32`]
37///   then delegates to [`crate::distance::distance`]. These metrics are not
38///   used in embedding search hot paths; the allocation is acceptable.
39/// - **F32**: up-converts via [`cast_to_f32`] (memcopy) then delegates.
40///
41/// # Errors
42///
43/// - [`DistanceError::Dtype`] — if either buffer's length does not match
44///   `dtype.bytes_for_dim(dim)`.
45/// - [`DistanceError::DimMismatch`] — reserved for asymmetric validation;
46///   currently both sides are checked against the same `dim`, so this variant
47///   is unreachable in normal use but is preserved for future asymmetric checks.
48pub fn distance_typed(
49    metric: DistanceMetric,
50    dtype: VectorStorageDtype,
51    a_bytes: &[u8],
52    b_bytes: &[u8],
53    dim: usize,
54) -> Result<f32, DistanceError> {
55    validate_byte_len(a_bytes, dtype, dim)?;
56    validate_byte_len(b_bytes, dtype, dim)?;
57
58    match dtype {
59        VectorStorageDtype::F32 => {
60            let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
61            let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
62            Ok(crate::distance::distance(&a_f32, &b_f32, metric))
63        }
64        VectorStorageDtype::F16 => Ok(match metric {
65            DistanceMetric::L2 => {
66                (crate::distance::simd::runtime().l2_squared_f16)(a_bytes, b_bytes, dim)
67            }
68            DistanceMetric::Cosine => {
69                (crate::distance::simd::runtime().cosine_distance_f16)(a_bytes, b_bytes, dim)
70            }
71            DistanceMetric::InnerProduct => {
72                (crate::distance::simd::runtime().neg_inner_product_f16)(a_bytes, b_bytes, dim)
73            }
74            // Non-hot metrics (Manhattan, Chebyshev, Hamming, Jaccard, Pearson)
75            // are not used in embedding search; cast+delegate is correct and the
76            // allocation overhead is acceptable on this path.
77            _ => {
78                let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
79                let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
80                crate::distance::distance(&a_f32, &b_f32, metric)
81            }
82        }),
83        VectorStorageDtype::BF16 => Ok(match metric {
84            DistanceMetric::L2 => {
85                (crate::distance::simd::runtime().l2_squared_bf16)(a_bytes, b_bytes, dim)
86            }
87            DistanceMetric::Cosine => {
88                (crate::distance::simd::runtime().cosine_distance_bf16)(a_bytes, b_bytes, dim)
89            }
90            DistanceMetric::InnerProduct => {
91                (crate::distance::simd::runtime().neg_inner_product_bf16)(a_bytes, b_bytes, dim)
92            }
93            // Non-hot metrics fall through to cast+delegate (same rationale as F16 arm).
94            _ => {
95                let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
96                let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
97                crate::distance::distance(&a_f32, &b_f32, metric)
98            }
99        }),
100        // `VectorStorageDtype` is #[non_exhaustive]; any future variant falls
101        // back to the cast path, which will return DtypeError if it cannot handle
102        // the variant. This arm is not reachable with any currently-defined dtype.
103        _ => {
104            let a_f32 = cast_to_f32(a_bytes, dtype, dim)?;
105            let b_f32 = cast_to_f32(b_bytes, dtype, dim)?;
106            Ok(crate::distance::distance(&a_f32, &b_f32, metric))
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::dtype::cast_from_f32;
115
116    const EPS_F32: f32 = 1e-6;
117    const EPS_F16: f32 = 1e-2;
118    const EPS_BF16: f32 = 1e-1;
119
120    // Fixed vector pair used across all tests.
121    const A: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
122    const B: [f32; 4] = [4.0, 3.0, 2.0, 1.0];
123
124    fn f32_ref(metric: DistanceMetric) -> f32 {
125        crate::distance::distance(&A, &B, metric)
126    }
127
128    // ── F32 path: byte-exact match with direct distance() ────────────────────
129
130    #[test]
131    fn f32_path_matches_direct_distance() {
132        let a_bytes = cast_from_f32(&A, VectorStorageDtype::F32);
133        let b_bytes = cast_from_f32(&B, VectorStorageDtype::F32);
134        for metric in [
135            DistanceMetric::L2,
136            DistanceMetric::Cosine,
137            DistanceMetric::InnerProduct,
138        ] {
139            let via_typed = distance_typed(metric, VectorStorageDtype::F32, &a_bytes, &b_bytes, 4)
140                .expect("F32 typed distance must not fail");
141            let via_direct = f32_ref(metric);
142            assert_eq!(
143                via_typed, via_direct,
144                "F32 typed vs direct mismatch for {metric:?}"
145            );
146        }
147    }
148
149    // ── F16 round-trip within tolerance ──────────────────────────────────────
150
151    #[test]
152    fn f16_round_trip_within_tolerance() {
153        let a_bytes = cast_from_f32(&A, VectorStorageDtype::F16);
154        let b_bytes = cast_from_f32(&B, VectorStorageDtype::F16);
155        for metric in [
156            DistanceMetric::L2,
157            DistanceMetric::Cosine,
158            DistanceMetric::InnerProduct,
159        ] {
160            let via_typed = distance_typed(metric, VectorStorageDtype::F16, &a_bytes, &b_bytes, 4)
161                .expect("F16 typed distance must not fail");
162            let reference = f32_ref(metric);
163            assert!(
164                (via_typed - reference).abs() < EPS_F16,
165                "F16 typed distance for {metric:?}: got {via_typed}, ref {reference}, diff {}",
166                (via_typed - reference).abs()
167            );
168        }
169    }
170
171    // ── BF16 round-trip within tolerance ─────────────────────────────────────
172
173    #[test]
174    fn bf16_round_trip_within_tolerance() {
175        let a_bytes = cast_from_f32(&A, VectorStorageDtype::BF16);
176        let b_bytes = cast_from_f32(&B, VectorStorageDtype::BF16);
177        for metric in [
178            DistanceMetric::L2,
179            DistanceMetric::Cosine,
180            DistanceMetric::InnerProduct,
181        ] {
182            let via_typed = distance_typed(metric, VectorStorageDtype::BF16, &a_bytes, &b_bytes, 4)
183                .expect("BF16 typed distance must not fail");
184            let reference = f32_ref(metric);
185            assert!(
186                (via_typed - reference).abs() < EPS_BF16,
187                "BF16 typed distance for {metric:?}: got {via_typed}, ref {reference}, diff {}",
188                (via_typed - reference).abs()
189            );
190        }
191    }
192
193    // ── Dim mismatch is caught ────────────────────────────────────────────────
194
195    #[test]
196    fn dim_mismatch_returns_dtype_error() {
197        // a_bytes encodes 2 F32 dims (8 bytes); b_bytes encodes 4 F32 dims (16 bytes).
198        // validate_byte_len(b_bytes, F32, 2) should fire with BadByteLen.
199        let a_bytes = [0u8; 8];
200        let b_bytes = [0u8; 16];
201        let err = distance_typed(
202            DistanceMetric::L2,
203            VectorStorageDtype::F32,
204            &a_bytes,
205            &b_bytes,
206            2,
207        )
208        .expect_err("mismatched buffer must return an error");
209        match err {
210            DistanceError::Dtype(DtypeError::BadByteLen {
211                dtype,
212                dim,
213                expected,
214                actual,
215            }) => {
216                assert_eq!(dtype, VectorStorageDtype::F32);
217                assert_eq!(dim, 2);
218                assert_eq!(expected, 8);
219                assert_eq!(actual, 16);
220            }
221            other => panic!("expected DistanceError::Dtype(BadByteLen), got {other:?}"),
222        }
223    }
224
225    // ── All three metrics × all three dtypes: finite + non-NaN ───────────────
226
227    #[test]
228    fn all_metrics_all_dtypes_finite_non_nan() {
229        let metrics = [
230            DistanceMetric::L2,
231            DistanceMetric::Cosine,
232            DistanceMetric::InnerProduct,
233        ];
234        let dtypes = [
235            VectorStorageDtype::F32,
236            VectorStorageDtype::F16,
237            VectorStorageDtype::BF16,
238        ];
239        for &metric in &metrics {
240            for &dtype in &dtypes {
241                let a_bytes = cast_from_f32(&A, dtype);
242                let b_bytes = cast_from_f32(&B, dtype);
243                let result =
244                    distance_typed(metric, dtype, &a_bytes, &b_bytes, 4).unwrap_or_else(|e| {
245                        panic!("distance_typed({metric:?}, {dtype:?}) failed: {e}")
246                    });
247                assert!(
248                    result.is_finite() && !result.is_nan(),
249                    "distance_typed({metric:?}, {dtype:?}) returned non-finite/NaN: {result}"
250                );
251            }
252        }
253    }
254
255    // ── F32 path: result is finite ────────────────────────────────────────────
256
257    #[test]
258    fn f32_result_finite() {
259        let a_bytes = cast_from_f32(&A, VectorStorageDtype::F32);
260        let b_bytes = cast_from_f32(&B, VectorStorageDtype::F32);
261        let result = distance_typed(
262            DistanceMetric::L2,
263            VectorStorageDtype::F32,
264            &a_bytes,
265            &b_bytes,
266            4,
267        )
268        .expect("F32 distance must succeed");
269        assert!(
270            result.is_finite(),
271            "F32 L2 result must be finite, got {result}"
272        );
273        assert!((result - f32_ref(DistanceMetric::L2)).abs() < EPS_F32);
274    }
275}