Skip to main content

nodedb_codec/vector_quant/ternary/
simd.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SIMD ternary dot product with runtime CPU-feature dispatch.
4
5#![allow(unsafe_op_in_unsafe_fn)]
6
7use std::sync::OnceLock;
8
9use super::packing::unpack_hot;
10
11/// Scalar fallback ternary dot product.
12fn ternary_dot_scalar(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
13    let a = unpack_hot(a_hot, dim);
14    let b = unpack_hot(b_hot, dim);
15    a.iter()
16        .zip(b.iter())
17        .map(|(&x, &y)| x as i32 * y as i32)
18        .sum()
19}
20
21#[cfg(target_arch = "x86_64")]
22#[target_feature(enable = "avx512f,avx512bw")]
23unsafe fn ternary_dot_avx512(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
24    use std::arch::x86_64::*;
25
26    let a_trits = unpack_hot(a_hot, dim);
27    let b_trits = unpack_hot(b_hot, dim);
28
29    let len = a_trits.len();
30    let chunks = len / 64;
31    let mut sum = _mm512_setzero_si512();
32
33    for i in 0..chunks {
34        let a_ptr = a_trits.as_ptr().add(i * 64) as *const __m512i;
35        let b_ptr = b_trits.as_ptr().add(i * 64) as *const __m512i;
36        let a_vec = _mm512_loadu_si512(a_ptr);
37        let b_vec = _mm512_loadu_si512(b_ptr);
38        let a_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_vec));
39        let a_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_vec, 1));
40        let b_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_vec));
41        let b_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_vec, 1));
42        let prod_lo = _mm512_mullo_epi16(a_lo, b_lo);
43        let prod_hi = _mm512_mullo_epi16(a_hi, b_hi);
44        let ones = _mm512_set1_epi16(1);
45        let sum_lo = _mm512_madd_epi16(prod_lo, ones);
46        let sum_hi = _mm512_madd_epi16(prod_hi, ones);
47        sum = _mm512_add_epi32(sum, _mm512_add_epi32(sum_lo, sum_hi));
48    }
49
50    let mut acc = _mm512_reduce_add_epi32(sum);
51    for i in (chunks * 64)..len {
52        acc += a_trits[i] as i32 * b_trits[i] as i32;
53    }
54    acc
55}
56
57#[cfg(target_arch = "aarch64")]
58fn ternary_dot_neon(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
59    use std::arch::aarch64::*;
60
61    let a_trits = unpack_hot(a_hot, dim);
62    let b_trits = unpack_hot(b_hot, dim);
63
64    let len = a_trits.len();
65    let chunks = len / 16;
66    let mut acc: i32;
67
68    unsafe {
69        let mut sum = vdupq_n_s32(0i32);
70        for i in 0..chunks {
71            let a_ptr = a_trits.as_ptr().add(i * 16);
72            let b_ptr = b_trits.as_ptr().add(i * 16);
73            let a_vec = vld1q_s8(a_ptr);
74            let b_vec = vld1q_s8(b_ptr);
75            let prod = vmulq_s8(a_vec, b_vec);
76            let prod_lo = vmovl_s8(vget_low_s8(prod));
77            let prod_hi = vmovl_s8(vget_high_s8(prod));
78            let prod32_lo = vmovl_s16(vget_low_s16(prod_lo));
79            let prod32_hi = vmovl_s16(vget_high_s16(prod_lo));
80            let prod32_lo2 = vmovl_s16(vget_low_s16(prod_hi));
81            let prod32_hi2 = vmovl_s16(vget_high_s16(prod_hi));
82            sum = vaddq_s32(
83                sum,
84                vaddq_s32(
85                    vaddq_s32(prod32_lo, prod32_hi),
86                    vaddq_s32(prod32_lo2, prod32_hi2),
87                ),
88            );
89        }
90        acc = vaddvq_s32(sum);
91        for i in (chunks * 16)..len {
92            acc += a_trits[i] as i32 * b_trits[i] as i32;
93        }
94    }
95    acc
96}
97
98type DotFn = fn(&[u8], &[u8], usize) -> i32;
99
100static DOT_FN: OnceLock<DotFn> = OnceLock::new();
101
102#[cfg(target_arch = "x86_64")]
103fn ternary_dot_avx512_trampoline(a: &[u8], b: &[u8], dim: usize) -> i32 {
104    unsafe { ternary_dot_avx512(a, b, dim) }
105}
106
107fn resolve() -> DotFn {
108    #[cfg(target_arch = "x86_64")]
109    {
110        if std::is_x86_feature_detected!("avx512f") && std::is_x86_feature_detected!("avx512bw") {
111            return ternary_dot_avx512_trampoline;
112        }
113    }
114    #[cfg(target_arch = "aarch64")]
115    {
116        return ternary_dot_neon;
117    }
118    #[allow(unreachable_code)]
119    {
120        ternary_dot_scalar
121    }
122}
123
124/// Compute ternary dot product between two hot-packed byte slices.
125#[inline]
126pub fn ternary_dot(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
127    let f = DOT_FN.get_or_init(resolve);
128    f(a_hot, b_hot, dim)
129}
130
131#[cfg(test)]
132mod tests {
133    use super::super::packing::pack_hot;
134    use super::*;
135
136    #[test]
137    fn simd_vs_scalar_agreement() {
138        let trits_a: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1];
139        let trits_b: Vec<i8> = vec![-1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1];
140        let dim = trits_a.len();
141        let hot_a = pack_hot(&trits_a);
142        let hot_b = pack_hot(&trits_b);
143
144        let scalar = ternary_dot_scalar(&hot_a, &hot_b, dim);
145        let dispatched = ternary_dot(&hot_a, &hot_b, dim);
146
147        let expected: i32 = trits_a
148            .iter()
149            .zip(trits_b.iter())
150            .map(|(&a, &b)| a as i32 * b as i32)
151            .sum();
152        assert_eq!(scalar, expected);
153        assert_eq!(scalar, dispatched);
154    }
155}