nodedb_codec/vector_quant/
hamming.rs1#![allow(unsafe_op_in_unsafe_fn)]
17
18use std::sync::OnceLock;
19
20type HammingFn = fn(a: &[u8], b: &[u8]) -> u32;
21
22static DISPATCH: OnceLock<HammingFn> = OnceLock::new();
23
24fn scalar(a: &[u8], b: &[u8]) -> u32 {
25 let chunks = a.len() / 8;
26 let rem = a.len() % 8;
27 let mut count = 0u32;
28 for i in 0..chunks {
29 let av = u64::from_ne_bytes(a[i * 8..i * 8 + 8].try_into().unwrap_or([0u8; 8]));
30 let bv = u64::from_ne_bytes(b[i * 8..i * 8 + 8].try_into().unwrap_or([0u8; 8]));
31 count += (av ^ bv).count_ones();
32 }
33 let base = chunks * 8;
34 for i in 0..rem {
35 count += (a[base + i] ^ b[base + i]).count_ones();
36 }
37 count
38}
39
40#[cfg(target_arch = "x86_64")]
41#[target_feature(enable = "avx512f,avx512vpopcntdq")]
42unsafe fn avx512(a: &[u8], b: &[u8]) -> u32 {
43 use std::arch::x86_64::*;
44 let mut count = 0u32;
45 let chunks = a.len() / 64;
46 let rem = a.len() % 64;
47 for i in 0..chunks {
48 let va = _mm512_loadu_si512(a.as_ptr().add(i * 64) as *const __m512i);
49 let vb = _mm512_loadu_si512(b.as_ptr().add(i * 64) as *const __m512i);
50 let xored = _mm512_xor_si512(va, vb);
51 let popcnt = _mm512_popcnt_epi64(xored);
52 let lo = _mm512_extracti64x4_epi64(popcnt, 0);
53 let hi = _mm512_extracti64x4_epi64(popcnt, 1);
54 let sum4 = _mm256_add_epi64(lo, hi);
55 let sum2 = _mm256_add_epi64(sum4, _mm256_permute4x64_epi64(sum4, 0b0100_1110));
56 let sum1 = _mm256_add_epi64(sum2, _mm256_permute4x64_epi64(sum2, 0b0001_0001));
57 count += _mm256_extract_epi64(sum1, 0) as u32;
58 }
59 let base = chunks * 64;
60 for i in 0..rem {
61 count += (a[base + i] ^ b[base + i]).count_ones();
62 }
63 count
64}
65
66#[cfg(target_arch = "x86_64")]
67#[target_feature(enable = "avx2")]
68unsafe fn avx2(a: &[u8], b: &[u8]) -> u32 {
69 let mut count = 0u32;
70 let n8 = a.len() / 8;
71 let ap = a.as_ptr() as *const u64;
72 let bp = b.as_ptr() as *const u64;
73 let mut i = 0usize;
74 while i + 4 <= n8 {
75 let x0 = (*ap.add(i)) ^ (*bp.add(i));
76 let x1 = (*ap.add(i + 1)) ^ (*bp.add(i + 1));
77 let x2 = (*ap.add(i + 2)) ^ (*bp.add(i + 2));
78 let x3 = (*ap.add(i + 3)) ^ (*bp.add(i + 3));
79 count += x0.count_ones() + x1.count_ones() + x2.count_ones() + x3.count_ones();
80 i += 4;
81 }
82 while i < n8 {
83 count += ((*ap.add(i)) ^ (*bp.add(i))).count_ones();
84 i += 1;
85 }
86 let tail_start = n8 * 8;
87 for j in tail_start..a.len() {
88 count += (*a.as_ptr().add(j) ^ *b.as_ptr().add(j)).count_ones();
89 }
90 count
91}
92
93#[cfg(target_arch = "x86_64")]
94fn avx512_trampoline(a: &[u8], b: &[u8]) -> u32 {
95 unsafe { avx512(a, b) }
96}
97
98#[cfg(target_arch = "x86_64")]
99fn avx2_trampoline(a: &[u8], b: &[u8]) -> u32 {
100 unsafe { avx2(a, b) }
101}
102
103#[cfg(target_arch = "aarch64")]
104fn neon(a: &[u8], b: &[u8]) -> u32 {
105 use std::arch::aarch64::*;
106 let mut count = 0u32;
107 let chunks = a.len() / 16;
108 let rem = a.len() % 16;
109 unsafe {
110 for i in 0..chunks {
111 let va = vld1q_u8(a.as_ptr().add(i * 16));
112 let vb = vld1q_u8(b.as_ptr().add(i * 16));
113 let xored = veorq_u8(va, vb);
114 let popcnt = vcntq_u8(xored);
115 count += vaddvq_u8(popcnt) as u32;
116 }
117 }
118 let base = chunks * 16;
119 for i in 0..rem {
120 count += (a[base + i] ^ b[base + i]).count_ones();
121 }
122 count
123}
124
125fn resolve() -> HammingFn {
126 #[cfg(target_arch = "x86_64")]
127 {
128 if std::is_x86_feature_detected!("avx512f")
129 && std::is_x86_feature_detected!("avx512vpopcntdq")
130 {
131 return avx512_trampoline;
132 }
133 if std::is_x86_feature_detected!("avx2") {
134 return avx2_trampoline;
135 }
136 }
137 #[cfg(target_arch = "aarch64")]
138 {
139 return neon;
140 }
141 #[allow(unreachable_code)]
142 {
143 scalar
144 }
145}
146
147#[inline]
153pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
154 debug_assert_eq!(a.len(), b.len(), "hamming_distance: slice length mismatch");
155 let len = a.len().min(b.len());
156 let f = DISPATCH.get_or_init(resolve);
157 f(&a[..len], &b[..len])
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn zero_distance_to_self() {
166 let bits = vec![0b1010_1010u8, 0b1100_1100, 0b1111_0000];
167 assert_eq!(hamming_distance(&bits, &bits), 0);
168 }
169
170 #[test]
171 fn full_inversion_byte() {
172 let a = vec![0xFFu8];
173 let b = vec![0x00u8];
174 assert_eq!(hamming_distance(&a, &b), 8);
175 }
176
177 #[test]
178 fn full_inversion_multibyte() {
179 let dim = 64;
180 let a: Vec<u8> = (0..dim as u8).collect();
181 let b: Vec<u8> = a.iter().map(|&x| !x).collect();
182 assert_eq!(hamming_distance(&a, &b), 512);
183 }
184
185 #[test]
186 fn agrees_across_lengths() {
187 for len in [1usize, 7, 8, 9, 15, 16, 31, 32, 63, 64, 65, 127, 128, 255] {
190 let a: Vec<u8> = (0..len).map(|i| (i * 31 + 7) as u8).collect();
191 let b: Vec<u8> = (0..len).map(|i| (i * 17 + 3) as u8).collect();
192 assert_eq!(hamming_distance(&a, &b), scalar(&a, &b), "len={len}");
193 }
194 }
195}