Skip to main content

jxl_encoder_simd/
quantize.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! SIMD-accelerated AC coefficient quantization.
6//!
7//! Two kernels:
8//! - `quantize_block_dct8`: Fixed 64-element DCT8 path (~4% of encoder CPU)
9//! - `quantize_block_large`: Generic path for DCT16+ blocks (128–4096 coefficients)
10//!
11//! Both use dead-zone thresholding: coefficients below a per-quadrant threshold
12//! are zeroed. SIMD processes 8 (AVX2) or 4 (NEON/WASM) coefficients at a time.
13
14/// Quantize a DCT8 block (64 coefficients) with dead-zone thresholding.
15///
16/// For each coefficient `i` (except DC at index 0):
17///   `val = dct_coeffs[i] / weights[i] * qac_qm`
18///   if `|val| < threshold[quadrant]`: output 0
19///   else: output round(val) as i32
20///
21/// DC (index 0) is always set to 0 (handled separately by LLF coding).
22///
23/// `thresholds` are the 4 quadrant thresholds:
24///   `[0]` = top-left (y<4, x<4),
25///   `[1]` = top-right (y<4, x>=4),
26///   `[2]` = bottom-left (y>=4, x<4),
27///   `[3]` = bottom-right (y>=4, x>=4)
28#[inline]
29pub fn quantize_block_dct8(
30    dct_coeffs: &[f32; 64],
31    weights: &[f32; 64],
32    qac_qm: f32,
33    thresholds: &[f32; 4],
34    output: &mut [i32; 64],
35) {
36    #[cfg(target_arch = "x86_64")]
37    {
38        use archmage::SimdToken;
39        if let Some(token) = archmage::X64V3Token::summon() {
40            quantize_dct8_avx2(token, dct_coeffs, weights, qac_qm, thresholds, output);
41            return;
42        }
43    }
44
45    #[cfg(target_arch = "aarch64")]
46    {
47        use archmage::SimdToken;
48        if let Some(token) = archmage::NeonToken::summon() {
49            quantize_dct8_neon(token, dct_coeffs, weights, qac_qm, thresholds, output);
50            return;
51        }
52    }
53
54    #[cfg(target_arch = "wasm32")]
55    {
56        use archmage::SimdToken;
57        if let Some(token) = archmage::Wasm128Token::summon() {
58            quantize_dct8_wasm128(token, dct_coeffs, weights, qac_qm, thresholds, output);
59            return;
60        }
61    }
62
63    quantize_dct8_scalar(dct_coeffs, weights, qac_qm, thresholds, output);
64}
65
66#[inline]
67pub fn quantize_dct8_scalar(
68    dct_coeffs: &[f32; 64],
69    weights: &[f32; 64],
70    qac_qm: f32,
71    thresholds: &[f32; 4],
72    output: &mut [i32; 64],
73) {
74    output[0] = 0; // DC
75    for idx in 1..64 {
76        let y = idx / 8;
77        let x = idx % 8;
78        let thr_idx = (if y >= 4 { 2 } else { 0 }) + (if x >= 4 { 1 } else { 0 });
79        let val = dct_coeffs[idx] * (1.0 / weights[idx]) * qac_qm;
80        output[idx] = if val.abs() < thresholds[thr_idx] {
81            0
82        } else {
83            val.round_ties_even() as i32
84        };
85    }
86}
87
88#[cfg(target_arch = "x86_64")]
89#[inline]
90#[archmage::arcane]
91pub fn quantize_dct8_avx2(
92    token: archmage::X64V3Token,
93    dct_coeffs: &[f32; 64],
94    weights: &[f32; 64],
95    qac_qm: f32,
96    thresholds: &[f32; 4],
97    output: &mut [i32; 64],
98) {
99    use magetypes::simd::f32x8;
100
101    let qac_qm_v = f32x8::splat(token, qac_qm);
102    let zero_f = f32x8::zero(token);
103
104    // Pre-build threshold vectors for each row group:
105    // Rows 0-3: [t[0], t[0], t[0], t[0], t[1], t[1], t[1], t[1]]
106    // Rows 4-7: [t[2], t[2], t[2], t[2], t[3], t[3], t[3], t[3]]
107    let thr_top = f32x8::from_array(
108        token,
109        [
110            thresholds[0],
111            thresholds[0],
112            thresholds[0],
113            thresholds[0],
114            thresholds[1],
115            thresholds[1],
116            thresholds[1],
117            thresholds[1],
118        ],
119    );
120    let thr_bot = f32x8::from_array(
121        token,
122        [
123            thresholds[2],
124            thresholds[2],
125            thresholds[2],
126            thresholds[2],
127            thresholds[3],
128            thresholds[3],
129            thresholds[3],
130            thresholds[3],
131        ],
132    );
133
134    // Process 8 chunks of 8 elements (one row each)
135    for chunk in 0..8 {
136        let base = chunk * 8;
137        let coeffs = f32x8::from_slice(token, &dct_coeffs[base..]);
138        let w = f32x8::from_slice(token, &weights[base..]);
139        let thr = if chunk < 4 { thr_top } else { thr_bot };
140
141        // val = coeffs / weights * qac_qm
142        let val = coeffs / w * qac_qm_v;
143
144        // Dead-zone thresholding: if |val| < thr, output 0
145        let abs_val = val.abs();
146        let mask = abs_val.simd_ge(thr); // all-ones where |val| >= threshold
147
148        // Round and select (0 where below threshold)
149        let rounded = val.round();
150        let result = f32x8::blend(mask, rounded, zero_f);
151
152        // Convert to i32 (truncate — result is already at integer values)
153        let result_i32 = result.to_i32x8();
154        result_i32.store((&mut output[base..base + 8]).try_into().unwrap());
155    }
156
157    // DC is always 0 (overwrite whatever SIMD produced for index 0)
158    output[0] = 0;
159}
160
161// --- aarch64 NEON implementation ---
162
163#[cfg(target_arch = "aarch64")]
164#[inline]
165#[archmage::arcane]
166pub fn quantize_dct8_neon(
167    token: archmage::NeonToken,
168    dct_coeffs: &[f32; 64],
169    weights: &[f32; 64],
170    qac_qm: f32,
171    thresholds: &[f32; 4],
172    output: &mut [i32; 64],
173) {
174    use magetypes::simd::f32x4;
175
176    let qac_qm_v = f32x4::splat(token, qac_qm);
177    let zero_f = f32x4::zero(token);
178
179    // With f32x4 (4 elements = half a row), each chunk has a uniform threshold:
180    // row 0-3 lo (cols 0-3): thresholds[0]
181    // row 0-3 hi (cols 4-7): thresholds[1]
182    // row 4-7 lo (cols 0-3): thresholds[2]
183    // row 4-7 hi (cols 4-7): thresholds[3]
184    let thr = [
185        f32x4::splat(token, thresholds[0]),
186        f32x4::splat(token, thresholds[1]),
187        f32x4::splat(token, thresholds[2]),
188        f32x4::splat(token, thresholds[3]),
189    ];
190
191    // Process 16 chunks of 4 elements (2 per row, 8 rows)
192    for row in 0..8 {
193        let thr_row = if row < 4 { 0 } else { 2 };
194        for half in 0..2usize {
195            let base = row * 8 + half * 4;
196            let coeffs = f32x4::from_slice(token, &dct_coeffs[base..]);
197            let w = f32x4::from_slice(token, &weights[base..]);
198            let t = thr[thr_row + half];
199
200            let val = coeffs / w * qac_qm_v;
201            let abs_val = val.abs();
202            let mask = abs_val.simd_ge(t);
203            let rounded = val.round();
204            let result = f32x4::blend(mask, rounded, zero_f);
205            let result_i32 = result.to_i32x4();
206            result_i32.store((&mut output[base..base + 4]).try_into().unwrap());
207        }
208    }
209
210    output[0] = 0;
211}
212
213// --- wasm32 SIMD128 implementation ---
214
215#[cfg(target_arch = "wasm32")]
216#[inline]
217#[archmage::arcane]
218pub fn quantize_dct8_wasm128(
219    token: archmage::Wasm128Token,
220    dct_coeffs: &[f32; 64],
221    weights: &[f32; 64],
222    qac_qm: f32,
223    thresholds: &[f32; 4],
224    output: &mut [i32; 64],
225) {
226    use magetypes::simd::f32x4;
227
228    let qac_qm_v = f32x4::splat(token, qac_qm);
229    let zero_f = f32x4::zero(token);
230
231    let thr = [
232        f32x4::splat(token, thresholds[0]),
233        f32x4::splat(token, thresholds[1]),
234        f32x4::splat(token, thresholds[2]),
235        f32x4::splat(token, thresholds[3]),
236    ];
237
238    // Process 16 chunks of 4 elements (2 per row, 8 rows)
239    for row in 0..8 {
240        let thr_row = if row < 4 { 0 } else { 2 };
241        for half in 0..2usize {
242            let base = row * 8 + half * 4;
243            let coeffs = f32x4::from_slice(token, &dct_coeffs[base..]);
244            let w = f32x4::from_slice(token, &weights[base..]);
245            let t = thr[thr_row + half];
246
247            let val = coeffs / w * qac_qm_v;
248            let abs_val = val.abs();
249            let mask = abs_val.simd_ge(t);
250            let rounded = val.round();
251            let result = f32x4::blend(mask, rounded, zero_f);
252            let result_i32 = result.to_i32x4();
253            result_i32.store((&mut output[base..base + 4]).try_into().unwrap());
254        }
255    }
256
257    output[0] = 0;
258}
259
260// ============================================================================
261// Generic large-block quantization (DCT16+)
262// ============================================================================
263
264/// Quantize AC coefficients for a large block (DCT16+) to a flat output buffer.
265///
266/// For each coefficient at position (y, x) in the grid:
267///   val = dct_coeffs[y*grid_width + x] / weights[y*grid_width + x] * qac_qm
268///   if y < llf_y && x < llf_x: output 0 (LLF handled separately)
269///   elif `|val| < threshold[quadrant]`: output 0
270///   else: output round_ties_even(val) as i32
271///
272/// `grid_width` MUST be a multiple of 8.
273/// `thresholds[0..4]` map to quadrants: [top-left, top-right, bottom-left, bottom-right]
274/// where the split is at grid_height/2 and grid_width/2.
275#[allow(clippy::too_many_arguments)]
276#[inline]
277pub fn quantize_block_large(
278    dct_coeffs: &[f32],
279    weights: &[f32],
280    qac_qm: f32,
281    thresholds: &[f32; 4],
282    grid_width: usize,
283    grid_height: usize,
284    llf_x: usize,
285    llf_y: usize,
286    output: &mut [i32],
287) {
288    debug_assert_eq!(grid_width % 8, 0, "grid_width must be a multiple of 8");
289    let size = grid_width * grid_height;
290    debug_assert!(dct_coeffs.len() >= size);
291    debug_assert!(weights.len() >= size);
292    debug_assert!(output.len() >= size);
293
294    #[cfg(target_arch = "x86_64")]
295    {
296        use archmage::SimdToken;
297        if let Some(token) = archmage::X64V3Token::summon() {
298            quantize_large_avx2(
299                token,
300                dct_coeffs,
301                weights,
302                qac_qm,
303                thresholds,
304                grid_width,
305                grid_height,
306                llf_x,
307                llf_y,
308                output,
309            );
310            return;
311        }
312    }
313
314    #[cfg(target_arch = "aarch64")]
315    {
316        use archmage::SimdToken;
317        if let Some(token) = archmage::NeonToken::summon() {
318            quantize_large_neon(
319                token,
320                dct_coeffs,
321                weights,
322                qac_qm,
323                thresholds,
324                grid_width,
325                grid_height,
326                llf_x,
327                llf_y,
328                output,
329            );
330            return;
331        }
332    }
333
334    #[cfg(target_arch = "wasm32")]
335    {
336        use archmage::SimdToken;
337        if let Some(token) = archmage::Wasm128Token::summon() {
338            quantize_large_wasm128(
339                token,
340                dct_coeffs,
341                weights,
342                qac_qm,
343                thresholds,
344                grid_width,
345                grid_height,
346                llf_x,
347                llf_y,
348                output,
349            );
350            return;
351        }
352    }
353
354    quantize_large_scalar(
355        dct_coeffs,
356        weights,
357        qac_qm,
358        thresholds,
359        grid_width,
360        grid_height,
361        llf_x,
362        llf_y,
363        output,
364    );
365}
366
367#[allow(clippy::too_many_arguments)]
368#[inline]
369pub fn quantize_large_scalar(
370    dct_coeffs: &[f32],
371    weights: &[f32],
372    qac_qm: f32,
373    thresholds: &[f32; 4],
374    grid_width: usize,
375    grid_height: usize,
376    llf_x: usize,
377    llf_y: usize,
378    output: &mut [i32],
379) {
380    let half_h = grid_height / 2;
381    let half_w = grid_width / 2;
382    let size = grid_width * grid_height;
383
384    for idx in 0..size {
385        let y = idx / grid_width;
386        let x = idx % grid_width;
387
388        // LLF positions are handled separately
389        if y < llf_y && x < llf_x {
390            output[idx] = 0;
391            continue;
392        }
393
394        let thr_idx = (if y >= half_h { 2 } else { 0 }) + (if x >= half_w { 1 } else { 0 });
395        let val = dct_coeffs[idx] * (1.0 / weights[idx]) * qac_qm;
396        output[idx] = if val.abs() < thresholds[thr_idx] {
397            0
398        } else {
399            val.round_ties_even() as i32
400        };
401    }
402}
403
404#[allow(clippy::too_many_arguments)]
405#[cfg(target_arch = "x86_64")]
406#[inline]
407#[archmage::arcane]
408pub fn quantize_large_avx2(
409    token: archmage::X64V3Token,
410    dct_coeffs: &[f32],
411    weights: &[f32],
412    qac_qm: f32,
413    thresholds: &[f32; 4],
414    grid_width: usize,
415    grid_height: usize,
416    llf_x: usize,
417    llf_y: usize,
418    output: &mut [i32],
419) {
420    use magetypes::simd::f32x8;
421
422    let qac_v = f32x8::splat(token, qac_qm);
423    let zero_f = f32x8::zero(token);
424
425    let half_h = grid_height / 2;
426    let half_w = grid_width / 2;
427    let chunks_per_row = grid_width / 8;
428
429    // Pre-build threshold splats for each quadrant
430    let thr_splat = [
431        f32x8::splat(token, thresholds[0]),
432        f32x8::splat(token, thresholds[1]),
433        f32x8::splat(token, thresholds[2]),
434        f32x8::splat(token, thresholds[3]),
435    ];
436
437    // Pre-slice to help bounds check elimination
438    let coeffs = &dct_coeffs[..grid_width * grid_height];
439    let wts = &weights[..grid_width * grid_height];
440    let out = &mut output[..grid_width * grid_height];
441
442    for y in 0..grid_height {
443        let row_thr_base = if y >= half_h { 2 } else { 0 };
444        let row_off = y * grid_width;
445
446        for chunk in 0..chunks_per_row {
447            let x_base = chunk * 8;
448            let base = row_off + x_base;
449            let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
450
451            let c = crate::load_f32x8(token, coeffs, base);
452            let w = crate::load_f32x8(token, wts, base);
453            let thr = thr_splat[thr_idx];
454
455            // val = coeff / weight * qac_qm
456            let val = c / w * qac_v;
457
458            // Dead-zone thresholding
459            let abs_val = val.abs();
460            let mask = abs_val.simd_ge(thr);
461            let rounded = val.round();
462            let result = f32x8::blend(mask, rounded, zero_f);
463
464            let result_i32 = result.to_i32x8();
465            result_i32.store((&mut out[base..base + 8]).try_into().unwrap());
466        }
467    }
468
469    // Zero out LLF positions
470    for y in 0..llf_y {
471        for x in 0..llf_x {
472            out[y * grid_width + x] = 0;
473        }
474    }
475}
476
477#[allow(clippy::too_many_arguments)]
478#[cfg(target_arch = "aarch64")]
479#[inline]
480#[archmage::arcane]
481pub fn quantize_large_neon(
482    token: archmage::NeonToken,
483    dct_coeffs: &[f32],
484    weights: &[f32],
485    qac_qm: f32,
486    thresholds: &[f32; 4],
487    grid_width: usize,
488    grid_height: usize,
489    llf_x: usize,
490    llf_y: usize,
491    output: &mut [i32],
492) {
493    use magetypes::simd::f32x4;
494
495    let qac_v = f32x4::splat(token, qac_qm);
496    let zero_f = f32x4::zero(token);
497
498    let half_h = grid_height / 2;
499    let half_w = grid_width / 2;
500
501    let thr_splat = [
502        f32x4::splat(token, thresholds[0]),
503        f32x4::splat(token, thresholds[1]),
504        f32x4::splat(token, thresholds[2]),
505        f32x4::splat(token, thresholds[3]),
506    ];
507
508    let coeffs = &dct_coeffs[..grid_width * grid_height];
509    let wts = &weights[..grid_width * grid_height];
510    let out = &mut output[..grid_width * grid_height];
511
512    for y in 0..grid_height {
513        let row_thr_base = if y >= half_h { 2 } else { 0 };
514        let row_off = y * grid_width;
515
516        // Process in 4-wide chunks
517        let chunks_per_row = grid_width / 4;
518        for chunk in 0..chunks_per_row {
519            let x_base = chunk * 4;
520            let base = row_off + x_base;
521            let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
522
523            let c = f32x4::from_slice(token, &coeffs[base..]);
524            let w = f32x4::from_slice(token, &wts[base..]);
525            let thr = thr_splat[thr_idx];
526
527            let val = c / w * qac_v;
528            let abs_val = val.abs();
529            let mask = abs_val.simd_ge(thr);
530            let rounded = val.round();
531            let result = f32x4::blend(mask, rounded, zero_f);
532
533            let result_i32 = result.to_i32x4();
534            result_i32.store((&mut out[base..base + 4]).try_into().unwrap());
535        }
536    }
537
538    // Zero out LLF positions
539    for y in 0..llf_y {
540        for x in 0..llf_x {
541            out[y * grid_width + x] = 0;
542        }
543    }
544}
545
546#[allow(clippy::too_many_arguments)]
547#[cfg(target_arch = "wasm32")]
548#[inline]
549#[archmage::arcane]
550pub fn quantize_large_wasm128(
551    token: archmage::Wasm128Token,
552    dct_coeffs: &[f32],
553    weights: &[f32],
554    qac_qm: f32,
555    thresholds: &[f32; 4],
556    grid_width: usize,
557    grid_height: usize,
558    llf_x: usize,
559    llf_y: usize,
560    output: &mut [i32],
561) {
562    use magetypes::simd::f32x4;
563
564    let qac_v = f32x4::splat(token, qac_qm);
565    let zero_f = f32x4::zero(token);
566
567    let half_h = grid_height / 2;
568    let half_w = grid_width / 2;
569
570    let thr_splat = [
571        f32x4::splat(token, thresholds[0]),
572        f32x4::splat(token, thresholds[1]),
573        f32x4::splat(token, thresholds[2]),
574        f32x4::splat(token, thresholds[3]),
575    ];
576
577    let coeffs = &dct_coeffs[..grid_width * grid_height];
578    let wts = &weights[..grid_width * grid_height];
579    let out = &mut output[..grid_width * grid_height];
580
581    for y in 0..grid_height {
582        let row_thr_base = if y >= half_h { 2 } else { 0 };
583        let row_off = y * grid_width;
584
585        let chunks_per_row = grid_width / 4;
586        for chunk in 0..chunks_per_row {
587            let x_base = chunk * 4;
588            let base = row_off + x_base;
589            let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
590
591            let c = f32x4::from_slice(token, &coeffs[base..]);
592            let w = f32x4::from_slice(token, &wts[base..]);
593            let thr = thr_splat[thr_idx];
594
595            let val = c / w * qac_v;
596            let abs_val = val.abs();
597            let mask = abs_val.simd_ge(thr);
598            let rounded = val.round();
599            let result = f32x4::blend(mask, rounded, zero_f);
600
601            let result_i32 = result.to_i32x4();
602            result_i32.store((&mut out[base..base + 4]).try_into().unwrap());
603        }
604    }
605
606    // Zero out LLF positions
607    for y in 0..llf_y {
608        for x in 0..llf_x {
609            out[y * grid_width + x] = 0;
610        }
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    extern crate alloc;
618    extern crate std;
619
620    #[test]
621    fn test_quantize_dct8_matches_scalar() {
622        // Realistic DCT8 coefficients
623        let mut coeffs = [0.0f32; 64];
624        let mut weights = [0.0f32; 64];
625        for i in 0..64 {
626            coeffs[i] = ((i as f32) * 1.7 - 50.0) * 0.3;
627            weights[i] = 0.01 + (i as f32) * 0.005;
628        }
629
630        let thresholds = [0.56f32, 0.62, 0.62, 0.62];
631        let qac_qm = 3.5f32;
632
633        let mut ref_out = [0i32; 64];
634        quantize_dct8_scalar(&coeffs, &weights, qac_qm, &thresholds, &mut ref_out);
635
636        let report = archmage::testing::for_each_token_permutation(
637            archmage::testing::CompileTimePolicy::Warn,
638            |perm| {
639                let mut simd_out = [0i32; 64];
640                quantize_block_dct8(&coeffs, &weights, qac_qm, &thresholds, &mut simd_out);
641
642                // DC must be 0
643                assert_eq!(simd_out[0], 0, "DC must be 0 [{perm}]");
644                assert_eq!(ref_out[0], 0, "DC must be 0 (ref) [{perm}]");
645
646                // Compare all AC coefficients — may differ by 1 at rounding boundaries
647                let mut max_diff = 0i32;
648                let mut diff_count = 0;
649                for i in 1..64 {
650                    let diff = (simd_out[i] - ref_out[i]).abs();
651                    if diff > 0 {
652                        diff_count += 1;
653                    }
654                    max_diff = max_diff.max(diff);
655                }
656                assert!(
657                    max_diff <= 1,
658                    "Max quantization diff: {} (at most 1 due to FP rounding boundary) [{perm}]",
659                    max_diff
660                );
661                // Allow up to ~5% of coefficients to differ by 1 at rounding boundaries
662                assert!(
663                    diff_count <= 3,
664                    "Too many differing coefficients: {}/63 [{perm}]",
665                    diff_count
666                );
667            },
668        );
669        std::eprintln!("{report}");
670    }
671
672    #[test]
673    fn test_quantize_dct8_all_zeros() {
674        let coeffs = [0.0f32; 64];
675        let weights = [1.0f32; 64];
676        let thresholds = [0.5f32; 4];
677        let mut output = [99i32; 64]; // fill with non-zero to verify
678
679        quantize_block_dct8(&coeffs, &weights, 1.0, &thresholds, &mut output);
680
681        for (i, &val) in output.iter().enumerate() {
682            assert_eq!(val, 0, "Index {} should be 0", i);
683        }
684    }
685
686    #[test]
687    fn test_quantize_dct8_large_coeffs() {
688        // Large coefficients should all survive thresholding
689        let mut coeffs = [100.0f32; 64];
690        coeffs[0] = 0.0; // DC doesn't matter
691        let weights = [1.0f32; 64];
692        let thresholds = [0.5f32; 4];
693
694        let mut output = [0i32; 64];
695        quantize_block_dct8(&coeffs, &weights, 1.0, &thresholds, &mut output);
696
697        assert_eq!(output[0], 0, "DC must be 0");
698        for (i, &val) in output.iter().enumerate().skip(1) {
699            assert_eq!(val, 100, "Index {} should be 100", i);
700        }
701    }
702
703    // =====================================================================
704    // Large-block quantize tests
705    // =====================================================================
706
707    #[test]
708    fn test_quantize_large_dct16x16_matches_scalar() {
709        let grid_w = 16;
710        let grid_h = 16;
711        let size = grid_w * grid_h;
712        let llf_x = 2; // cx for DCT16x16
713        let llf_y = 2; // cy for DCT16x16
714
715        let mut coeffs = alloc::vec![0.0f32; size];
716        let mut weights = alloc::vec![0.0f32; size];
717        for i in 0..size {
718            coeffs[i] = ((i as f32) * 0.37 - 40.0) * 0.5;
719            weights[i] = 0.01 + (i as f32) * 0.002;
720        }
721
722        let thresholds = [0.56f32, 0.62, 0.62, 0.62];
723        let qac_qm = 4.2f32;
724
725        let mut ref_out = alloc::vec![0i32; size];
726        quantize_large_scalar(
727            &coeffs,
728            &weights,
729            qac_qm,
730            &thresholds,
731            grid_w,
732            grid_h,
733            llf_x,
734            llf_y,
735            &mut ref_out,
736        );
737
738        let report = archmage::testing::for_each_token_permutation(
739            archmage::testing::CompileTimePolicy::Warn,
740            |perm| {
741                let mut simd_out = alloc::vec![0i32; size];
742                quantize_block_large(
743                    &coeffs,
744                    &weights,
745                    qac_qm,
746                    &thresholds,
747                    grid_w,
748                    grid_h,
749                    llf_x,
750                    llf_y,
751                    &mut simd_out,
752                );
753
754                // LLF positions must be 0
755                for y in 0..llf_y {
756                    for x in 0..llf_x {
757                        assert_eq!(
758                            simd_out[y * grid_w + x],
759                            0,
760                            "LLF ({},{}) must be 0 [{perm}]",
761                            y,
762                            x
763                        );
764                    }
765                }
766
767                let mut max_diff = 0i32;
768                let mut diff_count = 0;
769                for i in 0..size {
770                    let diff = (simd_out[i] - ref_out[i]).abs();
771                    if diff > 0 {
772                        diff_count += 1;
773                    }
774                    max_diff = max_diff.max(diff);
775                }
776                assert!(
777                    max_diff <= 1,
778                    "Max diff: {} (at most 1 due to FP rounding) [{perm}]",
779                    max_diff
780                );
781                let tolerance = size / 20; // 5%
782                assert!(
783                    diff_count <= tolerance,
784                    "Too many diffs: {}/{} [{perm}]",
785                    diff_count,
786                    size
787                );
788            },
789        );
790        std::eprintln!("{report}");
791    }
792
793    #[test]
794    fn test_quantize_large_dct32x32_matches_scalar() {
795        let grid_w = 32;
796        let grid_h = 32;
797        let size = grid_w * grid_h;
798        let llf_x = 4;
799        let llf_y = 4;
800
801        let mut coeffs = alloc::vec![0.0f32; size];
802        let mut weights = alloc::vec![0.0f32; size];
803        for i in 0..size {
804            coeffs[i] = ((i as f32) * 0.19 - 80.0) * 0.3;
805            weights[i] = 0.005 + (i as f32) * 0.001;
806        }
807
808        let thresholds = [0.54f32, 0.60, 0.58, 0.62];
809        let qac_qm = 5.0f32;
810
811        let mut ref_out = alloc::vec![0i32; size];
812        quantize_large_scalar(
813            &coeffs,
814            &weights,
815            qac_qm,
816            &thresholds,
817            grid_w,
818            grid_h,
819            llf_x,
820            llf_y,
821            &mut ref_out,
822        );
823
824        let report = archmage::testing::for_each_token_permutation(
825            archmage::testing::CompileTimePolicy::Warn,
826            |perm| {
827                let mut simd_out = alloc::vec![0i32; size];
828                quantize_block_large(
829                    &coeffs,
830                    &weights,
831                    qac_qm,
832                    &thresholds,
833                    grid_w,
834                    grid_h,
835                    llf_x,
836                    llf_y,
837                    &mut simd_out,
838                );
839
840                for y in 0..llf_y {
841                    for x in 0..llf_x {
842                        assert_eq!(simd_out[y * grid_w + x], 0, "LLF ({},{}) [{perm}]", y, x);
843                    }
844                }
845
846                let mut max_diff = 0i32;
847                for i in 0..size {
848                    let diff = (simd_out[i] - ref_out[i]).abs();
849                    max_diff = max_diff.max(diff);
850                }
851                assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
852            },
853        );
854        std::eprintln!("{report}");
855    }
856
857    #[test]
858    fn test_quantize_large_dct64x64_matches_scalar() {
859        let grid_w = 64;
860        let grid_h = 64;
861        let size = grid_w * grid_h;
862        let llf_x = 8;
863        let llf_y = 8;
864
865        let mut coeffs = alloc::vec![0.0f32; size];
866        let mut weights = alloc::vec![0.0f32; size];
867        for i in 0..size {
868            coeffs[i] = ((i as f32) * 0.07 - 120.0) * 0.2;
869            weights[i] = 0.002 + (i as f32) * 0.0005;
870        }
871
872        let thresholds = [0.56f32, 0.62, 0.62, 0.62];
873        let qac_qm = 3.0f32;
874
875        let mut ref_out = alloc::vec![0i32; size];
876        quantize_large_scalar(
877            &coeffs,
878            &weights,
879            qac_qm,
880            &thresholds,
881            grid_w,
882            grid_h,
883            llf_x,
884            llf_y,
885            &mut ref_out,
886        );
887
888        let report = archmage::testing::for_each_token_permutation(
889            archmage::testing::CompileTimePolicy::Warn,
890            |perm| {
891                let mut simd_out = alloc::vec![0i32; size];
892                quantize_block_large(
893                    &coeffs,
894                    &weights,
895                    qac_qm,
896                    &thresholds,
897                    grid_w,
898                    grid_h,
899                    llf_x,
900                    llf_y,
901                    &mut simd_out,
902                );
903
904                for y in 0..llf_y {
905                    for x in 0..llf_x {
906                        assert_eq!(simd_out[y * grid_w + x], 0, "LLF ({},{}) [{perm}]", y, x);
907                    }
908                }
909
910                let mut max_diff = 0i32;
911                for i in 0..size {
912                    let diff = (simd_out[i] - ref_out[i]).abs();
913                    max_diff = max_diff.max(diff);
914                }
915                assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
916            },
917        );
918        std::eprintln!("{report}");
919    }
920
921    #[test]
922    fn test_quantize_large_nonsquare_16x8() {
923        let grid_w = 16;
924        let grid_h = 8;
925        let size = grid_w * grid_h;
926        let llf_x = 2;
927        let llf_y = 1;
928
929        let mut coeffs = alloc::vec![0.0f32; size];
930        let mut weights = alloc::vec![0.0f32; size];
931        for i in 0..size {
932            coeffs[i] = ((i as f32) * 0.53 - 30.0) * 0.8;
933            weights[i] = 0.02 + (i as f32) * 0.004;
934        }
935
936        let thresholds = [0.56f32, 0.62, 0.62, 0.62];
937        let qac_qm = 2.5f32;
938
939        let mut ref_out = alloc::vec![0i32; size];
940        quantize_large_scalar(
941            &coeffs,
942            &weights,
943            qac_qm,
944            &thresholds,
945            grid_w,
946            grid_h,
947            llf_x,
948            llf_y,
949            &mut ref_out,
950        );
951
952        let report = archmage::testing::for_each_token_permutation(
953            archmage::testing::CompileTimePolicy::Warn,
954            |perm| {
955                let mut simd_out = alloc::vec![0i32; size];
956                quantize_block_large(
957                    &coeffs,
958                    &weights,
959                    qac_qm,
960                    &thresholds,
961                    grid_w,
962                    grid_h,
963                    llf_x,
964                    llf_y,
965                    &mut simd_out,
966                );
967
968                let mut max_diff = 0i32;
969                for i in 0..size {
970                    let diff = (simd_out[i] - ref_out[i]).abs();
971                    max_diff = max_diff.max(diff);
972                }
973                assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
974            },
975        );
976        std::eprintln!("{report}");
977    }
978
979    #[test]
980    fn test_quantize_large_all_zeros() {
981        let grid_w = 16;
982        let grid_h = 16;
983        let size = grid_w * grid_h;
984
985        let coeffs = alloc::vec![0.0f32; size];
986        let weights = alloc::vec![1.0f32; size];
987        let thresholds = [0.5f32; 4];
988        let mut output = alloc::vec![99i32; size];
989
990        quantize_block_large(
991            &coeffs,
992            &weights,
993            1.0,
994            &thresholds,
995            grid_w,
996            grid_h,
997            2,
998            2,
999            &mut output,
1000        );
1001
1002        for (i, &val) in output.iter().enumerate() {
1003            assert_eq!(val, 0, "Index {} should be 0", i);
1004        }
1005    }
1006
1007    #[test]
1008    fn test_quantize_large_llf_zeroed() {
1009        // Verify LLF positions are zeroed even with large coefficients
1010        let grid_w = 32;
1011        let grid_h = 32;
1012        let size = grid_w * grid_h;
1013        let llf_x = 4;
1014        let llf_y = 4;
1015
1016        let coeffs = alloc::vec![100.0f32; size];
1017        let weights = alloc::vec![1.0f32; size];
1018        let thresholds = [0.1f32; 4]; // low threshold so everything survives
1019        let mut output = alloc::vec![0i32; size];
1020
1021        quantize_block_large(
1022            &coeffs,
1023            &weights,
1024            1.0,
1025            &thresholds,
1026            grid_w,
1027            grid_h,
1028            llf_x,
1029            llf_y,
1030            &mut output,
1031        );
1032
1033        for y in 0..llf_y {
1034            for x in 0..llf_x {
1035                assert_eq!(output[y * grid_w + x], 0, "LLF ({},{}) must be 0", y, x);
1036            }
1037        }
1038        // Non-LLF should be 100
1039        assert_eq!(output[llf_x], 100, "First non-LLF position should be 100");
1040        assert_eq!(
1041            output[llf_y * grid_w],
1042            100,
1043            "First non-LLF row should be 100"
1044        );
1045    }
1046}