nodedb_codec/vector_quant/ternary/
simd.rs1#![allow(unsafe_op_in_unsafe_fn)]
6
7use std::sync::OnceLock;
8
9use super::packing::unpack_hot;
10
11fn ternary_dot_scalar(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
13 let a = unpack_hot(a_hot, dim);
14 let b = unpack_hot(b_hot, dim);
15 a.iter()
16 .zip(b.iter())
17 .map(|(&x, &y)| x as i32 * y as i32)
18 .sum()
19}
20
21#[cfg(target_arch = "x86_64")]
22#[target_feature(enable = "avx512f,avx512bw")]
23unsafe fn ternary_dot_avx512(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
24 use std::arch::x86_64::*;
25
26 let a_trits = unpack_hot(a_hot, dim);
27 let b_trits = unpack_hot(b_hot, dim);
28
29 let len = a_trits.len();
30 let chunks = len / 64;
31 let mut sum = _mm512_setzero_si512();
32
33 for i in 0..chunks {
34 let a_ptr = a_trits.as_ptr().add(i * 64) as *const __m512i;
35 let b_ptr = b_trits.as_ptr().add(i * 64) as *const __m512i;
36 let a_vec = _mm512_loadu_si512(a_ptr);
37 let b_vec = _mm512_loadu_si512(b_ptr);
38 let a_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_vec));
39 let a_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_vec, 1));
40 let b_lo = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_vec));
41 let b_hi = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_vec, 1));
42 let prod_lo = _mm512_mullo_epi16(a_lo, b_lo);
43 let prod_hi = _mm512_mullo_epi16(a_hi, b_hi);
44 let ones = _mm512_set1_epi16(1);
45 let sum_lo = _mm512_madd_epi16(prod_lo, ones);
46 let sum_hi = _mm512_madd_epi16(prod_hi, ones);
47 sum = _mm512_add_epi32(sum, _mm512_add_epi32(sum_lo, sum_hi));
48 }
49
50 let mut acc = _mm512_reduce_add_epi32(sum);
51 for i in (chunks * 64)..len {
52 acc += a_trits[i] as i32 * b_trits[i] as i32;
53 }
54 acc
55}
56
57#[cfg(target_arch = "aarch64")]
58fn ternary_dot_neon(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
59 use std::arch::aarch64::*;
60
61 let a_trits = unpack_hot(a_hot, dim);
62 let b_trits = unpack_hot(b_hot, dim);
63
64 let len = a_trits.len();
65 let chunks = len / 16;
66 let mut acc: i32;
67
68 unsafe {
69 let mut sum = vdupq_n_s32(0i32);
70 for i in 0..chunks {
71 let a_ptr = a_trits.as_ptr().add(i * 16);
72 let b_ptr = b_trits.as_ptr().add(i * 16);
73 let a_vec = vld1q_s8(a_ptr);
74 let b_vec = vld1q_s8(b_ptr);
75 let prod = vmulq_s8(a_vec, b_vec);
76 let prod_lo = vmovl_s8(vget_low_s8(prod));
77 let prod_hi = vmovl_s8(vget_high_s8(prod));
78 let prod32_lo = vmovl_s16(vget_low_s16(prod_lo));
79 let prod32_hi = vmovl_s16(vget_high_s16(prod_lo));
80 let prod32_lo2 = vmovl_s16(vget_low_s16(prod_hi));
81 let prod32_hi2 = vmovl_s16(vget_high_s16(prod_hi));
82 sum = vaddq_s32(
83 sum,
84 vaddq_s32(
85 vaddq_s32(prod32_lo, prod32_hi),
86 vaddq_s32(prod32_lo2, prod32_hi2),
87 ),
88 );
89 }
90 acc = vaddvq_s32(sum);
91 for i in (chunks * 16)..len {
92 acc += a_trits[i] as i32 * b_trits[i] as i32;
93 }
94 }
95 acc
96}
97
98type DotFn = fn(&[u8], &[u8], usize) -> i32;
99
100static DOT_FN: OnceLock<DotFn> = OnceLock::new();
101
102#[cfg(target_arch = "x86_64")]
103fn ternary_dot_avx512_trampoline(a: &[u8], b: &[u8], dim: usize) -> i32 {
104 unsafe { ternary_dot_avx512(a, b, dim) }
105}
106
107fn resolve() -> DotFn {
108 #[cfg(target_arch = "x86_64")]
109 {
110 if std::is_x86_feature_detected!("avx512f") && std::is_x86_feature_detected!("avx512bw") {
111 return ternary_dot_avx512_trampoline;
112 }
113 }
114 #[cfg(target_arch = "aarch64")]
115 {
116 return ternary_dot_neon;
117 }
118 #[allow(unreachable_code)]
119 {
120 ternary_dot_scalar
121 }
122}
123
124#[inline]
126pub fn ternary_dot(a_hot: &[u8], b_hot: &[u8], dim: usize) -> i32 {
127 let f = DOT_FN.get_or_init(resolve);
128 f(a_hot, b_hot, dim)
129}
130
131#[cfg(test)]
132mod tests {
133 use super::super::packing::pack_hot;
134 use super::*;
135
136 #[test]
137 fn simd_vs_scalar_agreement() {
138 let trits_a: Vec<i8> = vec![1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1];
139 let trits_b: Vec<i8> = vec![-1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1, 1, 0, -1];
140 let dim = trits_a.len();
141 let hot_a = pack_hot(&trits_a);
142 let hot_b = pack_hot(&trits_b);
143
144 let scalar = ternary_dot_scalar(&hot_a, &hot_b, dim);
145 let dispatched = ternary_dot(&hot_a, &hot_b, dim);
146
147 let expected: i32 = trits_a
148 .iter()
149 .zip(trits_b.iter())
150 .map(|(&a, &b)| a as i32 * b as i32)
151 .sum();
152 assert_eq!(scalar, expected);
153 assert_eq!(scalar, dispatched);
154 }
155}