Skip to main content

jxl_encoder_simd/
dequant.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! SIMD-accelerated AC coefficient dequantization for DCT8 blocks.
6//!
7//! The reconstruct_xyb inner dequant loop is ~7% of encoder CPU. For DCT8
8//! (the most common strategy), we process 64 coefficients in one pass:
9//! dequantize + adjust_quant_bias + CfL restore.
10
11/// Dequantize a DCT8 block and apply CfL (chroma-from-luma) in one pass.
12///
13/// For each channel c and coefficient i (except DC at index 0):
14///   biased = adjust_quant_bias(quant_ac[c][i], c)
15///   dequant[c][i] = biased * weights[c][i] / (qac * qm_mul[c])
16///
17/// Then CfL restore (AC positions only):
18///   dequant[X][i] += x_factor * dequant[Y][i]
19///   dequant[B][i] += b_factor * dequant[Y][i]
20///
21/// DC (index 0) is left as-is in the output (caller must restore LLF from DC).
22///
23/// # Parameters
24/// - `quant_ac_x/y/b`: Quantized AC coefficients per channel, [i32; 64]
25/// - `weights_x/y/b`: Dequantization weights per channel, [f32; 64]
26/// - `qac_qm`: Per-channel `qac * qm_mul` values [x, y, b]
27/// - `x_factor`: CfL ytox ratio for this tile
28/// - `b_factor`: CfL ytob ratio for this tile
29/// - `output_x/y/b`: Output dequantized coefficients per channel, [f32; 64]
30#[inline]
31#[allow(clippy::too_many_arguments)]
32pub fn dequant_block_dct8(
33    quant_ac_x: &[i32; 64],
34    quant_ac_y: &[i32; 64],
35    quant_ac_b: &[i32; 64],
36    weights_x: &[f32; 64],
37    weights_y: &[f32; 64],
38    weights_b: &[f32; 64],
39    qac_qm: [f32; 3], // [x, y, b]
40    x_factor: f32,
41    b_factor: f32,
42    output_x: &mut [f32; 64],
43    output_y: &mut [f32; 64],
44    output_b: &mut [f32; 64],
45) {
46    #[cfg(target_arch = "x86_64")]
47    {
48        use archmage::SimdToken;
49        if let Some(token) = archmage::X64V3Token::summon() {
50            dequant_dct8_avx2(
51                token, quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm,
52                x_factor, b_factor, output_x, output_y, output_b,
53            );
54            return;
55        }
56    }
57
58    #[cfg(target_arch = "aarch64")]
59    {
60        use archmage::SimdToken;
61        if let Some(token) = archmage::NeonToken::summon() {
62            dequant_dct8_neon(
63                token, quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm,
64                x_factor, b_factor, output_x, output_y, output_b,
65            );
66            return;
67        }
68    }
69
70    dequant_dct8_scalar(
71        quant_ac_x, quant_ac_y, quant_ac_b, weights_x, weights_y, weights_b, qac_qm, x_factor,
72        b_factor, output_x, output_y, output_b,
73    );
74}
75
76// AdjustQuantBias constants (pre-computed to avoid clippy lossy_float_literal)
77const BIAS_X: f32 = 0.945_349_93; // 1.0 - 0.054_650_073
78const BIAS_Y: f32 = 0.929_945_5; // 1.0 - 0.070_054_499
79const BIAS_B: f32 = 0.950_064_9; // 1.0 - 0.049_935_103
80const BIAS_RECIP: f32 = 0.145;
81
82#[inline(always)]
83fn adjust_quant_bias_scalar(q_int: i32, channel_bias: f32) -> f32 {
84    if q_int == 0 {
85        return 0.0;
86    }
87    let q = q_int as f32;
88    if q.abs() < 1.125 {
89        q.signum() * channel_bias
90    } else {
91        q - BIAS_RECIP / q
92    }
93}
94
95#[inline]
96#[allow(clippy::too_many_arguments)]
97pub fn dequant_dct8_scalar(
98    quant_ac_x: &[i32; 64],
99    quant_ac_y: &[i32; 64],
100    quant_ac_b: &[i32; 64],
101    weights_x: &[f32; 64],
102    weights_y: &[f32; 64],
103    weights_b: &[f32; 64],
104    qac_qm: [f32; 3],
105    x_factor: f32,
106    b_factor: f32,
107    output_x: &mut [f32; 64],
108    output_y: &mut [f32; 64],
109    output_b: &mut [f32; 64],
110) {
111    let inv_qac_x = 1.0 / qac_qm[0];
112    let inv_qac_y = 1.0 / qac_qm[1];
113    let inv_qac_b = 1.0 / qac_qm[2];
114
115    // DC (index 0) must be zeroed — it's restored from DC separately
116    output_x[0] = 0.0;
117    output_y[0] = 0.0;
118    output_b[0] = 0.0;
119
120    for i in 1..64 {
121        // Dequantize each channel
122        let biased_x = adjust_quant_bias_scalar(quant_ac_x[i], BIAS_X);
123        let biased_y = adjust_quant_bias_scalar(quant_ac_y[i], BIAS_Y);
124        let biased_b = adjust_quant_bias_scalar(quant_ac_b[i], BIAS_B);
125
126        let dq_y = biased_y * weights_y[i] * inv_qac_y;
127        output_y[i] = dq_y;
128
129        // CfL restore: X += ytox * Y, B += ytob * Y
130        output_x[i] = biased_x * weights_x[i] * inv_qac_x + x_factor * dq_y;
131        output_b[i] = biased_b * weights_b[i] * inv_qac_b + b_factor * dq_y;
132    }
133}
134
135#[cfg(target_arch = "x86_64")]
136#[inline]
137#[archmage::arcane]
138#[allow(clippy::too_many_arguments)]
139pub fn dequant_dct8_avx2(
140    token: archmage::X64V3Token,
141    quant_ac_x: &[i32; 64],
142    quant_ac_y: &[i32; 64],
143    quant_ac_b: &[i32; 64],
144    weights_x: &[f32; 64],
145    weights_y: &[f32; 64],
146    weights_b: &[f32; 64],
147    qac_qm: [f32; 3],
148    x_factor: f32,
149    b_factor: f32,
150    output_x: &mut [f32; 64],
151    output_y: &mut [f32; 64],
152    output_b: &mut [f32; 64],
153) {
154    use magetypes::simd::{f32x8, i32x8};
155
156    let inv_qac_x_v = f32x8::splat(token, 1.0 / qac_qm[0]);
157    let inv_qac_y_v = f32x8::splat(token, 1.0 / qac_qm[1]);
158    let inv_qac_b_v = f32x8::splat(token, 1.0 / qac_qm[2]);
159    let x_factor_v = f32x8::splat(token, x_factor);
160    let b_factor_v = f32x8::splat(token, b_factor);
161    let zero_f = f32x8::zero(token);
162    let zero_i = i32x8::zero(token);
163    let one_f = f32x8::splat(token, 1.0);
164    let neg_one_f = f32x8::splat(token, -1.0);
165    let threshold = f32x8::splat(token, 1.125);
166    let bias_recip_v = f32x8::splat(token, BIAS_RECIP);
167    let bias_x_v = f32x8::splat(token, BIAS_X);
168    let bias_y_v = f32x8::splat(token, BIAS_Y);
169    let bias_b_v = f32x8::splat(token, BIAS_B);
170
171    // Process 8 chunks of 8 coefficients
172    for chunk in 0..8 {
173        let base = chunk * 8;
174
175        // --- Channel Y ---
176        let q_i_y = i32x8::from_slice(token, &quant_ac_y[base..]);
177        let dq_y = dequant_8_avx2(
178            token,
179            q_i_y,
180            bias_y_v,
181            bias_recip_v,
182            threshold,
183            zero_i,
184            zero_f,
185            one_f,
186            neg_one_f,
187            &weights_y[base..],
188            inv_qac_y_v,
189        );
190        dq_y.store((&mut output_y[base..base + 8]).try_into().unwrap());
191
192        // --- Channel X + CfL ---
193        let q_i_x = i32x8::from_slice(token, &quant_ac_x[base..]);
194        let dq_x_raw = dequant_8_avx2(
195            token,
196            q_i_x,
197            bias_x_v,
198            bias_recip_v,
199            threshold,
200            zero_i,
201            zero_f,
202            one_f,
203            neg_one_f,
204            &weights_x[base..],
205            inv_qac_x_v,
206        );
207        let dq_x = dq_x_raw + x_factor_v * dq_y;
208        dq_x.store((&mut output_x[base..base + 8]).try_into().unwrap());
209
210        // --- Channel B + CfL ---
211        let q_i_b = i32x8::from_slice(token, &quant_ac_b[base..]);
212        let dq_b_raw = dequant_8_avx2(
213            token,
214            q_i_b,
215            bias_b_v,
216            bias_recip_v,
217            threshold,
218            zero_i,
219            zero_f,
220            one_f,
221            neg_one_f,
222            &weights_b[base..],
223            inv_qac_b_v,
224        );
225        let dq_b = dq_b_raw + b_factor_v * dq_y;
226        dq_b.store((&mut output_b[base..base + 8]).try_into().unwrap());
227    }
228
229    // DC (index 0) must be zeroed — it's restored from DC separately
230    output_x[0] = 0.0;
231    output_y[0] = 0.0;
232    output_b[0] = 0.0;
233}
234
235/// Dequantize 8 coefficients with adjust_quant_bias, branchless SIMD.
236///
237/// For each element:
238///   if q == 0: result = 0
239///   elif |q| == 1: result = sign(q) * channel_bias
240///   else: result = q - 0.145/q
241///   output = result * weight / (qac * qm_mul)
242#[cfg(target_arch = "x86_64")]
243#[archmage::arcane]
244#[inline(always)]
245#[allow(clippy::too_many_arguments)]
246fn dequant_8_avx2(
247    token: archmage::X64V3Token,
248    q_int: magetypes::simd::i32x8,
249    channel_bias: magetypes::simd::f32x8,
250    bias_recip: magetypes::simd::f32x8,
251    threshold: magetypes::simd::f32x8,
252    _zero_i: magetypes::simd::i32x8,
253    zero_f: magetypes::simd::f32x8,
254    one_f: magetypes::simd::f32x8,
255    neg_one_f: magetypes::simd::f32x8,
256    weights: &[f32],
257    inv_qac_qm: magetypes::simd::f32x8,
258) -> magetypes::simd::f32x8 {
259    use magetypes::simd::f32x8;
260
261    // Convert q to f32
262    let q_f = q_int.to_f32x8();
263    let abs_q = q_f.abs();
264
265    // Compute sign: 1.0 for positive, -1.0 for negative (0 handled by zero_mask)
266    let sign = f32x8::blend(q_f.simd_ge(zero_f), one_f, neg_one_f);
267
268    // Case 1: |q| < 1.125 (i.e., q == ±1 for integer inputs)
269    // Result: sign * channel_bias
270    let case_one = sign * channel_bias;
271
272    // Case 2: |q| >= 1.125 (i.e., |q| >= 2 for integer inputs)
273    // Result: q - 0.145 / q
274    let case_large = q_f - bias_recip / q_f;
275
276    // Select: if |q| < 1.125 use case_one, else case_large
277    let is_large = abs_q.simd_ge(threshold);
278    let biased = f32x8::blend(is_large, case_large, case_one);
279
280    // Zero out where q == 0 (compare in f32 space since blend needs f32 mask)
281    let is_nonzero = abs_q.simd_ge(f32x8::splat(token, 0.5)); // integers: |q|>=0.5 means nonzero
282    let biased = f32x8::blend(is_nonzero, biased, zero_f);
283
284    // Multiply by dequant weight and inverse qac_qm
285    let w = f32x8::from_slice(token, weights);
286    biased * w * inv_qac_qm
287}
288
289// --- aarch64 NEON implementation ---
290
291#[cfg(target_arch = "aarch64")]
292#[inline]
293#[archmage::arcane]
294#[allow(clippy::too_many_arguments)]
295pub fn dequant_dct8_neon(
296    token: archmage::NeonToken,
297    quant_ac_x: &[i32; 64],
298    quant_ac_y: &[i32; 64],
299    quant_ac_b: &[i32; 64],
300    weights_x: &[f32; 64],
301    weights_y: &[f32; 64],
302    weights_b: &[f32; 64],
303    qac_qm: [f32; 3],
304    x_factor: f32,
305    b_factor: f32,
306    output_x: &mut [f32; 64],
307    output_y: &mut [f32; 64],
308    output_b: &mut [f32; 64],
309) {
310    use magetypes::simd::{f32x4, i32x4};
311
312    let inv_qac_x_v = f32x4::splat(token, 1.0 / qac_qm[0]);
313    let inv_qac_y_v = f32x4::splat(token, 1.0 / qac_qm[1]);
314    let inv_qac_b_v = f32x4::splat(token, 1.0 / qac_qm[2]);
315    let x_factor_v = f32x4::splat(token, x_factor);
316    let b_factor_v = f32x4::splat(token, b_factor);
317    let zero_f = f32x4::zero(token);
318    let one_f = f32x4::splat(token, 1.0);
319    let neg_one_f = f32x4::splat(token, -1.0);
320    let threshold = f32x4::splat(token, 1.125);
321    let bias_recip_v = f32x4::splat(token, BIAS_RECIP);
322    let bias_x_v = f32x4::splat(token, BIAS_X);
323    let bias_y_v = f32x4::splat(token, BIAS_Y);
324    let bias_b_v = f32x4::splat(token, BIAS_B);
325    let half_v = f32x4::splat(token, 0.5);
326
327    // Process 16 chunks of 4 coefficients
328    for chunk in 0..16 {
329        let base = chunk * 4;
330
331        // Channel Y
332        let q_i_y = i32x4::from_slice(token, &quant_ac_y[base..]);
333        let dq_y = neon_dequant_4(
334            token,
335            q_i_y,
336            bias_y_v,
337            bias_recip_v,
338            threshold,
339            zero_f,
340            one_f,
341            neg_one_f,
342            half_v,
343            &weights_y[base..],
344            inv_qac_y_v,
345        );
346        dq_y.store((&mut output_y[base..base + 4]).try_into().unwrap());
347
348        // Channel X + CfL
349        let q_i_x = i32x4::from_slice(token, &quant_ac_x[base..]);
350        let dq_x_raw = neon_dequant_4(
351            token,
352            q_i_x,
353            bias_x_v,
354            bias_recip_v,
355            threshold,
356            zero_f,
357            one_f,
358            neg_one_f,
359            half_v,
360            &weights_x[base..],
361            inv_qac_x_v,
362        );
363        let dq_x = dq_x_raw + x_factor_v * dq_y;
364        dq_x.store((&mut output_x[base..base + 4]).try_into().unwrap());
365
366        // Channel B + CfL
367        let q_i_b = i32x4::from_slice(token, &quant_ac_b[base..]);
368        let dq_b_raw = neon_dequant_4(
369            token,
370            q_i_b,
371            bias_b_v,
372            bias_recip_v,
373            threshold,
374            zero_f,
375            one_f,
376            neg_one_f,
377            half_v,
378            &weights_b[base..],
379            inv_qac_b_v,
380        );
381        let dq_b = dq_b_raw + b_factor_v * dq_y;
382        dq_b.store((&mut output_b[base..base + 4]).try_into().unwrap());
383    }
384
385    output_x[0] = 0.0;
386    output_y[0] = 0.0;
387    output_b[0] = 0.0;
388}
389
390/// Dequantize 4 coefficients with adjust_quant_bias, branchless NEON.
391#[cfg(target_arch = "aarch64")]
392#[archmage::rite]
393#[allow(clippy::too_many_arguments)]
394fn neon_dequant_4(
395    token: archmage::NeonToken,
396    q_int: magetypes::simd::i32x4,
397    channel_bias: magetypes::simd::f32x4,
398    bias_recip: magetypes::simd::f32x4,
399    threshold: magetypes::simd::f32x4,
400    zero_f: magetypes::simd::f32x4,
401    one_f: magetypes::simd::f32x4,
402    neg_one_f: magetypes::simd::f32x4,
403    half_v: magetypes::simd::f32x4,
404    weights: &[f32],
405    inv_qac_qm: magetypes::simd::f32x4,
406) -> magetypes::simd::f32x4 {
407    use magetypes::simd::f32x4;
408
409    let q_f = f32x4::from_i32x4(q_int);
410    let abs_q = q_f.abs();
411
412    let sign = f32x4::blend(q_f.simd_ge(zero_f), one_f, neg_one_f);
413
414    // Case 1: |q| < 1.125 → sign * channel_bias
415    let case_one = sign * channel_bias;
416
417    // Case 2: |q| >= 1.125 → q - 0.145/q
418    let case_large = q_f - bias_recip / q_f;
419
420    let is_large = abs_q.simd_ge(threshold);
421    let biased = f32x4::blend(is_large, case_large, case_one);
422
423    // Zero out where q == 0
424    let is_nonzero = abs_q.simd_ge(half_v);
425    let biased = f32x4::blend(is_nonzero, biased, zero_f);
426
427    let w = f32x4::from_slice(token, weights);
428    biased * w * inv_qac_qm
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    extern crate alloc;
435
436    #[test]
437    fn test_dequant_dct8_matches_scalar() {
438        let mut quant_x = [0i32; 64];
439        let mut quant_y = [0i32; 64];
440        let mut quant_b = [0i32; 64];
441        let mut weights_x = [0.01f32; 64];
442        let mut weights_y = [0.01f32; 64];
443        let mut weights_b = [0.01f32; 64];
444
445        // Fill with varied values: 0, ±1, and larger
446        for i in 0..64 {
447            let v = (i as i32) - 32; // -32..+31
448            quant_x[i] = v;
449            quant_y[i] = v / 2;
450            quant_b[i] = -v;
451            weights_x[i] = 0.01 + i as f32 * 0.001;
452            weights_y[i] = 0.02 + i as f32 * 0.0005;
453            weights_b[i] = 0.015 + i as f32 * 0.0008;
454        }
455
456        let qac_qm = [3.5f32, 4.0, 3.2];
457        let x_factor = 0.15f32;
458        let b_factor = 1.05f32;
459
460        // Scalar reference
461        let mut ref_x = [0.0f32; 64];
462        let mut ref_y = [0.0f32; 64];
463        let mut ref_b = [0.0f32; 64];
464        dequant_dct8_scalar(
465            &quant_x, &quant_y, &quant_b, &weights_x, &weights_y, &weights_b, qac_qm, x_factor,
466            b_factor, &mut ref_x, &mut ref_y, &mut ref_b,
467        );
468
469        // SIMD
470        let mut out_x = [0.0f32; 64];
471        let mut out_y = [0.0f32; 64];
472        let mut out_b = [0.0f32; 64];
473        dequant_block_dct8(
474            &quant_x, &quant_y, &quant_b, &weights_x, &weights_y, &weights_b, qac_qm, x_factor,
475            b_factor, &mut out_x, &mut out_y, &mut out_b,
476        );
477
478        // Compare — DC (index 0) should be 0 from both paths
479        let eps = 1e-5;
480        for i in 0..64 {
481            let diff_x = (out_x[i] - ref_x[i]).abs();
482            let diff_y = (out_y[i] - ref_y[i]).abs();
483            let diff_b = (out_b[i] - ref_b[i]).abs();
484            assert!(
485                diff_x < eps,
486                "X[{}] mismatch: simd={}, ref={}, diff={}",
487                i,
488                out_x[i],
489                ref_x[i],
490                diff_x
491            );
492            assert!(
493                diff_y < eps,
494                "Y[{}] mismatch: simd={}, ref={}, diff={}",
495                i,
496                out_y[i],
497                ref_y[i],
498                diff_y
499            );
500            assert!(
501                diff_b < eps,
502                "B[{}] mismatch: simd={}, ref={}, diff={}",
503                i,
504                out_b[i],
505                ref_b[i],
506                diff_b
507            );
508        }
509    }
510
511    #[test]
512    fn test_dequant_dct8_all_zeros() {
513        let quant = [0i32; 64];
514        let weights = [1.0f32; 64];
515        let qac_qm = [1.0f32; 3];
516
517        let mut out_x = [99.0f32; 64];
518        let mut out_y = [99.0f32; 64];
519        let mut out_b = [99.0f32; 64];
520        dequant_block_dct8(
521            &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.1, 1.0, &mut out_x,
522            &mut out_y, &mut out_b,
523        );
524
525        // DC (index 0) stays as-is from SIMD but scalar sets to 0 since it starts from i=1
526        // Actually the SIMD path zeros DC explicitly
527        for i in 0..64 {
528            assert_eq!(out_x[i], 0.0, "X[{}] should be 0 for zero input", i);
529            assert_eq!(out_y[i], 0.0, "Y[{}] should be 0 for zero input", i);
530            assert_eq!(out_b[i], 0.0, "B[{}] should be 0 for zero input", i);
531        }
532    }
533
534    #[test]
535    fn test_dequant_dct8_unit_values() {
536        // All ±1 values — tests the channel bias path
537        let mut quant = [0i32; 64];
538        for (i, q) in quant.iter_mut().enumerate().skip(1) {
539            *q = if i % 2 == 0 { 1 } else { -1 };
540        }
541        let weights = [1.0f32; 64];
542        let qac_qm = [1.0f32, 1.0, 1.0];
543
544        let mut out_x = [0.0f32; 64];
545        let mut out_y = [0.0f32; 64];
546        let mut out_b = [0.0f32; 64];
547        let mut ref_x = [0.0f32; 64];
548        let mut ref_y = [0.0f32; 64];
549        let mut ref_b = [0.0f32; 64];
550
551        dequant_block_dct8(
552            &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.0, 0.0, &mut out_x,
553            &mut out_y, &mut out_b,
554        );
555        dequant_dct8_scalar(
556            &quant, &quant, &quant, &weights, &weights, &weights, qac_qm, 0.0, 0.0, &mut ref_x,
557            &mut ref_y, &mut ref_b,
558        );
559
560        let eps = 1e-6;
561        for i in 1..64 {
562            assert!(
563                (out_y[i] - ref_y[i]).abs() < eps,
564                "Y[{}]: simd={}, ref={}",
565                i,
566                out_y[i],
567                ref_y[i]
568            );
569        }
570    }
571}