Skip to main content

lattice_embed/simd/
quantized.rs

1//! INT8 quantization for efficient embedding storage and similarity computation.
2//!
3//! Quantized vectors provide ~3x speedup and 4x memory reduction with 99%+ accuracy.
4
5#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10
11use std::sync::OnceLock;
12
13use super::simd_config;
14
15/// **Unstable**: INT8 quantization parameters; scale/bias scheme may change.
16///
17/// Quantization parameters for int8 conversion.
18#[derive(Debug, Clone, Copy)]
19pub struct QuantizationParams {
20    /// **Unstable**: scale factor; formula may change with scheme update.
21    pub scale: f32,
22    /// **Unstable**: zero point offset; may be removed for symmetric-only quantization.
23    pub zero_point: i8,
24    /// **Unstable**: min float value; may be removed.
25    pub min_val: f32,
26    /// **Unstable**: max float value; may be removed.
27    pub max_val: f32,
28}
29
30impl QuantizationParams {
31    /// **Unstable**: parameter computation; may be folded into `QuantizedVector::from_f32`.
32    ///
33    /// Handles edge cases: empty vectors, NaN, Inf, near-zero vectors.
34    pub fn from_vector(vector: &[f32]) -> Self {
35        // Single pass over finite values to handle NaN/Inf gracefully.
36        let mut min_val = f32::INFINITY;
37        let mut max_val = f32::NEG_INFINITY;
38
39        for &v in vector {
40            if v.is_finite() {
41                min_val = min_val.min(v);
42                max_val = max_val.max(v);
43            }
44        }
45
46        // Handle edge case: empty or all non-finite.
47        if !min_val.is_finite() || !max_val.is_finite() {
48            min_val = 0.0;
49            max_val = 0.0;
50        }
51
52        // Symmetric quantization: map [-max_abs, max_abs] to [-127, 127]
53        let max_abs = min_val.abs().max(max_val.abs());
54
55        // Epsilon guard to avoid division by near-zero
56        let scale = if max_abs > 1e-10 {
57            127.0 / max_abs
58        } else {
59            1.0 // All zeros or near-zero case
60        };
61
62        Self {
63            scale,
64            zero_point: 0,
65            min_val,
66            max_val,
67        }
68    }
69}
70
71/// **Unstable**: INT8 quantized vector; struct layout and invariants may change.
72///
73/// Quantized int8 vector with its parameters.
74#[derive(Debug, Clone)]
75pub struct QuantizedVector {
76    /// **Unstable**: raw quantized data; invariant (`[-127, 127]`) enforced by constructor.
77    ///
78    /// # Invariant
79    /// All values must be in the range `[-127, 127]`. The value `-128` causes
80    /// incorrect results in AVX-512 VNNI and AVX2 SIMD paths due to `vpabsb`
81    /// saturation behavior. The `from_f32` constructor enforces this via clamping.
82    pub data: Vec<i8>,
83    /// **Unstable**: quantization parameters; may be separated from the vector.
84    pub params: QuantizationParams,
85    /// **Unstable**: L2 norm; may be removed or moved.
86    pub norm: f32,
87}
88
89impl QuantizedVector {
90    /// **Unstable**: quantization constructor; clamping behavior may change.
91    pub fn from_f32(vector: &[f32]) -> Self {
92        let mut params = QuantizationParams::from_vector(vector);
93
94        // Defensive guard: avoid NaN/Inf/zero scale.
95        if !params.scale.is_finite() || params.scale == 0.0 {
96            params.scale = 1.0;
97        }
98
99        // Compute L2 norm of finite values (NaN/Inf are treated as 0.0).
100        let mut norm_sq = 0.0f32;
101        for &v in vector {
102            if v.is_finite() {
103                norm_sq += v * v;
104            }
105        }
106        let norm = norm_sq.sqrt();
107
108        let data: Vec<i8> = vector
109            .iter()
110            .map(|&v| {
111                if !v.is_finite() {
112                    0
113                } else {
114                    (v * params.scale).round().clamp(-127.0, 127.0) as i8
115                }
116            })
117            .collect();
118
119        Self { data, params, norm }
120    }
121
122    /// **Unstable**: dequantization; output precision may change with scheme update.
123    ///
124    /// # Precision
125    ///
126    /// INT8 symmetric quantization maps `[-max_abs, max_abs]` to `[-127, 127]`,
127    /// so the quantization step size is `max_abs / 127`. The maximum per-element
128    /// round-trip error is bounded by half a quantization step: `max_abs / 254`.
129    ///
130    /// For a 384-dim unit-norm embedding (`max_abs` ≈ 1.0), expect element-wise
131    /// absolute error ≤ 0.004 and cosine-similarity error ≤ 0.5%.
132    pub fn to_f32(&self) -> Vec<f32> {
133        let scale = if self.params.scale.is_finite() && self.params.scale != 0.0 {
134            self.params.scale
135        } else {
136            1.0
137        };
138
139        self.data.iter().map(|&v| v as f32 / scale).collect()
140    }
141
142    /// **Unstable**: delegates to `dot_product_i8`; SIMD dispatch may change.
143    #[inline]
144    pub fn dot_product(&self, other: &QuantizedVector) -> f32 {
145        dot_product_i8(self, other)
146    }
147
148    /// **Unstable**: delegates to `cosine_similarity_i8`; SIMD dispatch may change.
149    #[inline]
150    pub fn cosine_similarity(&self, other: &QuantizedVector) -> f32 {
151        cosine_similarity_i8(self, other)
152    }
153}
154
155/// **Unstable**: SIMD INT8 dot product; VNNI/AVX2/NEON dispatch may change.
156///
157/// Returns the approximate float dot product.
158/// Returns 0.0 if vectors have different lengths.
159///
160/// # Feature gate asymmetry
161///
162/// The float32 AVX-512F path (dot_product, cosine, normalize, distance) activates
163/// unconditionally via runtime `is_x86_feature_detected!("avx512f")` -- no Cargo
164/// feature gate is needed because `_mm512_loadu_ps` / `_mm512_fmadd_ps` etc. are
165/// part of the base AVX-512F ISA and Rust's `#[target_feature(enable = "avx512f")]`
166/// annotation is sufficient.
167///
168/// The integer VNNI path below requires `--features avx512` at compile time because
169/// it uses `_mm512_dpbusd_epi32` (AVX-512 VNNI) and `_mm512_cmplt_epi8_mask`
170/// (AVX-512BW), which are behind nightly-gated intrinsics that need an explicit
171/// Cargo feature to opt in to the extended instruction sets at compile time.
172#[inline]
173pub fn dot_product_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
174    // FP-033: enforce at call time (not just debug) — -128 causes incorrect results
175    // in AVX-512 VNNI via vpabsb saturation; from_f32 clamps to [-127, 127] but
176    // the data field is pub so callers can bypass the constructor.
177    assert!(
178        a.data.iter().all(|&v| v != -128i8),
179        "QuantizedVector a contains -128, which violates the [-127, 127] VNNI invariant"
180    );
181    assert!(
182        b.data.iter().all(|&v| v != -128i8),
183        "QuantizedVector b contains -128, which violates the [-127, 127] VNNI invariant"
184    );
185
186    // Runtime length check to prevent UB in release builds
187    if a.data.len() != b.data.len() {
188        return 0.0;
189    }
190    debug_assert_eq!(a.data.len(), b.data.len());
191
192    let denom = a.params.scale * b.params.scale;
193    if denom == 0.0 || !denom.is_finite() {
194        return 0.0;
195    }
196
197    dot_product_i8_raw(&a.data, &b.data) / denom
198}
199
200/// Trusted INT8 dot product for constructor-owned vectors in prepared-query paths.
201///
202/// Uses `debug_assert!` instead of `assert!`; callers must guarantee vectors
203/// were produced by `QuantizedVector::from_f32` or equivalent (clamped to [-127,127]).
204#[inline]
205pub(crate) fn dot_product_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
206    if a.data.len() != b.data.len() {
207        return 0.0;
208    }
209    let denom = a.params.scale * b.params.scale;
210    if denom == 0.0 || !denom.is_finite() {
211        return 0.0;
212    }
213    debug_assert!(a.data.iter().all(|&v| v != i8::MIN));
214    debug_assert!(b.data.iter().all(|&v| v != i8::MIN));
215    dot_product_i8_raw(&a.data, &b.data) / denom
216}
217
218/// **Unstable**: SIMD INT8 cosine similarity; norm storage approach may change.
219///
220/// Uses pre-computed norms for efficiency.
221#[inline]
222pub fn cosine_similarity_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
223    let denom = a.norm * b.norm;
224    if denom == 0.0 || !denom.is_finite() {
225        return 0.0;
226    }
227    dot_product_i8(a, b) / denom
228}
229
230/// Trusted INT8 cosine similarity for constructor-owned vectors in prepared-query paths.
231///
232/// Uses `dot_product_i8_trusted` instead of `dot_product_i8` to skip release-mode
233/// O(N) invariant scans. Callers must guarantee vectors were produced by
234/// `QuantizedVector::from_f32` or equivalent (clamped to [-127, 127]).
235#[inline]
236pub(crate) fn cosine_similarity_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
237    let denom = a.norm * b.norm;
238    if denom == 0.0 || !denom.is_finite() {
239        return 0.0;
240    }
241    dot_product_i8_trusted(a, b) / denom
242}
243
244/// NEON int8 dot product using vmull/vpadal with 4x unrolling.
245///
246/// Processes 64 int8s per iteration with 4 accumulators.
247///
248/// # Safety
249///
250/// Caller must ensure:
251/// - Running on aarch64 (NEON is mandatory, always available)
252/// - `a` and `b` have equal length (checked by caller)
253///
254/// Memory safety:
255/// - Uses `vld1q_s8` for loads (handles any alignment)
256/// - Pointer arithmetic stays within slice bounds via chunk calculation
257/// - Remainder handled via safe slice iteration
258#[cfg(target_arch = "aarch64")]
259#[inline]
260unsafe fn dot_product_i8_neon_unrolled(a: &[i8], b: &[i8]) -> f32 {
261    const SIMD_WIDTH: usize = 16;
262    const UNROLL: usize = 4;
263    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
264    let n = a.len();
265    debug_assert_eq!(n, b.len());
266    let chunks = n / CHUNK_SIZE;
267
268    // 4 independent int32 accumulators
269    let mut sum0 = vdupq_n_s32(0);
270    let mut sum1 = vdupq_n_s32(0);
271    let mut sum2 = vdupq_n_s32(0);
272    let mut sum3 = vdupq_n_s32(0);
273
274    for i in 0..chunks {
275        let base = i * CHUNK_SIZE;
276
277        // Unroll 0: Load 16 int8s, split, widening multiply, pairwise add
278        let a0 = vld1q_s8(a.as_ptr().add(base));
279        let b0 = vld1q_s8(b.as_ptr().add(base));
280        let a0_lo = vget_low_s8(a0);
281        let a0_hi = vget_high_s8(a0);
282        let b0_lo = vget_low_s8(b0);
283        let b0_hi = vget_high_s8(b0);
284        let prod0_lo = vmull_s8(a0_lo, b0_lo);
285        let prod0_hi = vmull_s8(a0_hi, b0_hi);
286        sum0 = vpadalq_s16(sum0, prod0_lo);
287        sum0 = vpadalq_s16(sum0, prod0_hi);
288
289        // Unroll 1
290        let a1 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH));
291        let b1 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH));
292        let a1_lo = vget_low_s8(a1);
293        let a1_hi = vget_high_s8(a1);
294        let b1_lo = vget_low_s8(b1);
295        let b1_hi = vget_high_s8(b1);
296        let prod1_lo = vmull_s8(a1_lo, b1_lo);
297        let prod1_hi = vmull_s8(a1_hi, b1_hi);
298        sum1 = vpadalq_s16(sum1, prod1_lo);
299        sum1 = vpadalq_s16(sum1, prod1_hi);
300
301        // Unroll 2
302        let a2 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 2));
303        let b2 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 2));
304        let a2_lo = vget_low_s8(a2);
305        let a2_hi = vget_high_s8(a2);
306        let b2_lo = vget_low_s8(b2);
307        let b2_hi = vget_high_s8(b2);
308        let prod2_lo = vmull_s8(a2_lo, b2_lo);
309        let prod2_hi = vmull_s8(a2_hi, b2_hi);
310        sum2 = vpadalq_s16(sum2, prod2_lo);
311        sum2 = vpadalq_s16(sum2, prod2_hi);
312
313        // Unroll 3
314        let a3 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 3));
315        let b3 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 3));
316        let a3_lo = vget_low_s8(a3);
317        let a3_hi = vget_high_s8(a3);
318        let b3_lo = vget_low_s8(b3);
319        let b3_hi = vget_high_s8(b3);
320        let prod3_lo = vmull_s8(a3_lo, b3_lo);
321        let prod3_hi = vmull_s8(a3_hi, b3_hi);
322        sum3 = vpadalq_s16(sum3, prod3_lo);
323        sum3 = vpadalq_s16(sum3, prod3_hi);
324    }
325
326    // Combine accumulators
327    let sum01 = vaddq_s32(sum0, sum1);
328    let sum23 = vaddq_s32(sum2, sum3);
329    let mut sum_vec = vaddq_s32(sum01, sum23);
330
331    // Tail SIMD chunks: process remaining full 16-byte vectors before scalar tail.
332    // Helps dimensions like 127 (3 tail chunks) or 129 (0 tail chunks, 1 scalar byte).
333    let tail_start = chunks * CHUNK_SIZE;
334    let tail_chunks = (n - tail_start) / SIMD_WIDTH;
335    for j in 0..tail_chunks {
336        let base = tail_start + j * SIMD_WIDTH;
337        let at = vld1q_s8(a.as_ptr().add(base));
338        let bt = vld1q_s8(b.as_ptr().add(base));
339        let at_lo = vget_low_s8(at);
340        let at_hi = vget_high_s8(at);
341        let bt_lo = vget_low_s8(bt);
342        let bt_hi = vget_high_s8(bt);
343        let pt_lo = vmull_s8(at_lo, bt_lo);
344        let pt_hi = vmull_s8(at_hi, bt_hi);
345        sum_vec = vpadalq_s16(sum_vec, pt_lo);
346        sum_vec = vpadalq_s16(sum_vec, pt_hi);
347    }
348
349    // Horizontal sum
350    let sum = vgetq_lane_s32(sum_vec, 0)
351        + vgetq_lane_s32(sum_vec, 1)
352        + vgetq_lane_s32(sum_vec, 2)
353        + vgetq_lane_s32(sum_vec, 3);
354
355    // Scalar tail: only the final < SIMD_WIDTH elements
356    let remainder_start = tail_start + tail_chunks * SIMD_WIDTH;
357    let remainder: i32 = a[remainder_start..]
358        .iter()
359        .zip(b[remainder_start..].iter())
360        .map(|(&x, &y)| x as i32 * y as i32)
361        .sum();
362
363    (sum + remainder) as f32
364}
365
366/// Emulate `mm512_sign_epi8(b, a)` which doesn't exist in AVX-512.
367///
368/// Returns: b[i] if a[i] > 0, -b[i] if a[i] < 0, 0 if a[i] == 0.
369///
370/// # Safety
371/// Requires AVX-512BW.
372#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
373#[target_feature(enable = "avx512f", enable = "avx512bw")]
374#[inline]
375unsafe fn mm512_sign_epi8(b: __m512i, a: __m512i) -> __m512i {
376    let zero = _mm512_setzero_si512();
377    let neg_b = _mm512_sub_epi8(zero, b);
378    // mask where a < 0
379    let mask_neg = _mm512_cmplt_epi8_mask(a, zero);
380    // mask where a == 0
381    let mask_zero = _mm512_cmpeq_epi8_mask(a, zero);
382    // Start with b, replace with -b where a < 0
383    let result = _mm512_mask_blend_epi8(mask_neg, b, neg_b);
384    // Replace with 0 where a == 0
385    _mm512_mask_blend_epi8(mask_zero, result, zero)
386}
387
388/// AVX-512 VNNI int8 dot product using _mm512_dpbusd_epi32.
389///
390/// Processes 256 int8s per iteration (4x64 with 4 accumulators).
391/// Note: VNNI expects unsigned x signed, so we handle signs carefully.
392///
393/// # Safety
394///
395/// Caller must ensure:
396/// - CPU supports AVX-512F, AVX-512VNNI, and AVX-512BW (verified via `simd_config()`)
397/// - `a` and `b` have equal length (checked by caller)
398///
399/// Memory safety:
400/// - Uses `_mm512_loadu_si512` for unaligned loads (safe for any alignment)
401/// - Pointer arithmetic stays within slice bounds via chunk calculation
402/// - Remainder handled via safe slice iteration
403#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
404#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
405unsafe fn dot_product_i8_avx512vnni(a: &[i8], b: &[i8]) -> f32 {
406    const SIMD_WIDTH: usize = 64; // 64 int8s per 512-bit register
407    const UNROLL: usize = 4;
408    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
409    let n = a.len();
410    debug_assert_eq!(n, b.len());
411    debug_assert!(a.iter().all(|&v| v != i8::MIN));
412    debug_assert!(b.iter().all(|&v| v != i8::MIN));
413    let chunks = n / CHUNK_SIZE;
414
415    // 4 independent int32 accumulators (16 int32s each)
416    let mut sum0 = _mm512_setzero_si512();
417    let mut sum1 = _mm512_setzero_si512();
418    let mut sum2 = _mm512_setzero_si512();
419    let mut sum3 = _mm512_setzero_si512();
420
421    for i in 0..chunks {
422        let base = i * CHUNK_SIZE;
423
424        // VNNI: dpbusd computes sum += a[unsigned] * b[signed]
425        // For signed * signed, we use: abs(a) * sign(b, a)
426        let a0 = _mm512_loadu_si512(a.as_ptr().add(base) as *const __m512i);
427        let b0 = _mm512_loadu_si512(b.as_ptr().add(base) as *const __m512i);
428        let a0_abs = _mm512_abs_epi8(a0);
429        let b0_signed = mm512_sign_epi8(b0, a0);
430        sum0 = _mm512_dpbusd_epi32(sum0, a0_abs, b0_signed);
431
432        let a1 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
433        let b1 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
434        let a1_abs = _mm512_abs_epi8(a1);
435        let b1_signed = mm512_sign_epi8(b1, a1);
436        sum1 = _mm512_dpbusd_epi32(sum1, a1_abs, b1_signed);
437
438        let a2 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
439        let b2 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
440        let a2_abs = _mm512_abs_epi8(a2);
441        let b2_signed = mm512_sign_epi8(b2, a2);
442        sum2 = _mm512_dpbusd_epi32(sum2, a2_abs, b2_signed);
443
444        let a3 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
445        let b3 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
446        let a3_abs = _mm512_abs_epi8(a3);
447        let b3_signed = mm512_sign_epi8(b3, a3);
448        sum3 = _mm512_dpbusd_epi32(sum3, a3_abs, b3_signed);
449    }
450
451    // Combine accumulators
452    let sum01 = _mm512_add_epi32(sum0, sum1);
453    let sum23 = _mm512_add_epi32(sum2, sum3);
454    let sum_vec = _mm512_add_epi32(sum01, sum23);
455
456    // Horizontal sum of 16 int32s
457    let sum = _mm512_reduce_add_epi32(sum_vec);
458
459    // Handle remainder with scalar
460    let remainder_start = chunks * CHUNK_SIZE;
461    let remainder: i32 = a[remainder_start..]
462        .iter()
463        .zip(b[remainder_start..].iter())
464        .map(|(&x, &y)| x as i32 * y as i32)
465        .sum();
466
467    (sum + remainder) as f32
468}
469
470/// AVX2 int8 dot product with 4x unrolling.
471///
472/// # Safety
473///
474/// Caller must ensure:
475/// - CPU supports AVX2 (verified via `simd_config()`)
476/// - `a` and `b` have equal length (checked by caller)
477///
478/// Memory safety:
479/// - Uses `_mm256_loadu_si256` for unaligned loads (safe for any alignment)
480/// - Pointer arithmetic stays within slice bounds via chunk calculation
481/// - Remainder handled via safe slice iteration
482#[cfg(target_arch = "x86_64")]
483#[target_feature(enable = "avx2")]
484unsafe fn dot_product_i8_avx2_unrolled(a: &[i8], b: &[i8]) -> f32 {
485    const SIMD_WIDTH: usize = 32;
486    const UNROLL: usize = 4;
487    const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
488    let n = a.len();
489    debug_assert_eq!(n, b.len());
490    debug_assert!(a.iter().all(|&v| v != i8::MIN));
491    debug_assert!(b.iter().all(|&v| v != i8::MIN));
492    let chunks = n / CHUNK_SIZE;
493
494    // 4 independent int32 accumulators
495    let mut sum0 = _mm256_setzero_si256();
496    let mut sum1 = _mm256_setzero_si256();
497    let mut sum2 = _mm256_setzero_si256();
498    let mut sum3 = _mm256_setzero_si256();
499
500    let ones = _mm256_set1_epi16(1);
501
502    for i in 0..chunks {
503        let base = i * CHUNK_SIZE;
504
505        // Unroll 0
506        let a0 = _mm256_loadu_si256(a.as_ptr().add(base) as *const __m256i);
507        let b0 = _mm256_loadu_si256(b.as_ptr().add(base) as *const __m256i);
508        let prod0 = _mm256_maddubs_epi16(_mm256_abs_epi8(a0), _mm256_sign_epi8(b0, a0));
509        let prod0_32 = _mm256_madd_epi16(prod0, ones);
510        sum0 = _mm256_add_epi32(sum0, prod0_32);
511
512        // Unroll 1
513        let a1 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
514        let b1 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
515        let prod1 = _mm256_maddubs_epi16(_mm256_abs_epi8(a1), _mm256_sign_epi8(b1, a1));
516        let prod1_32 = _mm256_madd_epi16(prod1, ones);
517        sum1 = _mm256_add_epi32(sum1, prod1_32);
518
519        // Unroll 2
520        let a2 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
521        let b2 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
522        let prod2 = _mm256_maddubs_epi16(_mm256_abs_epi8(a2), _mm256_sign_epi8(b2, a2));
523        let prod2_32 = _mm256_madd_epi16(prod2, ones);
524        sum2 = _mm256_add_epi32(sum2, prod2_32);
525
526        // Unroll 3
527        let a3 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
528        let b3 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
529        let prod3 = _mm256_maddubs_epi16(_mm256_abs_epi8(a3), _mm256_sign_epi8(b3, a3));
530        let prod3_32 = _mm256_madd_epi16(prod3, ones);
531        sum3 = _mm256_add_epi32(sum3, prod3_32);
532    }
533
534    // Combine accumulators
535    let sum01 = _mm256_add_epi32(sum0, sum1);
536    let sum23 = _mm256_add_epi32(sum2, sum3);
537    let sum_vec = _mm256_add_epi32(sum01, sum23);
538
539    // Horizontal sum
540    let sum128_lo = _mm256_castsi256_si128(sum_vec);
541    let sum128_hi = _mm256_extracti128_si256(sum_vec, 1);
542    let sum128 = _mm_add_epi32(sum128_lo, sum128_hi);
543    let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
544    let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
545    let sum = _mm_cvtsi128_si32(sum32);
546
547    // Handle remainder
548    let remainder_start = chunks * CHUNK_SIZE;
549    let remainder: i32 = a[remainder_start..]
550        .iter()
551        .zip(b[remainder_start..].iter())
552        .map(|(&x, &y)| x as i32 * y as i32)
553        .sum();
554
555    (sum + remainder) as f32
556}
557
558// ============================================================================
559// INT8 kernel dispatch cache (mirrors f32 DotKernel pattern in dot_product.rs)
560// ============================================================================
561
562/// INT8 dot-product kernel function pointer type.
563pub type I8DotKernel = fn(&[i8], &[i8]) -> f32;
564
565static I8_DOT_KERNEL: OnceLock<I8DotKernel> = OnceLock::new();
566
567/// Return the cached INT8 dot-product kernel.
568///
569/// Callers that invoke INT8 dot product in a tight loop can hoist this call
570/// outside the loop so the OnceLock check runs once, not per-iteration.
571#[inline]
572pub fn resolved_i8_dot_kernel() -> I8DotKernel {
573    *I8_DOT_KERNEL.get_or_init(resolve_i8_dot_kernel)
574}
575
576fn resolve_i8_dot_kernel() -> I8DotKernel {
577    let config = simd_config();
578
579    #[cfg(target_arch = "aarch64")]
580    {
581        if config.neon_enabled {
582            return dot_product_i8_neon_kernel;
583        }
584    }
585
586    #[cfg(target_arch = "x86_64")]
587    {
588        #[cfg(feature = "avx512")]
589        {
590            if config.avx512vnni_enabled {
591                return dot_product_i8_avx512vnni_kernel;
592            }
593        }
594        if config.avx2_enabled {
595            return dot_product_i8_avx2_kernel;
596        }
597    }
598
599    dot_product_i8_scalar_kernel
600}
601
602#[cfg(target_arch = "aarch64")]
603fn dot_product_i8_neon_kernel(a: &[i8], b: &[i8]) -> f32 {
604    // SAFETY: stored only when NEON was detected at init time (always true on aarch64).
605    unsafe { dot_product_i8_neon_unrolled(a, b) }
606}
607
608#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
609fn dot_product_i8_avx512vnni_kernel(a: &[i8], b: &[i8]) -> f32 {
610    debug_assert!(a.iter().all(|&v| v != i8::MIN));
611    debug_assert!(b.iter().all(|&v| v != i8::MIN));
612    // SAFETY: stored only when AVX-512F+VNNI+BW were detected at init time.
613    unsafe { dot_product_i8_avx512vnni(a, b) }
614}
615
616#[cfg(target_arch = "x86_64")]
617fn dot_product_i8_avx2_kernel(a: &[i8], b: &[i8]) -> f32 {
618    debug_assert!(a.iter().all(|&v| v != i8::MIN));
619    debug_assert!(b.iter().all(|&v| v != i8::MIN));
620    // SAFETY: stored only when AVX2 was detected at init time.
621    unsafe { dot_product_i8_avx2_unrolled(a, b) }
622}
623
624fn dot_product_i8_scalar_kernel(a: &[i8], b: &[i8]) -> f32 {
625    a.iter()
626        .zip(b.iter())
627        .map(|(&x, &y)| x as i32 * y as i32)
628        .sum::<i32>() as f32
629}
630
631/// **Unstable**: raw SIMD INT8 hot path; signature and scaling semantics may change.
632///
633/// This is the hot-path function for HNSW quantized search. Unlike `dot_product_i8`,
634/// it takes raw `&[i8]` slices and does NOT divide by scale factors -- the caller
635/// handles scaling. This avoids allocating `QuantizedVector` wrappers.
636///
637/// Returns 0.0 if slices have different lengths.
638///
639/// # Performance
640///
641/// Uses the same SIMD paths as `dot_product_i8`:
642/// - aarch64: NEON with 4x unrolling + tail SIMD chunks
643/// - x86_64: AVX-512 VNNI > AVX2 > scalar
644///
645/// The key difference is zero allocation overhead: no `Vec<i8>`, no
646/// `QuantizedVector`, no `QuantizationParams`. Just raw slices in, f32 out.
647#[inline]
648pub fn dot_product_i8_raw(a: &[i8], b: &[i8]) -> f32 {
649    if a.len() != b.len() {
650        return 0.0;
651    }
652    debug_assert_eq!(a.len(), b.len());
653    resolved_i8_dot_kernel()(a, b)
654}
655
656#[cfg(test)]
657mod simd_parity_tests {
658    use super::*;
659
660    fn gen_vec(dim: usize, seed: u64) -> Vec<f32> {
661        let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
662        (0..dim)
663            .map(|i| {
664                state = state
665                    .wrapping_mul(6364136223846793005)
666                    .wrapping_add(1442695040888963407)
667                    .wrapping_add(i as u64);
668                let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
669                unit * 2.0 - 1.0
670            })
671            .collect()
672    }
673
674    // FP-034: NEON vs scalar parity for INT8 dot product.
675    #[test]
676    fn test_i8_neon_scalar_parity() {
677        #[cfg(target_arch = "aarch64")]
678        for dim in [7usize, 16, 64, 128, 384, 768] {
679            let a_q = QuantizedVector::from_f32(&gen_vec(dim, 200 + dim as u64));
680            let b_q = QuantizedVector::from_f32(&gen_vec(dim, 300 + dim as u64));
681
682            // SAFETY: NEON is mandatory on aarch64; slices have equal length from from_f32.
683            let neon = unsafe { dot_product_i8_neon_unrolled(&a_q.data, &b_q.data) };
684            let scalar: f32 = a_q
685                .data
686                .iter()
687                .zip(b_q.data.iter())
688                .map(|(&x, &y)| x as i32 * y as i32)
689                .sum::<i32>() as f32;
690
691            let diff = (neon - scalar).abs();
692            assert!(
693                diff <= 1.0,
694                "NEON vs scalar i8 dot product dim={dim}: neon={neon} scalar={scalar} diff={diff}"
695            );
696        }
697    }
698
699    // FP-034: AVX2 vs scalar parity for INT8 dot product.
700    #[test]
701    fn test_i8_avx2_scalar_parity() {
702        #[cfg(target_arch = "x86_64")]
703        if std::arch::is_x86_feature_detected!("avx2") {
704            for dim in [7usize, 16, 64, 128, 384, 768] {
705                let a_q = QuantizedVector::from_f32(&gen_vec(dim, 400 + dim as u64));
706                let b_q = QuantizedVector::from_f32(&gen_vec(dim, 500 + dim as u64));
707
708                // SAFETY: AVX2 verified by is_x86_feature_detected! above; slices have equal length.
709                let avx2 = unsafe { dot_product_i8_avx2_unrolled(&a_q.data, &b_q.data) };
710                let scalar: f32 = a_q
711                    .data
712                    .iter()
713                    .zip(b_q.data.iter())
714                    .map(|(&x, &y)| x as i32 * y as i32)
715                    .sum::<i32>() as f32;
716
717                let diff = (avx2 - scalar).abs();
718                assert!(
719                    diff <= 1.0,
720                    "AVX2 vs scalar i8 dot product dim={dim}: avx2={avx2} scalar={scalar} diff={diff}"
721                );
722            }
723        }
724    }
725}