1use crate::compression::ADCTable;
31
32pub const BATCH_SIZE: usize = 32;
34
35#[derive(Debug, Clone)]
40pub struct FastScanLUT {
41 luts_lo: Vec<[u8; 16]>,
43
44 luts_hi: Vec<[u8; 16]>,
46
47 scale: f32,
49
50 offset: f32,
52}
53
54impl FastScanLUT {
55 #[must_use]
63 pub fn from_adc_table(adc: &ADCTable) -> Option<Self> {
64 if adc.bits() != 4 {
66 return None;
67 }
68
69 let dimensions = adc.dimensions();
70 if dimensions == 0 || !dimensions.is_multiple_of(2) {
71 return None;
72 }
73
74 let num_sq = dimensions / 2;
75
76 let mut global_min = f32::MAX;
79 let mut global_max = f32::MIN;
80
81 for sq in 0..num_sq {
82 for lo_code in 0..16 {
83 for hi_code in 0..16 {
84 let dist_lo = adc.get(sq * 2, lo_code);
85 let dist_hi = adc.get(sq * 2 + 1, hi_code);
86 let sum = dist_lo + dist_hi;
87 global_min = global_min.min(sum);
88 global_max = global_max.max(sum);
89 }
90 }
91 }
92
93 let safe_max_per_nibble = (65535.0 / (num_sq * 2) as f32).floor().min(127.0);
98
99 let range = global_max - global_min;
102 let scale_factor = if range > 1e-7 {
103 safe_max_per_nibble / (range / 2.0) } else {
105 1.0
106 };
107
108 let offset = global_min;
109
110 let mut luts_lo = Vec::with_capacity(num_sq);
112 let mut luts_hi = Vec::with_capacity(num_sq);
113
114 for sq in 0..num_sq {
115 let dim_lo = sq * 2;
116 let dim_hi = sq * 2 + 1;
117
118 let mut lut_lo = [0u8; 16];
120 for (code, entry) in lut_lo.iter_mut().enumerate() {
121 let dist = adc.get(dim_lo, code);
122 *entry = ((dist - offset / 2.0) * scale_factor)
124 .round()
125 .clamp(0.0, safe_max_per_nibble) as u8;
126 }
127
128 let mut lut_hi = [0u8; 16];
130 for (code, entry) in lut_hi.iter_mut().enumerate() {
131 let dist = adc.get(dim_hi, code);
132 *entry = ((dist - offset / 2.0) * scale_factor)
133 .round()
134 .clamp(0.0, safe_max_per_nibble) as u8;
135 }
136
137 luts_lo.push(lut_lo);
138 luts_hi.push(lut_hi);
139 }
140
141 Some(Self {
142 luts_lo,
143 luts_hi,
144 scale: 1.0 / scale_factor,
145 offset,
146 })
147 }
148
149 #[must_use]
151 pub fn num_sq(&self) -> usize {
152 self.luts_lo.len()
153 }
154
155 #[must_use]
157 pub fn luts_lo(&self) -> &[[u8; 16]] {
158 &self.luts_lo
159 }
160
161 #[must_use]
163 pub fn luts_hi(&self) -> &[[u8; 16]] {
164 &self.luts_hi
165 }
166
167 #[must_use]
169 pub fn to_f32(&self, accumulated: u16) -> f32 {
170 accumulated as f32 * self.scale + self.offset
171 }
172}
173
174#[cfg(target_arch = "aarch64")]
184#[must_use]
185pub fn fastscan_batch_neon(
186 luts_lo: &[[u8; 16]],
187 luts_hi: &[[u8; 16]],
188 interleaved_codes: &[u8],
189) -> [u16; BATCH_SIZE] {
190 use std::arch::aarch64::{
191 uint16x8_t, vaddl_u8, vaddq_u16, vandq_u8, vdupq_n_u16, vdupq_n_u8, vget_high_u8,
192 vget_low_u8, vld1q_u8, vqtbl1q_u8, vshrq_n_u8, vst1q_u16,
193 };
194
195 unsafe {
196 let low_mask = vdupq_n_u8(0x0F);
197
198 let mut accum0: uint16x8_t = vdupq_n_u16(0);
200 let mut accum1: uint16x8_t = vdupq_n_u16(0);
201 let mut accum2: uint16x8_t = vdupq_n_u16(0);
202 let mut accum3: uint16x8_t = vdupq_n_u16(0);
203
204 for sq in 0..luts_lo.len() {
206 let base = sq * BATCH_SIZE;
207
208 let lut_lo_vec = vld1q_u8(luts_lo[sq].as_ptr());
210 let lut_hi_vec = vld1q_u8(luts_hi[sq].as_ptr());
211
212 let codes_0_15 = vld1q_u8(interleaved_codes.as_ptr().add(base));
214 let codes_16_31 = vld1q_u8(interleaved_codes.as_ptr().add(base + 16));
215
216 let idx_lo_0 = vandq_u8(codes_0_15, low_mask);
218 let idx_lo_1 = vandq_u8(codes_16_31, low_mask);
219 let vals_lo_0 = vqtbl1q_u8(lut_lo_vec, idx_lo_0);
220 let vals_lo_1 = vqtbl1q_u8(lut_lo_vec, idx_lo_1);
221
222 let idx_hi_0 = vshrq_n_u8(codes_0_15, 4);
224 let idx_hi_1 = vshrq_n_u8(codes_16_31, 4);
225 let vals_hi_0 = vqtbl1q_u8(lut_hi_vec, idx_hi_0);
226 let vals_hi_1 = vqtbl1q_u8(lut_hi_vec, idx_hi_1);
227
228 accum0 = vaddq_u16(
231 accum0,
232 vaddl_u8(vget_low_u8(vals_lo_0), vget_low_u8(vals_hi_0)),
233 );
234 accum1 = vaddq_u16(
236 accum1,
237 vaddl_u8(vget_high_u8(vals_lo_0), vget_high_u8(vals_hi_0)),
238 );
239 accum2 = vaddq_u16(
241 accum2,
242 vaddl_u8(vget_low_u8(vals_lo_1), vget_low_u8(vals_hi_1)),
243 );
244 accum3 = vaddq_u16(
246 accum3,
247 vaddl_u8(vget_high_u8(vals_lo_1), vget_high_u8(vals_hi_1)),
248 );
249 }
250
251 let mut results = [0u16; BATCH_SIZE];
253 vst1q_u16(results.as_mut_ptr(), accum0);
254 vst1q_u16(results.as_mut_ptr().add(8), accum1);
255 vst1q_u16(results.as_mut_ptr().add(16), accum2);
256 vst1q_u16(results.as_mut_ptr().add(24), accum3);
257
258 results
259 }
260}
261
262#[cfg(target_arch = "x86_64")]
264#[allow(clippy::cast_ptr_alignment)] #[must_use]
266pub fn fastscan_batch_avx2(
267 luts_lo: &[[u8; 16]],
268 luts_hi: &[[u8; 16]],
269 interleaved_codes: &[u8],
270) -> [u16; BATCH_SIZE] {
271 use std::arch::x86_64::{
272 __m128i, __m256i, _mm256_add_epi16, _mm256_and_si256, _mm256_broadcastsi128_si256,
273 _mm256_cvtepu8_epi16, _mm256_loadu_si256, _mm256_set1_epi8, _mm256_setzero_si256,
274 _mm256_shuffle_epi8, _mm256_srli_epi16, _mm256_storeu_si256, _mm_loadu_si128,
275 };
276
277 unsafe {
278 if !std::is_x86_feature_detected!("avx2") {
279 return fastscan_batch_scalar(luts_lo, luts_hi, interleaved_codes);
280 }
281
282 let low_mask = _mm256_set1_epi8(0x0F);
283
284 let mut accum_lo = _mm256_setzero_si256(); let mut accum_hi = _mm256_setzero_si256(); for sq in 0..luts_lo.len() {
289 let base = sq * BATCH_SIZE;
290
291 let lut_lo_128 = _mm_loadu_si128(luts_lo[sq].as_ptr() as *const __m128i);
293 let lut_hi_128 = _mm_loadu_si128(luts_hi[sq].as_ptr() as *const __m128i);
294 let lut_lo_vec = _mm256_broadcastsi128_si256(lut_lo_128);
295 let lut_hi_vec = _mm256_broadcastsi128_si256(lut_hi_128);
296
297 let codes = _mm256_loadu_si256(interleaved_codes.as_ptr().add(base) as *const __m256i);
299
300 let idx_lo = _mm256_and_si256(codes, low_mask);
302 let vals_lo = _mm256_shuffle_epi8(lut_lo_vec, idx_lo);
303
304 let idx_hi = _mm256_and_si256(_mm256_srli_epi16(codes, 4), low_mask);
306 let vals_hi = _mm256_shuffle_epi8(lut_hi_vec, idx_hi);
307
308 let vals_lo_128 = _mm256_castsi256_si128(vals_lo);
315 let vals_hi_128 = _mm256_castsi256_si128(vals_hi);
316 let sum_lo_16 = _mm256_cvtepu8_epi16(vals_lo_128);
317 let sum_hi_16 = _mm256_cvtepu8_epi16(vals_hi_128);
318 accum_lo = _mm256_add_epi16(accum_lo, sum_lo_16);
319 accum_lo = _mm256_add_epi16(accum_lo, sum_hi_16);
320
321 let vals_lo_high = _mm256_extracti128_si256(vals_lo, 1);
323 let vals_hi_high = _mm256_extracti128_si256(vals_hi, 1);
324 let sum_lo_high_16 = _mm256_cvtepu8_epi16(vals_lo_high);
325 let sum_hi_high_16 = _mm256_cvtepu8_epi16(vals_hi_high);
326 accum_hi = _mm256_add_epi16(accum_hi, sum_lo_high_16);
327 accum_hi = _mm256_add_epi16(accum_hi, sum_hi_high_16);
328 }
329
330 let mut results = [0u16; BATCH_SIZE];
331 _mm256_storeu_si256(results.as_mut_ptr() as *mut __m256i, accum_lo);
332 _mm256_storeu_si256(results.as_mut_ptr().add(16) as *mut __m256i, accum_hi);
333 results
334 }
335}
336
337#[cfg(target_arch = "x86_64")]
338use std::arch::x86_64::{_mm256_castsi256_si128, _mm256_extracti128_si256};
339
340#[must_use]
342pub fn fastscan_batch_scalar(
343 luts_lo: &[[u8; 16]],
344 luts_hi: &[[u8; 16]],
345 interleaved_codes: &[u8],
346) -> [u16; BATCH_SIZE] {
347 let mut results = [0u16; BATCH_SIZE];
348
349 for (sq, (lut_lo, lut_hi)) in luts_lo.iter().zip(luts_hi.iter()).enumerate() {
350 let base = sq * BATCH_SIZE;
351 for n in 0..BATCH_SIZE {
352 let code = interleaved_codes[base + n];
353 let lo_idx = (code & 0x0F) as usize;
354 let hi_idx = ((code >> 4) & 0x0F) as usize;
355 results[n] += lut_lo[lo_idx] as u16 + lut_hi[hi_idx] as u16;
356 }
357 }
358
359 results
360}
361
362#[inline]
364#[must_use]
365pub fn fastscan_batch(
366 luts_lo: &[[u8; 16]],
367 luts_hi: &[[u8; 16]],
368 interleaved_codes: &[u8],
369) -> [u16; BATCH_SIZE] {
370 #[cfg(target_arch = "aarch64")]
371 {
372 fastscan_batch_neon(luts_lo, luts_hi, interleaved_codes)
373 }
374 #[cfg(target_arch = "x86_64")]
375 {
376 fastscan_batch_avx2(luts_lo, luts_hi, interleaved_codes)
377 }
378 #[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
379 {
380 fastscan_batch_scalar(luts_lo, luts_hi, interleaved_codes)
381 }
382}
383
384#[inline]
386#[must_use]
387pub fn fastscan_batch_with_lut(lut: &FastScanLUT, interleaved_codes: &[u8]) -> [u16; BATCH_SIZE] {
388 fastscan_batch(lut.luts_lo(), lut.luts_hi(), interleaved_codes)
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_fastscan_scalar() {
397 let luts_lo: Vec<[u8; 16]> = (0..4).map(|_| core::array::from_fn(|i| i as u8)).collect();
399 let luts_hi: Vec<[u8; 16]> = (0..4).map(|_| core::array::from_fn(|i| i as u8)).collect();
400
401 let codes = vec![0u8; 4 * BATCH_SIZE];
403
404 let results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
405
406 for &r in &results {
408 assert_eq!(r, 0);
409 }
410 }
411
412 #[test]
413 fn test_fastscan_scalar_nonzero() {
414 let luts_lo: Vec<[u8; 16]> = (0..2).map(|_| core::array::from_fn(|i| i as u8)).collect();
416 let luts_hi: Vec<[u8; 16]> = (0..2).map(|_| core::array::from_fn(|i| i as u8)).collect();
417
418 let mut codes = vec![0u8; 2 * BATCH_SIZE];
420 codes[0] = 0x11; codes[BATCH_SIZE] = 0x22; let results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
424
425 assert_eq!(results[0], 6);
427 assert_eq!(results[1], 0);
429 }
430
431 #[test]
432 fn test_fastscan_matches_scalar() {
433 let luts_lo: Vec<[u8; 16]> = (0..8)
435 .map(|sq| core::array::from_fn(|i| ((sq * 17 + i * 7) % 100) as u8))
436 .collect();
437 let luts_hi: Vec<[u8; 16]> = (0..8)
438 .map(|sq| core::array::from_fn(|i| ((sq * 13 + i * 11) % 100) as u8))
439 .collect();
440
441 let mut codes = vec![0u8; 8 * BATCH_SIZE];
443 for (i, code) in codes.iter_mut().enumerate() {
444 *code = ((i * 31 + 17) % 256) as u8;
445 }
446
447 let scalar_results = fastscan_batch_scalar(&luts_lo, &luts_hi, &codes);
448 let simd_results = fastscan_batch(&luts_lo, &luts_hi, &codes);
449
450 for (i, (&scalar, &simd)) in scalar_results.iter().zip(simd_results.iter()).enumerate() {
452 assert_eq!(
453 scalar, simd,
454 "Mismatch at neighbor {i}: scalar={scalar}, simd={simd}"
455 );
456 }
457 }
458}