1#[inline]
29pub fn quantize_block_dct8(
30 dct_coeffs: &[f32; 64],
31 weights: &[f32; 64],
32 qac_qm: f32,
33 thresholds: &[f32; 4],
34 output: &mut [i32; 64],
35) {
36 #[cfg(target_arch = "x86_64")]
37 {
38 use archmage::SimdToken;
39 if let Some(token) = archmage::X64V3Token::summon() {
40 quantize_dct8_avx2(token, dct_coeffs, weights, qac_qm, thresholds, output);
41 return;
42 }
43 }
44
45 #[cfg(target_arch = "aarch64")]
46 {
47 use archmage::SimdToken;
48 if let Some(token) = archmage::NeonToken::summon() {
49 quantize_dct8_neon(token, dct_coeffs, weights, qac_qm, thresholds, output);
50 return;
51 }
52 }
53
54 #[cfg(target_arch = "wasm32")]
55 {
56 use archmage::SimdToken;
57 if let Some(token) = archmage::Wasm128Token::summon() {
58 quantize_dct8_wasm128(token, dct_coeffs, weights, qac_qm, thresholds, output);
59 return;
60 }
61 }
62
63 quantize_dct8_scalar(dct_coeffs, weights, qac_qm, thresholds, output);
64}
65
66#[inline]
67pub fn quantize_dct8_scalar(
68 dct_coeffs: &[f32; 64],
69 weights: &[f32; 64],
70 qac_qm: f32,
71 thresholds: &[f32; 4],
72 output: &mut [i32; 64],
73) {
74 output[0] = 0; for idx in 1..64 {
76 let y = idx / 8;
77 let x = idx % 8;
78 let thr_idx = (if y >= 4 { 2 } else { 0 }) + (if x >= 4 { 1 } else { 0 });
79 let val = dct_coeffs[idx] * (1.0 / weights[idx]) * qac_qm;
80 output[idx] = if val.abs() < thresholds[thr_idx] {
81 0
82 } else {
83 val.round_ties_even() as i32
84 };
85 }
86}
87
88#[cfg(target_arch = "x86_64")]
89#[inline]
90#[archmage::arcane]
91pub fn quantize_dct8_avx2(
92 token: archmage::X64V3Token,
93 dct_coeffs: &[f32; 64],
94 weights: &[f32; 64],
95 qac_qm: f32,
96 thresholds: &[f32; 4],
97 output: &mut [i32; 64],
98) {
99 use magetypes::simd::f32x8;
100
101 let qac_qm_v = f32x8::splat(token, qac_qm);
102 let zero_f = f32x8::zero(token);
103
104 let thr_top = f32x8::from_array(
108 token,
109 [
110 thresholds[0],
111 thresholds[0],
112 thresholds[0],
113 thresholds[0],
114 thresholds[1],
115 thresholds[1],
116 thresholds[1],
117 thresholds[1],
118 ],
119 );
120 let thr_bot = f32x8::from_array(
121 token,
122 [
123 thresholds[2],
124 thresholds[2],
125 thresholds[2],
126 thresholds[2],
127 thresholds[3],
128 thresholds[3],
129 thresholds[3],
130 thresholds[3],
131 ],
132 );
133
134 for chunk in 0..8 {
136 let base = chunk * 8;
137 let coeffs = f32x8::from_slice(token, &dct_coeffs[base..]);
138 let w = f32x8::from_slice(token, &weights[base..]);
139 let thr = if chunk < 4 { thr_top } else { thr_bot };
140
141 let val = coeffs / w * qac_qm_v;
143
144 let abs_val = val.abs();
146 let mask = abs_val.simd_ge(thr); let rounded = val.round();
150 let result = f32x8::blend(mask, rounded, zero_f);
151
152 let result_i32 = result.to_i32x8();
154 result_i32.store((&mut output[base..base + 8]).try_into().unwrap());
155 }
156
157 output[0] = 0;
159}
160
161#[cfg(target_arch = "aarch64")]
164#[inline]
165#[archmage::arcane]
166pub fn quantize_dct8_neon(
167 token: archmage::NeonToken,
168 dct_coeffs: &[f32; 64],
169 weights: &[f32; 64],
170 qac_qm: f32,
171 thresholds: &[f32; 4],
172 output: &mut [i32; 64],
173) {
174 use magetypes::simd::f32x4;
175
176 let qac_qm_v = f32x4::splat(token, qac_qm);
177 let zero_f = f32x4::zero(token);
178
179 let thr = [
185 f32x4::splat(token, thresholds[0]),
186 f32x4::splat(token, thresholds[1]),
187 f32x4::splat(token, thresholds[2]),
188 f32x4::splat(token, thresholds[3]),
189 ];
190
191 for row in 0..8 {
193 let thr_row = if row < 4 { 0 } else { 2 };
194 for half in 0..2usize {
195 let base = row * 8 + half * 4;
196 let coeffs = f32x4::from_slice(token, &dct_coeffs[base..]);
197 let w = f32x4::from_slice(token, &weights[base..]);
198 let t = thr[thr_row + half];
199
200 let val = coeffs / w * qac_qm_v;
201 let abs_val = val.abs();
202 let mask = abs_val.simd_ge(t);
203 let rounded = val.round();
204 let result = f32x4::blend(mask, rounded, zero_f);
205 let result_i32 = result.to_i32x4();
206 result_i32.store((&mut output[base..base + 4]).try_into().unwrap());
207 }
208 }
209
210 output[0] = 0;
211}
212
213#[cfg(target_arch = "wasm32")]
216#[inline]
217#[archmage::arcane]
218pub fn quantize_dct8_wasm128(
219 token: archmage::Wasm128Token,
220 dct_coeffs: &[f32; 64],
221 weights: &[f32; 64],
222 qac_qm: f32,
223 thresholds: &[f32; 4],
224 output: &mut [i32; 64],
225) {
226 use magetypes::simd::f32x4;
227
228 let qac_qm_v = f32x4::splat(token, qac_qm);
229 let zero_f = f32x4::zero(token);
230
231 let thr = [
232 f32x4::splat(token, thresholds[0]),
233 f32x4::splat(token, thresholds[1]),
234 f32x4::splat(token, thresholds[2]),
235 f32x4::splat(token, thresholds[3]),
236 ];
237
238 for row in 0..8 {
240 let thr_row = if row < 4 { 0 } else { 2 };
241 for half in 0..2usize {
242 let base = row * 8 + half * 4;
243 let coeffs = f32x4::from_slice(token, &dct_coeffs[base..]);
244 let w = f32x4::from_slice(token, &weights[base..]);
245 let t = thr[thr_row + half];
246
247 let val = coeffs / w * qac_qm_v;
248 let abs_val = val.abs();
249 let mask = abs_val.simd_ge(t);
250 let rounded = val.round();
251 let result = f32x4::blend(mask, rounded, zero_f);
252 let result_i32 = result.to_i32x4();
253 result_i32.store((&mut output[base..base + 4]).try_into().unwrap());
254 }
255 }
256
257 output[0] = 0;
258}
259
260#[allow(clippy::too_many_arguments)]
276#[inline]
277pub fn quantize_block_large(
278 dct_coeffs: &[f32],
279 weights: &[f32],
280 qac_qm: f32,
281 thresholds: &[f32; 4],
282 grid_width: usize,
283 grid_height: usize,
284 llf_x: usize,
285 llf_y: usize,
286 output: &mut [i32],
287) {
288 debug_assert_eq!(grid_width % 8, 0, "grid_width must be a multiple of 8");
289 let size = grid_width * grid_height;
290 debug_assert!(dct_coeffs.len() >= size);
291 debug_assert!(weights.len() >= size);
292 debug_assert!(output.len() >= size);
293
294 #[cfg(target_arch = "x86_64")]
295 {
296 use archmage::SimdToken;
297 if let Some(token) = archmage::X64V3Token::summon() {
298 quantize_large_avx2(
299 token,
300 dct_coeffs,
301 weights,
302 qac_qm,
303 thresholds,
304 grid_width,
305 grid_height,
306 llf_x,
307 llf_y,
308 output,
309 );
310 return;
311 }
312 }
313
314 #[cfg(target_arch = "aarch64")]
315 {
316 use archmage::SimdToken;
317 if let Some(token) = archmage::NeonToken::summon() {
318 quantize_large_neon(
319 token,
320 dct_coeffs,
321 weights,
322 qac_qm,
323 thresholds,
324 grid_width,
325 grid_height,
326 llf_x,
327 llf_y,
328 output,
329 );
330 return;
331 }
332 }
333
334 #[cfg(target_arch = "wasm32")]
335 {
336 use archmage::SimdToken;
337 if let Some(token) = archmage::Wasm128Token::summon() {
338 quantize_large_wasm128(
339 token,
340 dct_coeffs,
341 weights,
342 qac_qm,
343 thresholds,
344 grid_width,
345 grid_height,
346 llf_x,
347 llf_y,
348 output,
349 );
350 return;
351 }
352 }
353
354 quantize_large_scalar(
355 dct_coeffs,
356 weights,
357 qac_qm,
358 thresholds,
359 grid_width,
360 grid_height,
361 llf_x,
362 llf_y,
363 output,
364 );
365}
366
367#[allow(clippy::too_many_arguments)]
368#[inline]
369pub fn quantize_large_scalar(
370 dct_coeffs: &[f32],
371 weights: &[f32],
372 qac_qm: f32,
373 thresholds: &[f32; 4],
374 grid_width: usize,
375 grid_height: usize,
376 llf_x: usize,
377 llf_y: usize,
378 output: &mut [i32],
379) {
380 let half_h = grid_height / 2;
381 let half_w = grid_width / 2;
382 let size = grid_width * grid_height;
383
384 for idx in 0..size {
385 let y = idx / grid_width;
386 let x = idx % grid_width;
387
388 if y < llf_y && x < llf_x {
390 output[idx] = 0;
391 continue;
392 }
393
394 let thr_idx = (if y >= half_h { 2 } else { 0 }) + (if x >= half_w { 1 } else { 0 });
395 let val = dct_coeffs[idx] * (1.0 / weights[idx]) * qac_qm;
396 output[idx] = if val.abs() < thresholds[thr_idx] {
397 0
398 } else {
399 val.round_ties_even() as i32
400 };
401 }
402}
403
404#[allow(clippy::too_many_arguments)]
405#[cfg(target_arch = "x86_64")]
406#[inline]
407#[archmage::arcane]
408pub fn quantize_large_avx2(
409 token: archmage::X64V3Token,
410 dct_coeffs: &[f32],
411 weights: &[f32],
412 qac_qm: f32,
413 thresholds: &[f32; 4],
414 grid_width: usize,
415 grid_height: usize,
416 llf_x: usize,
417 llf_y: usize,
418 output: &mut [i32],
419) {
420 use magetypes::simd::f32x8;
421
422 let qac_v = f32x8::splat(token, qac_qm);
423 let zero_f = f32x8::zero(token);
424
425 let half_h = grid_height / 2;
426 let half_w = grid_width / 2;
427 let chunks_per_row = grid_width / 8;
428
429 let thr_splat = [
431 f32x8::splat(token, thresholds[0]),
432 f32x8::splat(token, thresholds[1]),
433 f32x8::splat(token, thresholds[2]),
434 f32x8::splat(token, thresholds[3]),
435 ];
436
437 let coeffs = &dct_coeffs[..grid_width * grid_height];
439 let wts = &weights[..grid_width * grid_height];
440 let out = &mut output[..grid_width * grid_height];
441
442 for y in 0..grid_height {
443 let row_thr_base = if y >= half_h { 2 } else { 0 };
444 let row_off = y * grid_width;
445
446 for chunk in 0..chunks_per_row {
447 let x_base = chunk * 8;
448 let base = row_off + x_base;
449 let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
450
451 let c = crate::load_f32x8(token, coeffs, base);
452 let w = crate::load_f32x8(token, wts, base);
453 let thr = thr_splat[thr_idx];
454
455 let val = c / w * qac_v;
457
458 let abs_val = val.abs();
460 let mask = abs_val.simd_ge(thr);
461 let rounded = val.round();
462 let result = f32x8::blend(mask, rounded, zero_f);
463
464 let result_i32 = result.to_i32x8();
465 result_i32.store((&mut out[base..base + 8]).try_into().unwrap());
466 }
467 }
468
469 for y in 0..llf_y {
471 for x in 0..llf_x {
472 out[y * grid_width + x] = 0;
473 }
474 }
475}
476
477#[allow(clippy::too_many_arguments)]
478#[cfg(target_arch = "aarch64")]
479#[inline]
480#[archmage::arcane]
481pub fn quantize_large_neon(
482 token: archmage::NeonToken,
483 dct_coeffs: &[f32],
484 weights: &[f32],
485 qac_qm: f32,
486 thresholds: &[f32; 4],
487 grid_width: usize,
488 grid_height: usize,
489 llf_x: usize,
490 llf_y: usize,
491 output: &mut [i32],
492) {
493 use magetypes::simd::f32x4;
494
495 let qac_v = f32x4::splat(token, qac_qm);
496 let zero_f = f32x4::zero(token);
497
498 let half_h = grid_height / 2;
499 let half_w = grid_width / 2;
500
501 let thr_splat = [
502 f32x4::splat(token, thresholds[0]),
503 f32x4::splat(token, thresholds[1]),
504 f32x4::splat(token, thresholds[2]),
505 f32x4::splat(token, thresholds[3]),
506 ];
507
508 let coeffs = &dct_coeffs[..grid_width * grid_height];
509 let wts = &weights[..grid_width * grid_height];
510 let out = &mut output[..grid_width * grid_height];
511
512 for y in 0..grid_height {
513 let row_thr_base = if y >= half_h { 2 } else { 0 };
514 let row_off = y * grid_width;
515
516 let chunks_per_row = grid_width / 4;
518 for chunk in 0..chunks_per_row {
519 let x_base = chunk * 4;
520 let base = row_off + x_base;
521 let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
522
523 let c = f32x4::from_slice(token, &coeffs[base..]);
524 let w = f32x4::from_slice(token, &wts[base..]);
525 let thr = thr_splat[thr_idx];
526
527 let val = c / w * qac_v;
528 let abs_val = val.abs();
529 let mask = abs_val.simd_ge(thr);
530 let rounded = val.round();
531 let result = f32x4::blend(mask, rounded, zero_f);
532
533 let result_i32 = result.to_i32x4();
534 result_i32.store((&mut out[base..base + 4]).try_into().unwrap());
535 }
536 }
537
538 for y in 0..llf_y {
540 for x in 0..llf_x {
541 out[y * grid_width + x] = 0;
542 }
543 }
544}
545
546#[allow(clippy::too_many_arguments)]
547#[cfg(target_arch = "wasm32")]
548#[inline]
549#[archmage::arcane]
550pub fn quantize_large_wasm128(
551 token: archmage::Wasm128Token,
552 dct_coeffs: &[f32],
553 weights: &[f32],
554 qac_qm: f32,
555 thresholds: &[f32; 4],
556 grid_width: usize,
557 grid_height: usize,
558 llf_x: usize,
559 llf_y: usize,
560 output: &mut [i32],
561) {
562 use magetypes::simd::f32x4;
563
564 let qac_v = f32x4::splat(token, qac_qm);
565 let zero_f = f32x4::zero(token);
566
567 let half_h = grid_height / 2;
568 let half_w = grid_width / 2;
569
570 let thr_splat = [
571 f32x4::splat(token, thresholds[0]),
572 f32x4::splat(token, thresholds[1]),
573 f32x4::splat(token, thresholds[2]),
574 f32x4::splat(token, thresholds[3]),
575 ];
576
577 let coeffs = &dct_coeffs[..grid_width * grid_height];
578 let wts = &weights[..grid_width * grid_height];
579 let out = &mut output[..grid_width * grid_height];
580
581 for y in 0..grid_height {
582 let row_thr_base = if y >= half_h { 2 } else { 0 };
583 let row_off = y * grid_width;
584
585 let chunks_per_row = grid_width / 4;
586 for chunk in 0..chunks_per_row {
587 let x_base = chunk * 4;
588 let base = row_off + x_base;
589 let thr_idx = row_thr_base + if x_base >= half_w { 1 } else { 0 };
590
591 let c = f32x4::from_slice(token, &coeffs[base..]);
592 let w = f32x4::from_slice(token, &wts[base..]);
593 let thr = thr_splat[thr_idx];
594
595 let val = c / w * qac_v;
596 let abs_val = val.abs();
597 let mask = abs_val.simd_ge(thr);
598 let rounded = val.round();
599 let result = f32x4::blend(mask, rounded, zero_f);
600
601 let result_i32 = result.to_i32x4();
602 result_i32.store((&mut out[base..base + 4]).try_into().unwrap());
603 }
604 }
605
606 for y in 0..llf_y {
608 for x in 0..llf_x {
609 out[y * grid_width + x] = 0;
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 extern crate alloc;
618 extern crate std;
619
620 #[test]
621 fn test_quantize_dct8_matches_scalar() {
622 let mut coeffs = [0.0f32; 64];
624 let mut weights = [0.0f32; 64];
625 for i in 0..64 {
626 coeffs[i] = ((i as f32) * 1.7 - 50.0) * 0.3;
627 weights[i] = 0.01 + (i as f32) * 0.005;
628 }
629
630 let thresholds = [0.56f32, 0.62, 0.62, 0.62];
631 let qac_qm = 3.5f32;
632
633 let mut ref_out = [0i32; 64];
634 quantize_dct8_scalar(&coeffs, &weights, qac_qm, &thresholds, &mut ref_out);
635
636 let report = archmage::testing::for_each_token_permutation(
637 archmage::testing::CompileTimePolicy::Warn,
638 |perm| {
639 let mut simd_out = [0i32; 64];
640 quantize_block_dct8(&coeffs, &weights, qac_qm, &thresholds, &mut simd_out);
641
642 assert_eq!(simd_out[0], 0, "DC must be 0 [{perm}]");
644 assert_eq!(ref_out[0], 0, "DC must be 0 (ref) [{perm}]");
645
646 let mut max_diff = 0i32;
648 let mut diff_count = 0;
649 for i in 1..64 {
650 let diff = (simd_out[i] - ref_out[i]).abs();
651 if diff > 0 {
652 diff_count += 1;
653 }
654 max_diff = max_diff.max(diff);
655 }
656 assert!(
657 max_diff <= 1,
658 "Max quantization diff: {} (at most 1 due to FP rounding boundary) [{perm}]",
659 max_diff
660 );
661 assert!(
663 diff_count <= 3,
664 "Too many differing coefficients: {}/63 [{perm}]",
665 diff_count
666 );
667 },
668 );
669 std::eprintln!("{report}");
670 }
671
672 #[test]
673 fn test_quantize_dct8_all_zeros() {
674 let coeffs = [0.0f32; 64];
675 let weights = [1.0f32; 64];
676 let thresholds = [0.5f32; 4];
677 let mut output = [99i32; 64]; quantize_block_dct8(&coeffs, &weights, 1.0, &thresholds, &mut output);
680
681 for (i, &val) in output.iter().enumerate() {
682 assert_eq!(val, 0, "Index {} should be 0", i);
683 }
684 }
685
686 #[test]
687 fn test_quantize_dct8_large_coeffs() {
688 let mut coeffs = [100.0f32; 64];
690 coeffs[0] = 0.0; let weights = [1.0f32; 64];
692 let thresholds = [0.5f32; 4];
693
694 let mut output = [0i32; 64];
695 quantize_block_dct8(&coeffs, &weights, 1.0, &thresholds, &mut output);
696
697 assert_eq!(output[0], 0, "DC must be 0");
698 for (i, &val) in output.iter().enumerate().skip(1) {
699 assert_eq!(val, 100, "Index {} should be 100", i);
700 }
701 }
702
703 #[test]
708 fn test_quantize_large_dct16x16_matches_scalar() {
709 let grid_w = 16;
710 let grid_h = 16;
711 let size = grid_w * grid_h;
712 let llf_x = 2; let llf_y = 2; let mut coeffs = alloc::vec![0.0f32; size];
716 let mut weights = alloc::vec![0.0f32; size];
717 for i in 0..size {
718 coeffs[i] = ((i as f32) * 0.37 - 40.0) * 0.5;
719 weights[i] = 0.01 + (i as f32) * 0.002;
720 }
721
722 let thresholds = [0.56f32, 0.62, 0.62, 0.62];
723 let qac_qm = 4.2f32;
724
725 let mut ref_out = alloc::vec![0i32; size];
726 quantize_large_scalar(
727 &coeffs,
728 &weights,
729 qac_qm,
730 &thresholds,
731 grid_w,
732 grid_h,
733 llf_x,
734 llf_y,
735 &mut ref_out,
736 );
737
738 let report = archmage::testing::for_each_token_permutation(
739 archmage::testing::CompileTimePolicy::Warn,
740 |perm| {
741 let mut simd_out = alloc::vec![0i32; size];
742 quantize_block_large(
743 &coeffs,
744 &weights,
745 qac_qm,
746 &thresholds,
747 grid_w,
748 grid_h,
749 llf_x,
750 llf_y,
751 &mut simd_out,
752 );
753
754 for y in 0..llf_y {
756 for x in 0..llf_x {
757 assert_eq!(
758 simd_out[y * grid_w + x],
759 0,
760 "LLF ({},{}) must be 0 [{perm}]",
761 y,
762 x
763 );
764 }
765 }
766
767 let mut max_diff = 0i32;
768 let mut diff_count = 0;
769 for i in 0..size {
770 let diff = (simd_out[i] - ref_out[i]).abs();
771 if diff > 0 {
772 diff_count += 1;
773 }
774 max_diff = max_diff.max(diff);
775 }
776 assert!(
777 max_diff <= 1,
778 "Max diff: {} (at most 1 due to FP rounding) [{perm}]",
779 max_diff
780 );
781 let tolerance = size / 20; assert!(
783 diff_count <= tolerance,
784 "Too many diffs: {}/{} [{perm}]",
785 diff_count,
786 size
787 );
788 },
789 );
790 std::eprintln!("{report}");
791 }
792
793 #[test]
794 fn test_quantize_large_dct32x32_matches_scalar() {
795 let grid_w = 32;
796 let grid_h = 32;
797 let size = grid_w * grid_h;
798 let llf_x = 4;
799 let llf_y = 4;
800
801 let mut coeffs = alloc::vec![0.0f32; size];
802 let mut weights = alloc::vec![0.0f32; size];
803 for i in 0..size {
804 coeffs[i] = ((i as f32) * 0.19 - 80.0) * 0.3;
805 weights[i] = 0.005 + (i as f32) * 0.001;
806 }
807
808 let thresholds = [0.54f32, 0.60, 0.58, 0.62];
809 let qac_qm = 5.0f32;
810
811 let mut ref_out = alloc::vec![0i32; size];
812 quantize_large_scalar(
813 &coeffs,
814 &weights,
815 qac_qm,
816 &thresholds,
817 grid_w,
818 grid_h,
819 llf_x,
820 llf_y,
821 &mut ref_out,
822 );
823
824 let report = archmage::testing::for_each_token_permutation(
825 archmage::testing::CompileTimePolicy::Warn,
826 |perm| {
827 let mut simd_out = alloc::vec![0i32; size];
828 quantize_block_large(
829 &coeffs,
830 &weights,
831 qac_qm,
832 &thresholds,
833 grid_w,
834 grid_h,
835 llf_x,
836 llf_y,
837 &mut simd_out,
838 );
839
840 for y in 0..llf_y {
841 for x in 0..llf_x {
842 assert_eq!(simd_out[y * grid_w + x], 0, "LLF ({},{}) [{perm}]", y, x);
843 }
844 }
845
846 let mut max_diff = 0i32;
847 for i in 0..size {
848 let diff = (simd_out[i] - ref_out[i]).abs();
849 max_diff = max_diff.max(diff);
850 }
851 assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
852 },
853 );
854 std::eprintln!("{report}");
855 }
856
857 #[test]
858 fn test_quantize_large_dct64x64_matches_scalar() {
859 let grid_w = 64;
860 let grid_h = 64;
861 let size = grid_w * grid_h;
862 let llf_x = 8;
863 let llf_y = 8;
864
865 let mut coeffs = alloc::vec![0.0f32; size];
866 let mut weights = alloc::vec![0.0f32; size];
867 for i in 0..size {
868 coeffs[i] = ((i as f32) * 0.07 - 120.0) * 0.2;
869 weights[i] = 0.002 + (i as f32) * 0.0005;
870 }
871
872 let thresholds = [0.56f32, 0.62, 0.62, 0.62];
873 let qac_qm = 3.0f32;
874
875 let mut ref_out = alloc::vec![0i32; size];
876 quantize_large_scalar(
877 &coeffs,
878 &weights,
879 qac_qm,
880 &thresholds,
881 grid_w,
882 grid_h,
883 llf_x,
884 llf_y,
885 &mut ref_out,
886 );
887
888 let report = archmage::testing::for_each_token_permutation(
889 archmage::testing::CompileTimePolicy::Warn,
890 |perm| {
891 let mut simd_out = alloc::vec![0i32; size];
892 quantize_block_large(
893 &coeffs,
894 &weights,
895 qac_qm,
896 &thresholds,
897 grid_w,
898 grid_h,
899 llf_x,
900 llf_y,
901 &mut simd_out,
902 );
903
904 for y in 0..llf_y {
905 for x in 0..llf_x {
906 assert_eq!(simd_out[y * grid_w + x], 0, "LLF ({},{}) [{perm}]", y, x);
907 }
908 }
909
910 let mut max_diff = 0i32;
911 for i in 0..size {
912 let diff = (simd_out[i] - ref_out[i]).abs();
913 max_diff = max_diff.max(diff);
914 }
915 assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
916 },
917 );
918 std::eprintln!("{report}");
919 }
920
921 #[test]
922 fn test_quantize_large_nonsquare_16x8() {
923 let grid_w = 16;
924 let grid_h = 8;
925 let size = grid_w * grid_h;
926 let llf_x = 2;
927 let llf_y = 1;
928
929 let mut coeffs = alloc::vec![0.0f32; size];
930 let mut weights = alloc::vec![0.0f32; size];
931 for i in 0..size {
932 coeffs[i] = ((i as f32) * 0.53 - 30.0) * 0.8;
933 weights[i] = 0.02 + (i as f32) * 0.004;
934 }
935
936 let thresholds = [0.56f32, 0.62, 0.62, 0.62];
937 let qac_qm = 2.5f32;
938
939 let mut ref_out = alloc::vec![0i32; size];
940 quantize_large_scalar(
941 &coeffs,
942 &weights,
943 qac_qm,
944 &thresholds,
945 grid_w,
946 grid_h,
947 llf_x,
948 llf_y,
949 &mut ref_out,
950 );
951
952 let report = archmage::testing::for_each_token_permutation(
953 archmage::testing::CompileTimePolicy::Warn,
954 |perm| {
955 let mut simd_out = alloc::vec![0i32; size];
956 quantize_block_large(
957 &coeffs,
958 &weights,
959 qac_qm,
960 &thresholds,
961 grid_w,
962 grid_h,
963 llf_x,
964 llf_y,
965 &mut simd_out,
966 );
967
968 let mut max_diff = 0i32;
969 for i in 0..size {
970 let diff = (simd_out[i] - ref_out[i]).abs();
971 max_diff = max_diff.max(diff);
972 }
973 assert!(max_diff <= 1, "Max diff: {} [{perm}]", max_diff);
974 },
975 );
976 std::eprintln!("{report}");
977 }
978
979 #[test]
980 fn test_quantize_large_all_zeros() {
981 let grid_w = 16;
982 let grid_h = 16;
983 let size = grid_w * grid_h;
984
985 let coeffs = alloc::vec![0.0f32; size];
986 let weights = alloc::vec![1.0f32; size];
987 let thresholds = [0.5f32; 4];
988 let mut output = alloc::vec![99i32; size];
989
990 quantize_block_large(
991 &coeffs,
992 &weights,
993 1.0,
994 &thresholds,
995 grid_w,
996 grid_h,
997 2,
998 2,
999 &mut output,
1000 );
1001
1002 for (i, &val) in output.iter().enumerate() {
1003 assert_eq!(val, 0, "Index {} should be 0", i);
1004 }
1005 }
1006
1007 #[test]
1008 fn test_quantize_large_llf_zeroed() {
1009 let grid_w = 32;
1011 let grid_h = 32;
1012 let size = grid_w * grid_h;
1013 let llf_x = 4;
1014 let llf_y = 4;
1015
1016 let coeffs = alloc::vec![100.0f32; size];
1017 let weights = alloc::vec![1.0f32; size];
1018 let thresholds = [0.1f32; 4]; let mut output = alloc::vec![0i32; size];
1020
1021 quantize_block_large(
1022 &coeffs,
1023 &weights,
1024 1.0,
1025 &thresholds,
1026 grid_w,
1027 grid_h,
1028 llf_x,
1029 llf_y,
1030 &mut output,
1031 );
1032
1033 for y in 0..llf_y {
1034 for x in 0..llf_x {
1035 assert_eq!(output[y * grid_w + x], 0, "LLF ({},{}) must be 0", y, x);
1036 }
1037 }
1038 assert_eq!(output[llf_x], 100, "First non-LLF position should be 100");
1040 assert_eq!(
1041 output[llf_y * grid_w],
1042 100,
1043 "First non-LLF row should be 100"
1044 );
1045 }
1046}