Skip to main content

nodedb_codec/vector_quant/
hamming.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared SIMD Hamming-distance kernel with runtime CPU-feature dispatch.
4//!
5//! Used by 1-bit quantizers ([`bbq`](super::bbq), [`rabitq`](super::rabitq))
6//! and any future binary code that needs to count differing bits between
7//! two equal-length packed-bit slices.
8//!
9//! Dispatch order (resolved once via [`OnceLock`]):
10//!
11//! 1. `x86_64` + `avx512f` + `avx512vpopcntdq` → AVX-512 VPOPCNTDQ kernel
12//! 2. `x86_64` + `avx2`                       → 64-bit popcnt unrolled 4×
13//! 3. `aarch64`                                → NEON `vcntq_u8`
14//! 4. fallback                                  → scalar 8-byte popcnt
15
16#![allow(unsafe_op_in_unsafe_fn)]
17
18use std::sync::OnceLock;
19
20type HammingFn = fn(a: &[u8], b: &[u8]) -> u32;
21
22static DISPATCH: OnceLock<HammingFn> = OnceLock::new();
23
24fn scalar(a: &[u8], b: &[u8]) -> u32 {
25    let chunks = a.len() / 8;
26    let rem = a.len() % 8;
27    let mut count = 0u32;
28    for i in 0..chunks {
29        let av = u64::from_ne_bytes(a[i * 8..i * 8 + 8].try_into().unwrap_or([0u8; 8]));
30        let bv = u64::from_ne_bytes(b[i * 8..i * 8 + 8].try_into().unwrap_or([0u8; 8]));
31        count += (av ^ bv).count_ones();
32    }
33    let base = chunks * 8;
34    for i in 0..rem {
35        count += (a[base + i] ^ b[base + i]).count_ones();
36    }
37    count
38}
39
40#[cfg(target_arch = "x86_64")]
41#[target_feature(enable = "avx512f,avx512vpopcntdq")]
42unsafe fn avx512(a: &[u8], b: &[u8]) -> u32 {
43    use std::arch::x86_64::*;
44    let mut count = 0u32;
45    let chunks = a.len() / 64;
46    let rem = a.len() % 64;
47    for i in 0..chunks {
48        let va = _mm512_loadu_si512(a.as_ptr().add(i * 64) as *const __m512i);
49        let vb = _mm512_loadu_si512(b.as_ptr().add(i * 64) as *const __m512i);
50        let xored = _mm512_xor_si512(va, vb);
51        let popcnt = _mm512_popcnt_epi64(xored);
52        let lo = _mm512_extracti64x4_epi64(popcnt, 0);
53        let hi = _mm512_extracti64x4_epi64(popcnt, 1);
54        let sum4 = _mm256_add_epi64(lo, hi);
55        let sum2 = _mm256_add_epi64(sum4, _mm256_permute4x64_epi64(sum4, 0b0100_1110));
56        let sum1 = _mm256_add_epi64(sum2, _mm256_permute4x64_epi64(sum2, 0b0001_0001));
57        count += _mm256_extract_epi64(sum1, 0) as u32;
58    }
59    let base = chunks * 64;
60    for i in 0..rem {
61        count += (a[base + i] ^ b[base + i]).count_ones();
62    }
63    count
64}
65
66#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68unsafe fn avx2(a: &[u8], b: &[u8]) -> u32 {
69    let mut count = 0u32;
70    let n8 = a.len() / 8;
71    let ap = a.as_ptr() as *const u64;
72    let bp = b.as_ptr() as *const u64;
73    let mut i = 0usize;
74    while i + 4 <= n8 {
75        let x0 = (*ap.add(i)) ^ (*bp.add(i));
76        let x1 = (*ap.add(i + 1)) ^ (*bp.add(i + 1));
77        let x2 = (*ap.add(i + 2)) ^ (*bp.add(i + 2));
78        let x3 = (*ap.add(i + 3)) ^ (*bp.add(i + 3));
79        count += x0.count_ones() + x1.count_ones() + x2.count_ones() + x3.count_ones();
80        i += 4;
81    }
82    while i < n8 {
83        count += ((*ap.add(i)) ^ (*bp.add(i))).count_ones();
84        i += 1;
85    }
86    let tail_start = n8 * 8;
87    for j in tail_start..a.len() {
88        count += (*a.as_ptr().add(j) ^ *b.as_ptr().add(j)).count_ones();
89    }
90    count
91}
92
93#[cfg(target_arch = "x86_64")]
94fn avx512_trampoline(a: &[u8], b: &[u8]) -> u32 {
95    unsafe { avx512(a, b) }
96}
97
98#[cfg(target_arch = "x86_64")]
99fn avx2_trampoline(a: &[u8], b: &[u8]) -> u32 {
100    unsafe { avx2(a, b) }
101}
102
103#[cfg(target_arch = "aarch64")]
104fn neon(a: &[u8], b: &[u8]) -> u32 {
105    use std::arch::aarch64::*;
106    let mut count = 0u32;
107    let chunks = a.len() / 16;
108    let rem = a.len() % 16;
109    unsafe {
110        for i in 0..chunks {
111            let va = vld1q_u8(a.as_ptr().add(i * 16));
112            let vb = vld1q_u8(b.as_ptr().add(i * 16));
113            let xored = veorq_u8(va, vb);
114            let popcnt = vcntq_u8(xored);
115            count += vaddvq_u8(popcnt) as u32;
116        }
117    }
118    let base = chunks * 16;
119    for i in 0..rem {
120        count += (a[base + i] ^ b[base + i]).count_ones();
121    }
122    count
123}
124
125fn resolve() -> HammingFn {
126    #[cfg(target_arch = "x86_64")]
127    {
128        if std::is_x86_feature_detected!("avx512f")
129            && std::is_x86_feature_detected!("avx512vpopcntdq")
130        {
131            return avx512_trampoline;
132        }
133        if std::is_x86_feature_detected!("avx2") {
134            return avx2_trampoline;
135        }
136    }
137    #[cfg(target_arch = "aarch64")]
138    {
139        return neon;
140    }
141    #[allow(unreachable_code)]
142    {
143        scalar
144    }
145}
146
147/// Count the number of differing bits between two equal-length packed-bit
148/// slices, dispatching to the best available SIMD kernel at runtime.
149///
150/// In debug builds, panics if `a.len() != b.len()`. In release builds, the
151/// shorter length is used and the tail of the longer slice is ignored.
152#[inline]
153pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
154    debug_assert_eq!(a.len(), b.len(), "hamming_distance: slice length mismatch");
155    let len = a.len().min(b.len());
156    let f = DISPATCH.get_or_init(resolve);
157    f(&a[..len], &b[..len])
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn zero_distance_to_self() {
166        let bits = vec![0b1010_1010u8, 0b1100_1100, 0b1111_0000];
167        assert_eq!(hamming_distance(&bits, &bits), 0);
168    }
169
170    #[test]
171    fn full_inversion_byte() {
172        let a = vec![0xFFu8];
173        let b = vec![0x00u8];
174        assert_eq!(hamming_distance(&a, &b), 8);
175    }
176
177    #[test]
178    fn full_inversion_multibyte() {
179        let dim = 64;
180        let a: Vec<u8> = (0..dim as u8).collect();
181        let b: Vec<u8> = a.iter().map(|&x| !x).collect();
182        assert_eq!(hamming_distance(&a, &b), 512);
183    }
184
185    #[test]
186    fn agrees_across_lengths() {
187        // Cross-check dispatched kernel against scalar reference for various
188        // lengths so any SIMD path divergence is caught.
189        for len in [1usize, 7, 8, 9, 15, 16, 31, 32, 63, 64, 65, 127, 128, 255] {
190            let a: Vec<u8> = (0..len).map(|i| (i * 31 + 7) as u8).collect();
191            let b: Vec<u8> = (0..len).map(|i| (i * 17 + 3) as u8).collect();
192            assert_eq!(hamming_distance(&a, &b), scalar(&a, &b), "len={len}");
193        }
194    }
195}