1#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10
11use std::sync::OnceLock;
12
13use super::simd_config;
14
15#[derive(Debug, Clone, Copy)]
19pub struct QuantizationParams {
20 pub scale: f32,
22 pub zero_point: i8,
24 pub min_val: f32,
26 pub max_val: f32,
28}
29
30impl QuantizationParams {
31 pub fn from_vector(vector: &[f32]) -> Self {
35 let mut min_val = f32::INFINITY;
37 let mut max_val = f32::NEG_INFINITY;
38
39 for &v in vector {
40 if v.is_finite() {
41 min_val = min_val.min(v);
42 max_val = max_val.max(v);
43 }
44 }
45
46 if !min_val.is_finite() || !max_val.is_finite() {
48 min_val = 0.0;
49 max_val = 0.0;
50 }
51
52 let max_abs = min_val.abs().max(max_val.abs());
54
55 let scale = if max_abs > 1e-10 {
57 127.0 / max_abs
58 } else {
59 1.0 };
61
62 Self {
63 scale,
64 zero_point: 0,
65 min_val,
66 max_val,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
75pub struct QuantizedVector {
76 pub data: Vec<i8>,
83 pub params: QuantizationParams,
85 pub norm: f32,
87}
88
89impl QuantizedVector {
90 pub fn from_f32(vector: &[f32]) -> Self {
92 let mut params = QuantizationParams::from_vector(vector);
93
94 if !params.scale.is_finite() || params.scale == 0.0 {
96 params.scale = 1.0;
97 }
98
99 let mut norm_sq = 0.0f32;
101 for &v in vector {
102 if v.is_finite() {
103 norm_sq += v * v;
104 }
105 }
106 let norm = norm_sq.sqrt();
107
108 let data: Vec<i8> = vector
109 .iter()
110 .map(|&v| {
111 if !v.is_finite() {
112 0
113 } else {
114 (v * params.scale).round().clamp(-127.0, 127.0) as i8
115 }
116 })
117 .collect();
118
119 Self { data, params, norm }
120 }
121
122 pub fn to_f32(&self) -> Vec<f32> {
133 let scale = if self.params.scale.is_finite() && self.params.scale != 0.0 {
134 self.params.scale
135 } else {
136 1.0
137 };
138
139 self.data.iter().map(|&v| v as f32 / scale).collect()
140 }
141
142 #[inline]
144 pub fn dot_product(&self, other: &QuantizedVector) -> f32 {
145 dot_product_i8(self, other)
146 }
147
148 #[inline]
150 pub fn cosine_similarity(&self, other: &QuantizedVector) -> f32 {
151 cosine_similarity_i8(self, other)
152 }
153}
154
155#[inline]
173pub fn dot_product_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
174 assert!(
178 a.data.iter().all(|&v| v != -128i8),
179 "QuantizedVector a contains -128, which violates the [-127, 127] VNNI invariant"
180 );
181 assert!(
182 b.data.iter().all(|&v| v != -128i8),
183 "QuantizedVector b contains -128, which violates the [-127, 127] VNNI invariant"
184 );
185
186 if a.data.len() != b.data.len() {
188 return 0.0;
189 }
190 debug_assert_eq!(a.data.len(), b.data.len());
191
192 let denom = a.params.scale * b.params.scale;
193 if denom == 0.0 || !denom.is_finite() {
194 return 0.0;
195 }
196
197 dot_product_i8_raw(&a.data, &b.data) / denom
198}
199
200#[inline]
205pub(crate) fn dot_product_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
206 if a.data.len() != b.data.len() {
207 return 0.0;
208 }
209 let denom = a.params.scale * b.params.scale;
210 if denom == 0.0 || !denom.is_finite() {
211 return 0.0;
212 }
213 debug_assert!(a.data.iter().all(|&v| v != i8::MIN));
214 debug_assert!(b.data.iter().all(|&v| v != i8::MIN));
215 dot_product_i8_raw(&a.data, &b.data) / denom
216}
217
218#[inline]
222pub fn cosine_similarity_i8(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
223 let denom = a.norm * b.norm;
224 if denom == 0.0 || !denom.is_finite() {
225 return 0.0;
226 }
227 dot_product_i8(a, b) / denom
228}
229
230#[inline]
236pub(crate) fn cosine_similarity_i8_trusted(a: &QuantizedVector, b: &QuantizedVector) -> f32 {
237 let denom = a.norm * b.norm;
238 if denom == 0.0 || !denom.is_finite() {
239 return 0.0;
240 }
241 dot_product_i8_trusted(a, b) / denom
242}
243
244#[cfg(target_arch = "aarch64")]
259#[inline]
260unsafe fn dot_product_i8_neon_unrolled(a: &[i8], b: &[i8]) -> f32 {
261 const SIMD_WIDTH: usize = 16;
262 const UNROLL: usize = 4;
263 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
264 let n = a.len();
265 debug_assert_eq!(n, b.len());
266 let chunks = n / CHUNK_SIZE;
267
268 let mut sum0 = vdupq_n_s32(0);
270 let mut sum1 = vdupq_n_s32(0);
271 let mut sum2 = vdupq_n_s32(0);
272 let mut sum3 = vdupq_n_s32(0);
273
274 for i in 0..chunks {
275 let base = i * CHUNK_SIZE;
276
277 let a0 = vld1q_s8(a.as_ptr().add(base));
279 let b0 = vld1q_s8(b.as_ptr().add(base));
280 let a0_lo = vget_low_s8(a0);
281 let a0_hi = vget_high_s8(a0);
282 let b0_lo = vget_low_s8(b0);
283 let b0_hi = vget_high_s8(b0);
284 let prod0_lo = vmull_s8(a0_lo, b0_lo);
285 let prod0_hi = vmull_s8(a0_hi, b0_hi);
286 sum0 = vpadalq_s16(sum0, prod0_lo);
287 sum0 = vpadalq_s16(sum0, prod0_hi);
288
289 let a1 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH));
291 let b1 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH));
292 let a1_lo = vget_low_s8(a1);
293 let a1_hi = vget_high_s8(a1);
294 let b1_lo = vget_low_s8(b1);
295 let b1_hi = vget_high_s8(b1);
296 let prod1_lo = vmull_s8(a1_lo, b1_lo);
297 let prod1_hi = vmull_s8(a1_hi, b1_hi);
298 sum1 = vpadalq_s16(sum1, prod1_lo);
299 sum1 = vpadalq_s16(sum1, prod1_hi);
300
301 let a2 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 2));
303 let b2 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 2));
304 let a2_lo = vget_low_s8(a2);
305 let a2_hi = vget_high_s8(a2);
306 let b2_lo = vget_low_s8(b2);
307 let b2_hi = vget_high_s8(b2);
308 let prod2_lo = vmull_s8(a2_lo, b2_lo);
309 let prod2_hi = vmull_s8(a2_hi, b2_hi);
310 sum2 = vpadalq_s16(sum2, prod2_lo);
311 sum2 = vpadalq_s16(sum2, prod2_hi);
312
313 let a3 = vld1q_s8(a.as_ptr().add(base + SIMD_WIDTH * 3));
315 let b3 = vld1q_s8(b.as_ptr().add(base + SIMD_WIDTH * 3));
316 let a3_lo = vget_low_s8(a3);
317 let a3_hi = vget_high_s8(a3);
318 let b3_lo = vget_low_s8(b3);
319 let b3_hi = vget_high_s8(b3);
320 let prod3_lo = vmull_s8(a3_lo, b3_lo);
321 let prod3_hi = vmull_s8(a3_hi, b3_hi);
322 sum3 = vpadalq_s16(sum3, prod3_lo);
323 sum3 = vpadalq_s16(sum3, prod3_hi);
324 }
325
326 let sum01 = vaddq_s32(sum0, sum1);
328 let sum23 = vaddq_s32(sum2, sum3);
329 let mut sum_vec = vaddq_s32(sum01, sum23);
330
331 let tail_start = chunks * CHUNK_SIZE;
334 let tail_chunks = (n - tail_start) / SIMD_WIDTH;
335 for j in 0..tail_chunks {
336 let base = tail_start + j * SIMD_WIDTH;
337 let at = vld1q_s8(a.as_ptr().add(base));
338 let bt = vld1q_s8(b.as_ptr().add(base));
339 let at_lo = vget_low_s8(at);
340 let at_hi = vget_high_s8(at);
341 let bt_lo = vget_low_s8(bt);
342 let bt_hi = vget_high_s8(bt);
343 let pt_lo = vmull_s8(at_lo, bt_lo);
344 let pt_hi = vmull_s8(at_hi, bt_hi);
345 sum_vec = vpadalq_s16(sum_vec, pt_lo);
346 sum_vec = vpadalq_s16(sum_vec, pt_hi);
347 }
348
349 let sum = vgetq_lane_s32(sum_vec, 0)
351 + vgetq_lane_s32(sum_vec, 1)
352 + vgetq_lane_s32(sum_vec, 2)
353 + vgetq_lane_s32(sum_vec, 3);
354
355 let remainder_start = tail_start + tail_chunks * SIMD_WIDTH;
357 let remainder: i32 = a[remainder_start..]
358 .iter()
359 .zip(b[remainder_start..].iter())
360 .map(|(&x, &y)| x as i32 * y as i32)
361 .sum();
362
363 (sum + remainder) as f32
364}
365
366#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
373#[target_feature(enable = "avx512f", enable = "avx512bw")]
374#[inline]
375unsafe fn mm512_sign_epi8(b: __m512i, a: __m512i) -> __m512i {
376 let zero = _mm512_setzero_si512();
377 let neg_b = _mm512_sub_epi8(zero, b);
378 let mask_neg = _mm512_cmplt_epi8_mask(a, zero);
380 let mask_zero = _mm512_cmpeq_epi8_mask(a, zero);
382 let result = _mm512_mask_blend_epi8(mask_neg, b, neg_b);
384 _mm512_mask_blend_epi8(mask_zero, result, zero)
386}
387
388#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
404#[target_feature(enable = "avx512f", enable = "avx512vnni", enable = "avx512bw")]
405unsafe fn dot_product_i8_avx512vnni(a: &[i8], b: &[i8]) -> f32 {
406 const SIMD_WIDTH: usize = 64; const UNROLL: usize = 4;
408 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
409 let n = a.len();
410 debug_assert_eq!(n, b.len());
411 debug_assert!(a.iter().all(|&v| v != i8::MIN));
412 debug_assert!(b.iter().all(|&v| v != i8::MIN));
413 let chunks = n / CHUNK_SIZE;
414
415 let mut sum0 = _mm512_setzero_si512();
417 let mut sum1 = _mm512_setzero_si512();
418 let mut sum2 = _mm512_setzero_si512();
419 let mut sum3 = _mm512_setzero_si512();
420
421 for i in 0..chunks {
422 let base = i * CHUNK_SIZE;
423
424 let a0 = _mm512_loadu_si512(a.as_ptr().add(base) as *const __m512i);
427 let b0 = _mm512_loadu_si512(b.as_ptr().add(base) as *const __m512i);
428 let a0_abs = _mm512_abs_epi8(a0);
429 let b0_signed = mm512_sign_epi8(b0, a0);
430 sum0 = _mm512_dpbusd_epi32(sum0, a0_abs, b0_signed);
431
432 let a1 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
433 let b1 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH) as *const __m512i);
434 let a1_abs = _mm512_abs_epi8(a1);
435 let b1_signed = mm512_sign_epi8(b1, a1);
436 sum1 = _mm512_dpbusd_epi32(sum1, a1_abs, b1_signed);
437
438 let a2 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
439 let b2 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m512i);
440 let a2_abs = _mm512_abs_epi8(a2);
441 let b2_signed = mm512_sign_epi8(b2, a2);
442 sum2 = _mm512_dpbusd_epi32(sum2, a2_abs, b2_signed);
443
444 let a3 = _mm512_loadu_si512(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
445 let b3 = _mm512_loadu_si512(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m512i);
446 let a3_abs = _mm512_abs_epi8(a3);
447 let b3_signed = mm512_sign_epi8(b3, a3);
448 sum3 = _mm512_dpbusd_epi32(sum3, a3_abs, b3_signed);
449 }
450
451 let sum01 = _mm512_add_epi32(sum0, sum1);
453 let sum23 = _mm512_add_epi32(sum2, sum3);
454 let sum_vec = _mm512_add_epi32(sum01, sum23);
455
456 let sum = _mm512_reduce_add_epi32(sum_vec);
458
459 let remainder_start = chunks * CHUNK_SIZE;
461 let remainder: i32 = a[remainder_start..]
462 .iter()
463 .zip(b[remainder_start..].iter())
464 .map(|(&x, &y)| x as i32 * y as i32)
465 .sum();
466
467 (sum + remainder) as f32
468}
469
470#[cfg(target_arch = "x86_64")]
483#[target_feature(enable = "avx2")]
484unsafe fn dot_product_i8_avx2_unrolled(a: &[i8], b: &[i8]) -> f32 {
485 const SIMD_WIDTH: usize = 32;
486 const UNROLL: usize = 4;
487 const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
488 let n = a.len();
489 debug_assert_eq!(n, b.len());
490 debug_assert!(a.iter().all(|&v| v != i8::MIN));
491 debug_assert!(b.iter().all(|&v| v != i8::MIN));
492 let chunks = n / CHUNK_SIZE;
493
494 let mut sum0 = _mm256_setzero_si256();
496 let mut sum1 = _mm256_setzero_si256();
497 let mut sum2 = _mm256_setzero_si256();
498 let mut sum3 = _mm256_setzero_si256();
499
500 let ones = _mm256_set1_epi16(1);
501
502 for i in 0..chunks {
503 let base = i * CHUNK_SIZE;
504
505 let a0 = _mm256_loadu_si256(a.as_ptr().add(base) as *const __m256i);
507 let b0 = _mm256_loadu_si256(b.as_ptr().add(base) as *const __m256i);
508 let prod0 = _mm256_maddubs_epi16(_mm256_abs_epi8(a0), _mm256_sign_epi8(b0, a0));
509 let prod0_32 = _mm256_madd_epi16(prod0, ones);
510 sum0 = _mm256_add_epi32(sum0, prod0_32);
511
512 let a1 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
514 let b1 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH) as *const __m256i);
515 let prod1 = _mm256_maddubs_epi16(_mm256_abs_epi8(a1), _mm256_sign_epi8(b1, a1));
516 let prod1_32 = _mm256_madd_epi16(prod1, ones);
517 sum1 = _mm256_add_epi32(sum1, prod1_32);
518
519 let a2 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
521 let b2 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 2) as *const __m256i);
522 let prod2 = _mm256_maddubs_epi16(_mm256_abs_epi8(a2), _mm256_sign_epi8(b2, a2));
523 let prod2_32 = _mm256_madd_epi16(prod2, ones);
524 sum2 = _mm256_add_epi32(sum2, prod2_32);
525
526 let a3 = _mm256_loadu_si256(a.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
528 let b3 = _mm256_loadu_si256(b.as_ptr().add(base + SIMD_WIDTH * 3) as *const __m256i);
529 let prod3 = _mm256_maddubs_epi16(_mm256_abs_epi8(a3), _mm256_sign_epi8(b3, a3));
530 let prod3_32 = _mm256_madd_epi16(prod3, ones);
531 sum3 = _mm256_add_epi32(sum3, prod3_32);
532 }
533
534 let sum01 = _mm256_add_epi32(sum0, sum1);
536 let sum23 = _mm256_add_epi32(sum2, sum3);
537 let sum_vec = _mm256_add_epi32(sum01, sum23);
538
539 let sum128_lo = _mm256_castsi256_si128(sum_vec);
541 let sum128_hi = _mm256_extracti128_si256(sum_vec, 1);
542 let sum128 = _mm_add_epi32(sum128_lo, sum128_hi);
543 let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
544 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
545 let sum = _mm_cvtsi128_si32(sum32);
546
547 let remainder_start = chunks * CHUNK_SIZE;
549 let remainder: i32 = a[remainder_start..]
550 .iter()
551 .zip(b[remainder_start..].iter())
552 .map(|(&x, &y)| x as i32 * y as i32)
553 .sum();
554
555 (sum + remainder) as f32
556}
557
558pub type I8DotKernel = fn(&[i8], &[i8]) -> f32;
564
565static I8_DOT_KERNEL: OnceLock<I8DotKernel> = OnceLock::new();
566
567#[inline]
572pub fn resolved_i8_dot_kernel() -> I8DotKernel {
573 *I8_DOT_KERNEL.get_or_init(resolve_i8_dot_kernel)
574}
575
576fn resolve_i8_dot_kernel() -> I8DotKernel {
577 let config = simd_config();
578
579 #[cfg(target_arch = "aarch64")]
580 {
581 if config.neon_enabled {
582 return dot_product_i8_neon_kernel;
583 }
584 }
585
586 #[cfg(target_arch = "x86_64")]
587 {
588 #[cfg(feature = "avx512")]
589 {
590 if config.avx512vnni_enabled {
591 return dot_product_i8_avx512vnni_kernel;
592 }
593 }
594 if config.avx2_enabled {
595 return dot_product_i8_avx2_kernel;
596 }
597 }
598
599 dot_product_i8_scalar_kernel
600}
601
602#[cfg(target_arch = "aarch64")]
603fn dot_product_i8_neon_kernel(a: &[i8], b: &[i8]) -> f32 {
604 unsafe { dot_product_i8_neon_unrolled(a, b) }
606}
607
608#[cfg(all(target_arch = "x86_64", feature = "avx512"))]
609fn dot_product_i8_avx512vnni_kernel(a: &[i8], b: &[i8]) -> f32 {
610 debug_assert!(a.iter().all(|&v| v != i8::MIN));
611 debug_assert!(b.iter().all(|&v| v != i8::MIN));
612 unsafe { dot_product_i8_avx512vnni(a, b) }
614}
615
616#[cfg(target_arch = "x86_64")]
617fn dot_product_i8_avx2_kernel(a: &[i8], b: &[i8]) -> f32 {
618 debug_assert!(a.iter().all(|&v| v != i8::MIN));
619 debug_assert!(b.iter().all(|&v| v != i8::MIN));
620 unsafe { dot_product_i8_avx2_unrolled(a, b) }
622}
623
624fn dot_product_i8_scalar_kernel(a: &[i8], b: &[i8]) -> f32 {
625 a.iter()
626 .zip(b.iter())
627 .map(|(&x, &y)| x as i32 * y as i32)
628 .sum::<i32>() as f32
629}
630
631#[inline]
648pub fn dot_product_i8_raw(a: &[i8], b: &[i8]) -> f32 {
649 if a.len() != b.len() {
650 return 0.0;
651 }
652 debug_assert_eq!(a.len(), b.len());
653 resolved_i8_dot_kernel()(a, b)
654}
655
656#[cfg(test)]
657mod simd_parity_tests {
658 use super::*;
659
660 fn gen_vec(dim: usize, seed: u64) -> Vec<f32> {
661 let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
662 (0..dim)
663 .map(|i| {
664 state = state
665 .wrapping_mul(6364136223846793005)
666 .wrapping_add(1442695040888963407)
667 .wrapping_add(i as u64);
668 let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
669 unit * 2.0 - 1.0
670 })
671 .collect()
672 }
673
674 #[test]
676 fn test_i8_neon_scalar_parity() {
677 #[cfg(target_arch = "aarch64")]
678 for dim in [7usize, 16, 64, 128, 384, 768] {
679 let a_q = QuantizedVector::from_f32(&gen_vec(dim, 200 + dim as u64));
680 let b_q = QuantizedVector::from_f32(&gen_vec(dim, 300 + dim as u64));
681
682 let neon = unsafe { dot_product_i8_neon_unrolled(&a_q.data, &b_q.data) };
684 let scalar: f32 = a_q
685 .data
686 .iter()
687 .zip(b_q.data.iter())
688 .map(|(&x, &y)| x as i32 * y as i32)
689 .sum::<i32>() as f32;
690
691 let diff = (neon - scalar).abs();
692 assert!(
693 diff <= 1.0,
694 "NEON vs scalar i8 dot product dim={dim}: neon={neon} scalar={scalar} diff={diff}"
695 );
696 }
697 }
698
699 #[test]
701 fn test_i8_avx2_scalar_parity() {
702 #[cfg(target_arch = "x86_64")]
703 if std::arch::is_x86_feature_detected!("avx2") {
704 for dim in [7usize, 16, 64, 128, 384, 768] {
705 let a_q = QuantizedVector::from_f32(&gen_vec(dim, 400 + dim as u64));
706 let b_q = QuantizedVector::from_f32(&gen_vec(dim, 500 + dim as u64));
707
708 let avx2 = unsafe { dot_product_i8_avx2_unrolled(&a_q.data, &b_q.data) };
710 let scalar: f32 = a_q
711 .data
712 .iter()
713 .zip(b_q.data.iter())
714 .map(|(&x, &y)| x as i32 * y as i32)
715 .sum::<i32>() as f32;
716
717 let diff = (avx2 - scalar).abs();
718 assert!(
719 diff <= 1.0,
720 "AVX2 vs scalar i8 dot product dim={dim}: avx2={avx2} scalar={scalar} diff={diff}"
721 );
722 }
723 }
724 }
725}