1#[inline]
31#[allow(clippy::too_many_arguments)]
32pub fn dequant_block_dct8(
33 quant_ac_x: &[i32; 64],
34 quant_ac_y: &[i32; 64],
35 quant_ac_b: &[i32; 64],
36 weights_x: &[f32; 64],
37 weights_y: &[f32; 64],
38 weights_b: &[f32; 64],
39 qac_qm: [f32; 3], x_factor: f32,
41 b_factor: f32,
42 output_x: &mut [f32; 64],
43 output_y: &mut [f32; 64],
44 output_b: &mut [f32; 64],
45) {
46 #[cfg(target_arch = "x86_64")]
47 {
48 use archmage::SimdToken;
49 if let Some(token) = archmage::X64V3Token::summon() {
50 dequant_dct8_avx2(
51 token, quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm,
52 x_factor, b_factor, output_x, output_y, output_b,
53 );
54 return;
55 }
56 }
57
58 #[cfg(target_arch = "aarch64")]
59 {
60 use archmage::SimdToken;
61 if let Some(token) = archmage::NeonToken::summon() {
62 dequant_dct8_neon(
63 token, quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm,
64 x_factor, b_factor, output_x, output_y, output_b,
65 );
66 return;
67 }
68 }
69
70 dequant_dct8_scalar(
71 quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm, x_factor,
72 b_factor, output_x, output_y, output_b,
73 );
74}
75
76const BIAS_X: f32 = 0.945_349_93; const BIAS_Y: f32 = 0.929_945_5; const BIAS_B: f32 = 0.950_064_9; const BIAS_RECIP: f32 = 0.145;
81
82#[inline(always)]
83fn adjust_quant_bias_scalar(q_int: i32, channel_bias: f32) -> f32 {
84 if q_int == 0 {
85 return 0.0;
86 }
87 let q = q_int as f32;
88 if q.abs() < 1.125 {
89 q.signum() * channel_bias
90 } else {
91 q - BIAS_RECIP / q
92 }
93}
94
95#[inline]
96#[allow(clippy::too_many_arguments)]
97pub fn dequant_dct8_scalar(
98 quant_ac_x: &[i32; 64],
99 quant_ac_y: &[i32; 64],
100 quant_ac_b: &[i32; 64],
101 weights_x: &[f32; 64],
102 weights_y: &[f32; 64],
103 weights_b: &[f32; 64],
104 qac_qm: [f32; 3],
105 x_factor: f32,
106 b_factor: f32,
107 output_x: &mut [f32; 64],
108 output_y: &mut [f32; 64],
109 output_b: &mut [f32; 64],
110) {
111 let inv_qac_x = 1.0 / qac_qm[0];
112 let inv_qac_y = 1.0 / qac_qm[1];
113 let inv_qac_b = 1.0 / qac_qm[2];
114
115 output_x[0] = 0.0;
117 output_y[0] = 0.0;
118 output_b[0] = 0.0;
119
120 for i in 1..64 {
121 let biased_x = adjust_quant_bias_scalar(quant_ac_x[i], BIAS_X);
123 let biased_y = adjust_quant_bias_scalar(quant_ac_y[i], BIAS_Y);
124 let biased_b = adjust_quant_bias_scalar(quant_ac_b[i], BIAS_B);
125
126 let dq_y = biased_y * weights_y[i] * inv_qac_y;
127 output_y[i] = dq_y;
128
129 output_x[i] = biased_x * weights_x[i] * inv_qac_x + x_factor * dq_y;
131 output_b[i] = biased_b * weights_b[i] * inv_qac_b + b_factor * dq_y;
132 }
133}
134
135#[cfg(target_arch = "x86_64")]
136#[inline]
137#[archmage::arcane]
138#[allow(clippy::too_many_arguments)]
139pub fn dequant_dct8_avx2(
140 token: archmage::X64V3Token,
141 quant_ac_x: &[i32; 64],
142 quant_ac_y: &[i32; 64],
143 quant_ac_b: &[i32; 64],
144 weights_x: &[f32; 64],
145 weights_y: &[f32; 64],
146 weights_b: &[f32; 64],
147 qac_qm: [f32; 3],
148 x_factor: f32,
149 b_factor: f32,
150 output_x: &mut [f32; 64],
151 output_y: &mut [f32; 64],
152 output_b: &mut [f32; 64],
153) {
154 use magetypes::simd::{f32x8, i32x8};
155
156 let inv_qac_x_v = f32x8::splat(token, 1.0 / qac_qm[0]);
157 let inv_qac_y_v = f32x8::splat(token, 1.0 / qac_qm[1]);
158 let inv_qac_b_v = f32x8::splat(token, 1.0 / qac_qm[2]);
159 let x_factor_v = f32x8::splat(token, x_factor);
160 let b_factor_v = f32x8::splat(token, b_factor);
161 let zero_f = f32x8::zero(token);
162 let zero_i = i32x8::zero(token);
163 let one_f = f32x8::splat(token, 1.0);
164 let neg_one_f = f32x8::splat(token, -1.0);
165 let threshold = f32x8::splat(token, 1.125);
166 let bias_recip_v = f32x8::splat(token, BIAS_RECIP);
167 let bias_x_v = f32x8::splat(token, BIAS_X);
168 let bias_y_v = f32x8::splat(token, BIAS_Y);
169 let bias_b_v = f32x8::splat(token, BIAS_B);
170
171 for chunk in 0..8 {
173 let base = chunk * 8;
174
175 let q_i_y = i32x8::from_slice(token, &quant_ac_y[base..]);
177 let dq_y = dequant_8_avx2(
178 token,
179 q_i_y,
180 bias_y_v,
181 bias_recip_v,
182 threshold,
183 zero_i,
184 zero_f,
185 one_f,
186 neg_one_f,
187 &weights_y[base..],
188 inv_qac_y_v,
189 );
190 dq_y.store((&mut output_y[base..base + 8]).try_into().unwrap());
191
192 let q_i_x = i32x8::from_slice(token, &quant_ac_x[base..]);
194 let dq_x_raw = dequant_8_avx2(
195 token,
196 q_i_x,
197 bias_x_v,
198 bias_recip_v,
199 threshold,
200 zero_i,
201 zero_f,
202 one_f,
203 neg_one_f,
204 &weights_x[base..],
205 inv_qac_x_v,
206 );
207 let dq_x = dq_x_raw + x_factor_v * dq_y;
208 dq_x.store((&mut output_x[base..base + 8]).try_into().unwrap());
209
210 let q_i_b = i32x8::from_slice(token, &quant_ac_b[base..]);
212 let dq_b_raw = dequant_8_avx2(
213 token,
214 q_i_b,
215 bias_b_v,
216 bias_recip_v,
217 threshold,
218 zero_i,
219 zero_f,
220 one_f,
221 neg_one_f,
222 &weights_b[base..],
223 inv_qac_b_v,
224 );
225 let dq_b = dq_b_raw + b_factor_v * dq_y;
226 dq_b.store((&mut output_b[base..base + 8]).try_into().unwrap());
227 }
228
229 output_x[0] = 0.0;
231 output_y[0] = 0.0;
232 output_b[0] = 0.0;
233}
234
235#[cfg(target_arch = "x86_64")]
243#[archmage::arcane]
244#[inline(always)]
245#[allow(clippy::too_many_arguments)]
246fn dequant_8_avx2(
247 token: archmage::X64V3Token,
248 q_int: magetypes::simd::i32x8,
249 channel_bias: magetypes::simd::f32x8,
250 bias_recip: magetypes::simd::f32x8,
251 threshold: magetypes::simd::f32x8,
252 _zero_i: magetypes::simd::i32x8,
253 zero_f: magetypes::simd::f32x8,
254 one_f: magetypes::simd::f32x8,
255 neg_one_f: magetypes::simd::f32x8,
256 weights: &[f32],
257 inv_qac_qm: magetypes::simd::f32x8,
258) -> magetypes::simd::f32x8 {
259 use magetypes::simd::f32x8;
260
261 let q_f = q_int.to_f32x8();
263 let abs_q = q_f.abs();
264
265 let sign = f32x8::blend(q_f.simd_ge(zero_f), one_f, neg_one_f);
267
268 let case_one = sign * channel_bias;
271
272 let case_large = q_f - bias_recip / q_f;
275
276 let is_large = abs_q.simd_ge(threshold);
278 let biased = f32x8::blend(is_large, case_large, case_one);
279
280 let is_nonzero = abs_q.simd_ge(f32x8::splat(token, 0.5)); let biased = f32x8::blend(is_nonzero, biased, zero_f);
283
284 let w = f32x8::from_slice(token, weights);
286 biased * w * inv_qac_qm
287}
288
289#[cfg(target_arch = "aarch64")]
292#[inline]
293#[archmage::arcane]
294#[allow(clippy::too_many_arguments)]
295pub fn dequant_dct8_neon(
296 token: archmage::NeonToken,
297 quant_ac_x: &[i32; 64],
298 quant_ac_y: &[i32; 64],
299 quant_ac_b: &[i32; 64],
300 weights_x: &[f32; 64],
301 weights_y: &[f32; 64],
302 weights_b: &[f32; 64],
303 qac_qm: [f32; 3],
304 x_factor: f32,
305 b_factor: f32,
306 output_x: &mut [f32; 64],
307 output_y: &mut [f32; 64],
308 output_b: &mut [f32; 64],
309) {
310 use magetypes::simd::{f32x4, i32x4};
311
312 let inv_qac_x_v = f32x4::splat(token, 1.0 / qac_qm[0]);
313 let inv_qac_y_v = f32x4::splat(token, 1.0 / qac_qm[1]);
314 let inv_qac_b_v = f32x4::splat(token, 1.0 / qac_qm[2]);
315 let x_factor_v = f32x4::splat(token, x_factor);
316 let b_factor_v = f32x4::splat(token, b_factor);
317 let zero_f = f32x4::zero(token);
318 let one_f = f32x4::splat(token, 1.0);
319 let neg_one_f = f32x4::splat(token, -1.0);
320 let threshold = f32x4::splat(token, 1.125);
321 let bias_recip_v = f32x4::splat(token, BIAS_RECIP);
322 let bias_x_v = f32x4::splat(token, BIAS_X);
323 let bias_y_v = f32x4::splat(token, BIAS_Y);
324 let bias_b_v = f32x4::splat(token, BIAS_B);
325 let half_v = f32x4::splat(token, 0.5);
326
327 for chunk in 0..16 {
329 let base = chunk * 4;
330
331 let q_i_y = i32x4::from_slice(token, &quant_ac_y[base..]);
333 let dq_y = neon_dequant_4(
334 token,
335 q_i_y,
336 bias_y_v,
337 bias_recip_v,
338 threshold,
339 zero_f,
340 one_f,
341 neg_one_f,
342 half_v,
343 &weights_y[base..],
344 inv_qac_y_v,
345 );
346 dq_y.store((&mut output_y[base..base + 4]).try_into().unwrap());
347
348 let q_i_x = i32x4::from_slice(token, &quant_ac_x[base..]);
350 let dq_x_raw = neon_dequant_4(
351 token,
352 q_i_x,
353 bias_x_v,
354 bias_recip_v,
355 threshold,
356 zero_f,
357 one_f,
358 neg_one_f,
359 half_v,
360 &weights_x[base..],
361 inv_qac_x_v,
362 );
363 let dq_x = dq_x_raw + x_factor_v * dq_y;
364 dq_x.store((&mut output_x[base..base + 4]).try_into().unwrap());
365
366 let q_i_b = i32x4::from_slice(token, &quant_ac_b[base..]);
368 let dq_b_raw = neon_dequant_4(
369 token,
370 q_i_b,
371 bias_b_v,
372 bias_recip_v,
373 threshold,
374 zero_f,
375 one_f,
376 neg_one_f,
377 half_v,
378 &weights_b[base..],
379 inv_qac_b_v,
380 );
381 let dq_b = dq_b_raw + b_factor_v * dq_y;
382 dq_b.store((&mut output_b[base..base + 4]).try_into().unwrap());
383 }
384
385 output_x[0] = 0.0;
386 output_y[0] = 0.0;
387 output_b[0] = 0.0;
388}
389
390#[cfg(target_arch = "aarch64")]
392#[archmage::rite]
393#[allow(clippy::too_many_arguments)]
394fn neon_dequant_4(
395 token: archmage::NeonToken,
396 q_int: magetypes::simd::i32x4,
397 channel_bias: magetypes::simd::f32x4,
398 bias_recip: magetypes::simd::f32x4,
399 threshold: magetypes::simd::f32x4,
400 zero_f: magetypes::simd::f32x4,
401 one_f: magetypes::simd::f32x4,
402 neg_one_f: magetypes::simd::f32x4,
403 half_v: magetypes::simd::f32x4,
404 weights: &[f32],
405 inv_qac_qm: magetypes::simd::f32x4,
406) -> magetypes::simd::f32x4 {
407 use magetypes::simd::f32x4;
408
409 let q_f = f32x4::from_i32x4(q_int);
410 let abs_q = q_f.abs();
411
412 let sign = f32x4::blend(q_f.simd_ge(zero_f), one_f, neg_one_f);
413
414 let case_one = sign * channel_bias;
416
417 let case_large = q_f - bias_recip / q_f;
419
420 let is_large = abs_q.simd_ge(threshold);
421 let biased = f32x4::blend(is_large, case_large, case_one);
422
423 let is_nonzero = abs_q.simd_ge(half_v);
425 let biased = f32x4::blend(is_nonzero, biased, zero_f);
426
427 let w = f32x4::from_slice(token, weights);
428 biased * w * inv_qac_qm
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 extern crate alloc;
435
436 #[test]
437 fn test_dequant_dct8_matches_scalar() {
438 let mut quant_x = [0i32; 64];
439 let mut quant_y = [0i32; 64];
440 let mut quant_b = [0i32; 64];
441 let mut weights_x = [0.01f32; 64];
442 let mut weights_y = [0.01f32; 64];
443 let mut weights_b = [0.01f32; 64];
444
445 for i in 0..64 {
447 let v = (i as i32) - 32; quant_x[i] = v;
449 quant_y[i] = v / 2;
450 quant_b[i] = -v;
451 weights_x[i] = 0.01 + i as f32 * 0.001;
452 weights_y[i] = 0.02 + i as f32 * 0.0005;
453 weights_b[i] = 0.015 + i as f32 * 0.0008;
454 }
455
456 let qac_qm = [3.5f32, 4.0, 3.2];
457 let x_factor = 0.15f32;
458 let b_factor = 1.05f32;
459
460 let mut ref_x = [0.0f32; 64];
462 let mut ref_y = [0.0f32; 64];
463 let mut ref_b = [0.0f32; 64];
464 dequant_dct8_scalar(
465 &quant_x, &quant_y, &quant_b, &weights_x, &weights_y, &weights_b, qac_qm, x_factor,
466 b_factor, &mut ref_x, &mut ref_y, &mut ref_b,
467 );
468
469 let mut out_x = [0.0f32; 64];
471 let mut out_y = [0.0f32; 64];
472 let mut out_b = [0.0f32; 64];
473 dequant_block_dct8(
474 &quant_x, &quant_y, &quant_b, &weights_x, &weights_y, &weights_b, qac_qm, x_factor,
475 b_factor, &mut out_x, &mut out_y, &mut out_b,
476 );
477
478 let eps = 1e-5;
480 for i in 0..64 {
481 let diff_x = (out_x[i] - ref_x[i]).abs();
482 let diff_y = (out_y[i] - ref_y[i]).abs();
483 let diff_b = (out_b[i] - ref_b[i]).abs();
484 assert!(
485 diff_x < eps,
486 "X[{}] mismatch: simd={}, ref={}, diff={}",
487 i,
488 out_x[i],
489 ref_x[i],
490 diff_x
491 );
492 assert!(
493 diff_y < eps,
494 "Y[{}] mismatch: simd={}, ref={}, diff={}",
495 i,
496 out_y[i],
497 ref_y[i],
498 diff_y
499 );
500 assert!(
501 diff_b < eps,
502 "B[{}] mismatch: simd={}, ref={}, diff={}",
503 i,
504 out_b[i],
505 ref_b[i],
506 diff_b
507 );
508 }
509 }
510
511 #[test]
512 fn test_dequant_dct8_all_zeros() {
513 let quant = [0i32; 64];
514 let weights = [1.0f32; 64];
515 let qac_qm = [1.0f32; 3];
516
517 let mut out_x = [99.0f32; 64];
518 let mut out_y = [99.0f32; 64];
519 let mut out_b = [99.0f32; 64];
520 dequant_block_dct8(
521 &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.1, 1.0, &mut out_x,
522 &mut out_y, &mut out_b,
523 );
524
525 for i in 0..64 {
528 assert_eq!(out_x[i], 0.0, "X[{}] should be 0 for zero input", i);
529 assert_eq!(out_y[i], 0.0, "Y[{}] should be 0 for zero input", i);
530 assert_eq!(out_b[i], 0.0, "B[{}] should be 0 for zero input", i);
531 }
532 }
533
534 #[test]
535 fn test_dequant_dct8_unit_values() {
536 let mut quant = [0i32; 64];
538 for (i, q) in quant.iter_mut().enumerate().skip(1) {
539 *q = if i % 2 == 0 { 1 } else { -1 };
540 }
541 let weights = [1.0f32; 64];
542 let qac_qm = [1.0f32, 1.0, 1.0];
543
544 let mut out_x = [0.0f32; 64];
545 let mut out_y = [0.0f32; 64];
546 let mut out_b = [0.0f32; 64];
547 let mut ref_x = [0.0f32; 64];
548 let mut ref_y = [0.0f32; 64];
549 let mut ref_b = [0.0f32; 64];
550
551 dequant_block_dct8(
552 &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.0, 0.0, &mut out_x,
553 &mut out_y, &mut out_b,
554 );
555 dequant_dct8_scalar(
556 &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.0, 0.0, &mut ref_x,
557 &mut ref_y, &mut ref_b,
558 );
559
560 let eps = 1e-6;
561 for i in 1..64 {
562 assert!(
563 (out_y[i] - ref_y[i]).abs() < eps,
564 "Y[{}]: simd={}, ref={}",
565 i,
566 out_y[i],
567 ref_y[i]
568 );
569 }
570 }
571}