Skip to main content

batuta/oracle/rag/quantization/
simd.rs

1//! SIMD backend selection and dot product implementations
2//!
3//! Provides hardware-accelerated dot products for int8 embeddings.
4//! Supports AVX2, AVX-512, ARM NEON, and scalar fallback.
5//!
6//! # Safety
7//!
8//! This module uses `unsafe` exclusively for CPU SIMD intrinsics
9//! (AVX2, AVX-512, NEON). All unsafe calls are guarded by runtime
10//! feature detection (`is_x86_feature_detected!`) or compile-time
11//! target gates (`#[cfg(target_arch)]`). No raw pointer arithmetic,
12//! no transmutes, no FFI — only vendor-provided intrinsics.
13
14// Library code - usage from examples and integration tests
15/// SIMD backend selection (Jidoka auto-detection)
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SimdBackend {
18    /// AVX2: 256-bit vectors, 32 int8 ops/cycle
19    Avx2,
20    /// AVX-512: 512-bit vectors, 64 int8 ops/cycle
21    Avx512,
22    /// ARM NEON: 128-bit vectors, 16 int8 ops/cycle
23    Neon,
24    /// Scalar fallback (Jidoka degradation)
25    Scalar,
26}
27
28impl SimdBackend {
29    /// Auto-detect best available SIMD backend (Jidoka)
30    pub fn detect() -> Self {
31        #[cfg(target_arch = "x86_64")]
32        {
33            let has_avx512 =
34                is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw");
35            let has_avx2 = is_x86_feature_detected!("avx2");
36            return Self::from_x86_features(has_avx512, has_avx2);
37        }
38        #[cfg(target_arch = "aarch64")]
39        {
40            // NEON is always available on aarch64
41            return Self::Neon;
42        }
43        #[allow(unreachable_code)]
44        Self::Scalar
45    }
46
47    /// Select backend from x86 feature flags (testable)
48    #[cfg(target_arch = "x86_64")]
49    pub fn from_x86_features(has_avx512: bool, has_avx2: bool) -> Self {
50        if has_avx512 {
51            Self::Avx512
52        } else if has_avx2 {
53            Self::Avx2
54        } else {
55            Self::Scalar
56        }
57    }
58
59    /// Compute dot product of two i8 vectors
60    ///
61    /// Returns i32 to prevent overflow (127^2 x 4096 < i32::MAX)
62    ///
63    /// # Safety rationale
64    ///
65    /// Calls unsafe SIMD intrinsics guarded by `is_x86_feature_detected!`.
66    #[allow(unsafe_code)]
67    pub fn dot_i8(&self, a: &[i8], b: &[i8]) -> i32 {
68        debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
69
70        match self {
71            #[cfg(target_arch = "x86_64")]
72            Self::Avx2 => {
73                if is_x86_feature_detected!("avx2") {
74                    // Safety: AVX2 feature check above
75                    return unsafe { dot_i8_avx2(a, b) };
76                }
77                dot_i8_scalar(a, b)
78            }
79            #[cfg(target_arch = "x86_64")]
80            Self::Avx512 => {
81                if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw") {
82                    // Safety: AVX-512 feature check above
83                    return unsafe { dot_i8_avx512(a, b) };
84                }
85                dot_i8_scalar(a, b)
86            }
87            #[cfg(target_arch = "aarch64")]
88            Self::Neon => {
89                // SAFETY: NEON is always available on aarch64 (ARMv8+).
90                // The dot_i8_neon function uses only NEON intrinsics which are
91                // guaranteed to be supported on all aarch64 targets.
92                unsafe { dot_i8_neon(a, b) }
93            }
94            _ => dot_i8_scalar(a, b),
95        }
96    }
97
98    /// Compute dot product of f32 query with i8 document (rescoring)
99    ///
100    /// Used in stage 2 for 99%+ accuracy retention.
101    pub fn dot_f32_i8(&self, query: &[f32], doc: &[i8], scale: f32) -> f32 {
102        debug_assert_eq!(query.len(), doc.len(), "Vectors must have same length");
103
104        // For rescoring, we use f32 accumulation for precision
105        let mut sum: f32 = 0.0;
106        for (&q, &d) in query.iter().zip(doc.iter()) {
107            sum += q * (d as f32 * scale);
108        }
109        sum
110    }
111}
112
113/// Scalar dot product fallback
114pub fn dot_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
115    a.iter().zip(b.iter()).map(|(&x, &y)| (x as i32) * (y as i32)).sum()
116}
117
118/// Scalar tail computation for SIMD remainder elements
119///
120/// Computes i8 dot product for elements starting at `start` index,
121/// used by AVX2/AVX-512/NEON functions for their remainder loops.
122#[inline]
123fn dot_i8_scalar_tail(a: &[i8], b: &[i8], start: usize) -> i32 {
124    a[start..].iter().zip(b[start..].iter()).map(|(&x, &y)| (x as i32) * (y as i32)).sum()
125}
126
127/// AVX2 SIMD dot product (x86_64)
128// SAFETY: caller verifies AVX2 support via is_x86_feature_detected!, slices have equal length
129#[cfg(target_arch = "x86_64")]
130#[target_feature(enable = "avx2")]
131#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
132unsafe fn dot_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
133    use std::arch::x86_64::*;
134
135    let n = a.len();
136    let mut sum = _mm256_setzero_si256();
137
138    // Process 32 elements at a time
139    let mut i = 0;
140    while i + 32 <= n {
141        let va = _mm256_loadu_si256(a[i..].as_ptr().cast::<__m256i>());
142        let vb = _mm256_loadu_si256(b[i..].as_ptr().cast::<__m256i>());
143
144        // Unpack to i16 and multiply
145        let lo_a = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 0));
146        let lo_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 0));
147        let hi_a = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
148        let hi_b = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
149
150        // madd: multiply adjacent pairs and sum to i32
151        let prod_lo = _mm256_madd_epi16(lo_a, lo_b);
152        let prod_hi = _mm256_madd_epi16(hi_a, hi_b);
153
154        sum = _mm256_add_epi32(sum, prod_lo);
155        sum = _mm256_add_epi32(sum, prod_hi);
156
157        i += 32;
158    }
159
160    // Horizontal sum
161    let sum128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1));
162    let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
163    let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
164    let result = _mm_cvtsi128_si32(sum32);
165
166    // Handle remaining elements
167    result + dot_i8_scalar_tail(a, b, i)
168}
169
170/// AVX-512 SIMD dot product (x86_64)
171/// Note: AVX-512 support varies by CPU; falls back to scalar for remaining elements
172// SAFETY: caller verifies AVX-512F+BW support via is_x86_feature_detected!, equal-length slices
173#[cfg(target_arch = "x86_64")]
174#[target_feature(enable = "avx512f", enable = "avx512bw")]
175#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
176unsafe fn dot_i8_avx512(a: &[i8], b: &[i8]) -> i32 {
177    use std::arch::x86_64::*;
178
179    let n = a.len();
180    let mut sum = _mm512_setzero_si512();
181
182    // Process 64 elements at a time
183    let mut i = 0;
184    while i + 64 <= n {
185        let va = _mm512_loadu_si512(a[i..].as_ptr().cast::<__m512i>());
186        let vb = _mm512_loadu_si512(b[i..].as_ptr().cast::<__m512i>());
187
188        // Extract 256-bit halves and process
189        let lo_a = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 0));
190        let lo_b = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 0));
191        let hi_a = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(va, 1));
192        let hi_b = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(vb, 1));
193
194        let prod_lo = _mm512_madd_epi16(lo_a, lo_b);
195        let prod_hi = _mm512_madd_epi16(hi_a, hi_b);
196
197        sum = _mm512_add_epi32(sum, prod_lo);
198        sum = _mm512_add_epi32(sum, prod_hi);
199
200        i += 64;
201    }
202
203    // Reduce 512-bit to scalar
204    let result = _mm512_reduce_add_epi32(sum);
205
206    // Handle remaining elements with scalar
207    result + dot_i8_scalar_tail(a, b, i)
208}
209
210/// ARM NEON SIMD dot product (aarch64)
211// SAFETY: NEON is always available on aarch64, slices have equal length (debug_assert)
212#[cfg(target_arch = "aarch64")]
213#[target_feature(enable = "neon")]
214#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
215unsafe fn dot_i8_neon(a: &[i8], b: &[i8]) -> i32 {
216    use std::arch::aarch64::*;
217
218    let n = a.len();
219    let mut sum = vdupq_n_s32(0);
220
221    // Process 16 elements at a time
222    let mut i = 0;
223    while i + 16 <= n {
224        let va = vld1q_s8(a[i..].as_ptr());
225        let vb = vld1q_s8(b[i..].as_ptr());
226
227        // Multiply-accumulate low and high halves
228        let lo_a = vmovl_s8(vget_low_s8(va));
229        let lo_b = vmovl_s8(vget_low_s8(vb));
230        let hi_a = vmovl_s8(vget_high_s8(va));
231        let hi_b = vmovl_s8(vget_high_s8(vb));
232
233        let prod_lo = vmull_s16(vget_low_s16(lo_a), vget_low_s16(lo_b));
234        let prod_lo2 = vmull_s16(vget_high_s16(lo_a), vget_high_s16(lo_b));
235        let prod_hi = vmull_s16(vget_low_s16(hi_a), vget_low_s16(hi_b));
236        let prod_hi2 = vmull_s16(vget_high_s16(hi_a), vget_high_s16(hi_b));
237
238        sum = vaddq_s32(sum, prod_lo);
239        sum = vaddq_s32(sum, prod_lo2);
240        sum = vaddq_s32(sum, prod_hi);
241        sum = vaddq_s32(sum, prod_hi2);
242
243        i += 16;
244    }
245
246    // Horizontal sum
247    let result = vaddvq_s32(sum);
248
249    // Handle remaining elements
250    result + dot_i8_scalar_tail(a, b, i)
251}