Skip to main content

jxl_encoder_simd/
entropy.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! SIMD-accelerated coefficient processing for entropy estimation.
6//!
7//! The inner coefficient loop of `estimate_entropy_full` is the single biggest
8//! encoder hotspot (~7.5% CPU). This kernel vectorizes the per-coefficient math:
9//!   val = (block_c[i] - block_y[i] * cmap_factor) / weights[i] * quant
10//!   rval = round(val)
11//!   entropy_sum += sqrt(|rval|) * cost_delta
12//!   nzeros += (rval != 0)
13
14/// Results from vectorized entropy coefficient processing.
15#[derive(Debug, Clone, Copy)]
16pub struct EntropyCoeffResult {
17    /// Sum of sqrt(|round(val)|) * cost_delta for all coefficients.
18    pub entropy_sum: f32,
19    /// Count of non-zero quantized coefficients.
20    pub nzeros_sum: f32,
21    /// Sum of |val - round(val)| (coefficient-domain mode only).
22    pub info_loss_sum: f32,
23    /// Sum of (val - round(val))^2 (coefficient-domain mode only).
24    pub info_loss2_sum: f32,
25}
26
27/// Vectorized entropy coefficient processing.
28///
29/// For each coefficient i in 0..n:
30///   val = (block_c[i] - block_y[i] * cmap_factor) / weights[i] * quant
31///   rval = round(val)
32///   entropy_sum += sqrt(|rval|) * k_cost_delta
33///   nzeros += (rval != 0)
34///
35/// In pixel-domain mode: writes `error_coeffs[i] = weights[i] * (val - rval)`
36/// In coefficient-domain mode: accumulates info_loss stats and k_cost2 penalty.
37#[inline]
38#[allow(clippy::too_many_arguments)]
39pub fn entropy_estimate_coeffs(
40    block_c: &[f32],
41    block_y: &[f32],
42    weights: &[f32],
43    n: usize,
44    cmap_factor: f32,
45    quant: f32,
46    k_cost_delta: f32,
47    k_cost2: f32,
48    pixel_domain: bool,
49    error_coeffs: &mut [f32],
50) -> EntropyCoeffResult {
51    #[cfg(target_arch = "x86_64")]
52    {
53        use archmage::SimdToken;
54        if let Some(token) = archmage::X64V3Token::summon() {
55            return entropy_coeffs_avx2(
56                token,
57                block_c,
58                block_y,
59                weights,
60                n,
61                cmap_factor,
62                quant,
63                k_cost_delta,
64                k_cost2,
65                pixel_domain,
66                error_coeffs,
67            );
68        }
69    }
70
71    #[cfg(target_arch = "aarch64")]
72    {
73        use archmage::SimdToken;
74        if let Some(token) = archmage::NeonToken::summon() {
75            return entropy_coeffs_neon(
76                token,
77                block_c,
78                block_y,
79                weights,
80                n,
81                cmap_factor,
82                quant,
83                k_cost_delta,
84                k_cost2,
85                pixel_domain,
86                error_coeffs,
87            );
88        }
89    }
90
91    entropy_coeffs_scalar(
92        block_c,
93        block_y,
94        weights,
95        n,
96        cmap_factor,
97        quant,
98        k_cost_delta,
99        k_cost2,
100        pixel_domain,
101        error_coeffs,
102    )
103}
104
105#[inline]
106#[allow(clippy::too_many_arguments)]
107pub fn entropy_coeffs_scalar(
108    block_c: &[f32],
109    block_y: &[f32],
110    weights: &[f32],
111    n: usize,
112    cmap_factor: f32,
113    quant: f32,
114    k_cost_delta: f32,
115    k_cost2: f32,
116    pixel_domain: bool,
117    error_coeffs: &mut [f32],
118) -> EntropyCoeffResult {
119    let mut entropy_sum = 0.0f32;
120    let mut nzeros_sum = 0.0f32;
121    let mut info_loss_sum = 0.0f32;
122    let mut info_loss2_sum = 0.0f32;
123
124    for i in 0..n {
125        let val_in = block_c[i];
126        let val_y = block_y[i] * cmap_factor;
127        let val = (val_in - val_y) * (1.0 / weights[i]) * quant;
128        let rval = val.round();
129        let diff = val - rval;
130
131        if pixel_domain {
132            error_coeffs[i] = weights[i] * diff;
133        }
134
135        let q = rval.abs();
136        entropy_sum += q.sqrt() * k_cost_delta;
137        if q != 0.0 {
138            nzeros_sum += 1.0;
139        }
140
141        if !pixel_domain {
142            let diff_abs = diff.abs();
143            info_loss_sum += diff_abs;
144            info_loss2_sum += diff_abs * diff_abs;
145            if q >= 1.5 {
146                entropy_sum += k_cost2;
147            }
148        }
149    }
150
151    EntropyCoeffResult {
152        entropy_sum,
153        nzeros_sum,
154        info_loss_sum,
155        info_loss2_sum,
156    }
157}
158
159#[cfg(target_arch = "x86_64")]
160#[inline]
161#[archmage::arcane]
162#[allow(clippy::too_many_arguments)]
163pub fn entropy_coeffs_avx2(
164    token: archmage::X64V3Token,
165    block_c: &[f32],
166    block_y: &[f32],
167    weights: &[f32],
168    n: usize,
169    cmap_factor: f32,
170    quant: f32,
171    k_cost_delta: f32,
172    k_cost2: f32,
173    pixel_domain: bool,
174    error_coeffs: &mut [f32],
175) -> EntropyCoeffResult {
176    use magetypes::simd::f32x8;
177
178    let cmap_v = f32x8::splat(token, cmap_factor);
179    let quant_v = f32x8::splat(token, quant);
180    let cost_delta_v = f32x8::splat(token, k_cost_delta);
181    let cost2_v = f32x8::splat(token, k_cost2);
182    let zero = f32x8::zero(token);
183    let one = f32x8::splat(token, 1.0);
184    let thr_1_5 = f32x8::splat(token, 1.5);
185
186    let mut entropy_acc = f32x8::zero(token);
187    let mut nzeros_acc = f32x8::zero(token);
188    let mut info_loss_acc = f32x8::zero(token);
189    let mut info_loss2_acc = f32x8::zero(token);
190    let mut cost2_acc = f32x8::zero(token);
191
192    let chunks = n / 8;
193    // Pre-slice to exact SIMD length so the compiler can prove
194    // all from_slice loads are in bounds (base + 8 <= chunks * 8).
195    let simd_n = chunks * 8;
196    let block_c_s = &block_c[..simd_n];
197    let block_y_s = &block_y[..simd_n];
198    let weights_s = &weights[..simd_n];
199    for chunk in 0..chunks {
200        let base = chunk * 8;
201
202        let bc = f32x8::from_slice(token, &block_c_s[base..]);
203        let by_v = f32x8::from_slice(token, &block_y_s[base..]);
204        let w = f32x8::from_slice(token, &weights_s[base..]);
205
206        // val = (block_c - block_y * cmap_factor) / weights * quant
207        let adjusted = bc - by_v * cmap_v;
208        let val = adjusted / w * quant_v;
209
210        let rval = val.round();
211        let diff = val - rval;
212
213        // Write error coefficients for pixel-domain loss
214        if pixel_domain {
215            let err = w * diff;
216            let out: &mut [f32; 8] = (&mut error_coeffs[base..base + 8]).try_into().unwrap();
217            err.store(out);
218        }
219
220        // Entropy accumulation: entropy += sqrt(|rval|) * cost_delta
221        let q = rval.abs();
222        entropy_acc = q.sqrt().mul_add(cost_delta_v, entropy_acc);
223
224        // nzeros: count non-zero rounded values
225        let nz_mask = q.simd_ne(zero);
226        nzeros_acc += f32x8::blend(nz_mask, one, zero);
227
228        // Coefficient-domain statistics
229        if !pixel_domain {
230            let diff_abs = diff.abs();
231            info_loss_acc += diff_abs;
232            info_loss2_acc = diff_abs.mul_add(diff_abs, info_loss2_acc);
233
234            // q >= 1.5 penalty
235            let ge_mask = q.simd_ge(thr_1_5);
236            cost2_acc += f32x8::blend(ge_mask, cost2_v, zero);
237        }
238    }
239
240    // Handle remainder with scalar fallback
241    let start = chunks * 8;
242    let remainder = entropy_coeffs_scalar(
243        &block_c[start..n],
244        &block_y[start..n],
245        &weights[start..n],
246        n - start,
247        cmap_factor,
248        quant,
249        k_cost_delta,
250        k_cost2,
251        pixel_domain,
252        &mut error_coeffs[start..n],
253    );
254
255    let mut entropy_sum = entropy_acc.reduce_add() + remainder.entropy_sum;
256    if !pixel_domain {
257        entropy_sum += cost2_acc.reduce_add();
258    }
259
260    EntropyCoeffResult {
261        entropy_sum,
262        nzeros_sum: nzeros_acc.reduce_add() + remainder.nzeros_sum,
263        info_loss_sum: info_loss_acc.reduce_add() + remainder.info_loss_sum,
264        info_loss2_sum: info_loss2_acc.reduce_add() + remainder.info_loss2_sum,
265    }
266}
267
268// ============================================================================
269// aarch64 NEON implementation
270// ============================================================================
271
272#[cfg(target_arch = "aarch64")]
273#[inline]
274#[archmage::arcane]
275#[allow(clippy::too_many_arguments)]
276pub fn entropy_coeffs_neon(
277    token: archmage::NeonToken,
278    block_c: &[f32],
279    block_y: &[f32],
280    weights: &[f32],
281    n: usize,
282    cmap_factor: f32,
283    quant: f32,
284    k_cost_delta: f32,
285    k_cost2: f32,
286    pixel_domain: bool,
287    error_coeffs: &mut [f32],
288) -> EntropyCoeffResult {
289    use magetypes::simd::f32x4;
290
291    let cmap_v = f32x4::splat(token, cmap_factor);
292    let quant_v = f32x4::splat(token, quant);
293    let cost_delta_v = f32x4::splat(token, k_cost_delta);
294    let cost2_v = f32x4::splat(token, k_cost2);
295    let zero = f32x4::zero(token);
296    let one = f32x4::splat(token, 1.0);
297    let thr_1_5 = f32x4::splat(token, 1.5);
298
299    let mut entropy_acc = f32x4::zero(token);
300    let mut nzeros_acc = f32x4::zero(token);
301    let mut info_loss_acc = f32x4::zero(token);
302    let mut info_loss2_acc = f32x4::zero(token);
303    let mut cost2_acc = f32x4::zero(token);
304
305    let chunks = n / 4;
306    let simd_n = chunks * 4;
307    let block_c_s = &block_c[..simd_n];
308    let block_y_s = &block_y[..simd_n];
309    let weights_s = &weights[..simd_n];
310    for chunk in 0..chunks {
311        let base = chunk * 4;
312
313        let bc = f32x4::from_slice(token, &block_c_s[base..]);
314        let by_v = f32x4::from_slice(token, &block_y_s[base..]);
315        let w = f32x4::from_slice(token, &weights_s[base..]);
316
317        // val = (block_c - block_y * cmap_factor) / weights * quant
318        let adjusted = bc - by_v * cmap_v;
319        let val = adjusted / w * quant_v;
320
321        let rval = val.round();
322        let diff = val - rval;
323
324        if pixel_domain {
325            let err = w * diff;
326            let out: &mut [f32; 4] = (&mut error_coeffs[base..base + 4]).try_into().unwrap();
327            err.store(out);
328        }
329
330        let q = rval.abs();
331        entropy_acc = q.sqrt().mul_add(cost_delta_v, entropy_acc);
332
333        let nz_mask = q.simd_ne(zero);
334        nzeros_acc += f32x4::blend(nz_mask, one, zero);
335
336        if !pixel_domain {
337            let diff_abs = diff.abs();
338            info_loss_acc += diff_abs;
339            info_loss2_acc = diff_abs.mul_add(diff_abs, info_loss2_acc);
340
341            let ge_mask = q.simd_ge(thr_1_5);
342            cost2_acc += f32x4::blend(ge_mask, cost2_v, zero);
343        }
344    }
345
346    // Scalar remainder
347    let start = chunks * 4;
348    let remainder = entropy_coeffs_scalar(
349        &block_c[start..n],
350        &block_y[start..n],
351        &weights[start..n],
352        n - start,
353        cmap_factor,
354        quant,
355        k_cost_delta,
356        k_cost2,
357        pixel_domain,
358        &mut error_coeffs[start..n],
359    );
360
361    let mut entropy_sum = entropy_acc.reduce_add() + remainder.entropy_sum;
362    if !pixel_domain {
363        entropy_sum += cost2_acc.reduce_add();
364    }
365
366    EntropyCoeffResult {
367        entropy_sum,
368        nzeros_sum: nzeros_acc.reduce_add() + remainder.nzeros_sum,
369        info_loss_sum: info_loss_acc.reduce_add() + remainder.info_loss_sum,
370        info_loss2_sum: info_loss2_acc.reduce_add() + remainder.info_loss2_sum,
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    extern crate alloc;
378    use alloc::vec;
379    use alloc::vec::Vec;
380
381    /// Verify SIMD matches scalar for pixel-domain mode.
382    #[test]
383    fn test_entropy_coeffs_pixel_domain() {
384        let n = 64;
385        let block_c: Vec<f32> = (0..n).map(|i| (i as f32 * 0.7 - 20.0) * 0.1).collect();
386        let block_y: Vec<f32> = (0..n).map(|i| (i as f32 * 0.5 - 15.0) * 0.1).collect();
387        let weights: Vec<f32> = (0..n).map(|i| 0.01 + (i as f32) * 0.005).collect();
388
389        let cmap_factor = 0.15f32;
390        let quant = 3.5f32;
391        let k_cost_delta = 5.335f32;
392        let k_cost2 = 4.463f32;
393
394        // Reference: scalar
395        let mut error_ref = vec![0.0f32; n];
396        let ref_result = entropy_coeffs_scalar(
397            &block_c,
398            &block_y,
399            &weights,
400            n,
401            cmap_factor,
402            quant,
403            k_cost_delta,
404            k_cost2,
405            true,
406            &mut error_ref,
407        );
408
409        // SIMD
410        let mut error_simd = vec![0.0f32; n];
411        let simd_result = entropy_estimate_coeffs(
412            &block_c,
413            &block_y,
414            &weights,
415            n,
416            cmap_factor,
417            quant,
418            k_cost_delta,
419            k_cost2,
420            true,
421            &mut error_simd,
422        );
423
424        // FP differences from division ordering (x*(1/w) vs x/w) can cause
425        // different rounding decisions, so use relative tolerance
426        let rel_eps = 0.005; // 0.5% relative error
427        let entropy_rel =
428            (simd_result.entropy_sum - ref_result.entropy_sum).abs() / ref_result.entropy_sum.abs();
429        assert!(
430            entropy_rel < rel_eps,
431            "entropy_sum: SIMD={}, ref={}, rel_err={:.4}%",
432            simd_result.entropy_sum,
433            ref_result.entropy_sum,
434            entropy_rel * 100.0
435        );
436        // nzeros can differ slightly due to rounding boundary differences
437        let nz_rel = (simd_result.nzeros_sum - ref_result.nzeros_sum).abs()
438            / ref_result.nzeros_sum.abs().max(1.0);
439        assert!(
440            nz_rel < 0.05, // 5% tolerance for nzeros
441            "nzeros_sum: SIMD={}, ref={}, rel_err={:.4}%",
442            simd_result.nzeros_sum,
443            ref_result.nzeros_sum,
444            nz_rel * 100.0
445        );
446
447        // Error coeffs can differ by up to ~weight when rounding boundary
448        // decisions change (x*(1/w) vs x/w gives different ULPs near 0.5)
449        let mut max_err = 0.0f32;
450        for i in 0..n {
451            max_err = max_err.max((error_simd[i] - error_ref[i]).abs());
452        }
453        assert!(max_err < 0.5, "Error coeffs max diff: {:.2e}", max_err);
454    }
455
456    /// Verify SIMD matches scalar for coefficient-domain mode.
457    #[test]
458    fn test_entropy_coeffs_coeff_domain() {
459        let n = 64;
460        let block_c: Vec<f32> = (0..n).map(|i| (i as f32 * 1.3 - 40.0) * 0.05).collect();
461        let block_y: Vec<f32> = (0..n).map(|i| (i as f32 * 0.9 - 30.0) * 0.05).collect();
462        let weights: Vec<f32> = (0..n).map(|i| 0.02 + (i as f32) * 0.003).collect();
463
464        let cmap_factor = 0.0f32;
465        let quant = 5.0f32;
466        let k_cost_delta = 5.335f32;
467        let k_cost2 = 4.463f32;
468
469        let mut error_ref = vec![0.0f32; n];
470        let ref_result = entropy_coeffs_scalar(
471            &block_c,
472            &block_y,
473            &weights,
474            n,
475            cmap_factor,
476            quant,
477            k_cost_delta,
478            k_cost2,
479            false,
480            &mut error_ref,
481        );
482
483        let mut error_simd = vec![0.0f32; n];
484        let simd_result = entropy_estimate_coeffs(
485            &block_c,
486            &block_y,
487            &weights,
488            n,
489            cmap_factor,
490            quant,
491            k_cost_delta,
492            k_cost2,
493            false,
494            &mut error_simd,
495        );
496
497        let rel_eps = 0.005;
498        let entropy_rel =
499            (simd_result.entropy_sum - ref_result.entropy_sum).abs() / ref_result.entropy_sum.abs();
500        assert!(
501            entropy_rel < rel_eps,
502            "entropy_sum: SIMD={}, ref={}, rel_err={:.4}%",
503            simd_result.entropy_sum,
504            ref_result.entropy_sum,
505            entropy_rel * 100.0
506        );
507        let nz_rel = (simd_result.nzeros_sum - ref_result.nzeros_sum).abs()
508            / ref_result.nzeros_sum.abs().max(1.0);
509        assert!(
510            nz_rel < 0.05,
511            "nzeros_sum: SIMD={}, ref={}, rel_err={:.4}%",
512            simd_result.nzeros_sum,
513            ref_result.nzeros_sum,
514            nz_rel * 100.0
515        );
516        let il_rel = (simd_result.info_loss_sum - ref_result.info_loss_sum).abs()
517            / ref_result.info_loss_sum.abs().max(1.0);
518        assert!(
519            il_rel < rel_eps,
520            "info_loss_sum: SIMD={}, ref={}, rel_err={:.4}%",
521            simd_result.info_loss_sum,
522            ref_result.info_loss_sum,
523            il_rel * 100.0
524        );
525        let il2_rel = (simd_result.info_loss2_sum - ref_result.info_loss2_sum).abs()
526            / ref_result.info_loss2_sum.abs().max(1.0);
527        assert!(
528            il2_rel < rel_eps,
529            "info_loss2_sum: SIMD={}, ref={}, rel_err={:.4}%",
530            simd_result.info_loss2_sum,
531            ref_result.info_loss2_sum,
532            il2_rel * 100.0
533        );
534    }
535
536    /// Test with non-multiple-of-8 sizes (remainder handling).
537    #[test]
538    fn test_entropy_coeffs_remainder() {
539        let n = 67;
540        let block_c: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1 - 3.0).collect();
541        let block_y: Vec<f32> = (0..n).map(|i| (i as f32) * 0.08 - 2.5).collect();
542        let weights: Vec<f32> = (0..n).map(|i| 0.01 + (i as f32) * 0.002).collect();
543
544        let mut error_ref = vec![0.0f32; n];
545        let ref_result = entropy_coeffs_scalar(
546            &block_c,
547            &block_y,
548            &weights,
549            n,
550            0.2,
551            4.0,
552            5.335,
553            4.463,
554            true,
555            &mut error_ref,
556        );
557
558        let mut error_simd = vec![0.0f32; n];
559        let simd_result = entropy_estimate_coeffs(
560            &block_c,
561            &block_y,
562            &weights,
563            n,
564            0.2,
565            4.0,
566            5.335,
567            4.463,
568            true,
569            &mut error_simd,
570        );
571
572        let rel_eps = 0.005;
573        let entropy_rel = (simd_result.entropy_sum - ref_result.entropy_sum).abs()
574            / ref_result.entropy_sum.abs().max(1.0);
575        assert!(
576            entropy_rel < rel_eps,
577            "entropy_sum: SIMD={}, ref={}, rel_err={:.4}%",
578            simd_result.entropy_sum,
579            ref_result.entropy_sum,
580            entropy_rel * 100.0
581        );
582        let nz_rel = (simd_result.nzeros_sum - ref_result.nzeros_sum).abs()
583            / ref_result.nzeros_sum.abs().max(1.0);
584        assert!(
585            nz_rel < 0.05,
586            "nzeros_sum: SIMD={}, ref={}",
587            simd_result.nzeros_sum,
588            ref_result.nzeros_sum
589        );
590
591        let max_err = error_simd
592            .iter()
593            .zip(error_ref.iter())
594            .take(n)
595            .map(|(a, b)| (a - b).abs())
596            .fold(0.0f32, f32::max);
597        assert!(max_err < 0.01, "Error coeffs max diff: {:.2e}", max_err);
598    }
599
600    /// Test with large blocks (DCT64x64 = 4096 coefficients).
601    #[test]
602    fn test_entropy_coeffs_large_block() {
603        let n = 4096;
604        let block_c: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.01).sin() * 5.0).collect();
605        let block_y: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.013).cos() * 4.0).collect();
606        let weights: Vec<f32> = (0..n).map(|i| 0.005 + (i as f32) * 0.001).collect();
607
608        let mut error_ref = vec![0.0f32; n];
609        let ref_result = entropy_coeffs_scalar(
610            &block_c,
611            &block_y,
612            &weights,
613            n,
614            0.1,
615            2.0,
616            5.335,
617            4.463,
618            true,
619            &mut error_ref,
620        );
621
622        let mut error_simd = vec![0.0f32; n];
623        let simd_result = entropy_estimate_coeffs(
624            &block_c,
625            &block_y,
626            &weights,
627            n,
628            0.1,
629            2.0,
630            5.335,
631            4.463,
632            true,
633            &mut error_simd,
634        );
635
636        // Large block: use relative tolerance
637        let rel_eps = 0.005;
638        let entropy_rel =
639            (simd_result.entropy_sum - ref_result.entropy_sum).abs() / ref_result.entropy_sum.abs();
640        assert!(
641            entropy_rel < rel_eps,
642            "entropy_sum: SIMD={}, ref={}, rel_err={:.4}%",
643            simd_result.entropy_sum,
644            ref_result.entropy_sum,
645            entropy_rel * 100.0
646        );
647
648        let max_err = error_simd
649            .iter()
650            .zip(error_ref.iter())
651            .take(n)
652            .map(|(a, b)| (a - b).abs())
653            .fold(0.0f32, f32::max);
654        assert!(max_err < 1e-3, "Error coeffs max diff: {:.2e}", max_err);
655    }
656}