Skip to main content

jxl_encoder_simd/
block_l2.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! SIMD-accelerated per-block masked weighted L2 error computation.
6//!
7//! For each 8x8 block, computes:
8//!   error = sum over pixels of: mask[px]^2 * sum_c(weight[c] * (orig[c][px] - recon[c][px])^2)
9//!
10//! Used by EPF sharpness selection to compare reconstruction quality.
11
12use alloc::vec;
13use alloc::vec::Vec;
14
15/// Channel weights for L2 error: X=12.34, Y=1.0, B=0.2
16const CHANNEL_WEIGHTS: [f32; 3] = [12.339_445, 1.0, 0.2];
17
18/// Compute per-block masked weighted L2 error between original and reconstructed XYB planes.
19///
20/// Each block's error = sum over 8x8 pixels of:
21///   mask[px]^2 * (w_x * dx^2 + w_y * dy^2 + w_b * db^2)
22///
23/// All planes and mask have stride = `xsize_blocks * 8`.
24#[inline]
25pub fn compute_block_l2_errors(
26    original: [&[f32]; 3],
27    reconstructed: [&[f32]; 3],
28    mask1x1: &[f32],
29    xsize_blocks: usize,
30    ysize_blocks: usize,
31) -> Vec<f32> {
32    let padded_width = xsize_blocks * 8;
33    let nblocks = xsize_blocks * ysize_blocks;
34
35    #[cfg(target_arch = "x86_64")]
36    {
37        use archmage::SimdToken;
38        if let Some(token) = archmage::X64V3Token::summon() {
39            return compute_block_l2_errors_avx2(
40                token,
41                original,
42                reconstructed,
43                mask1x1,
44                xsize_blocks,
45                ysize_blocks,
46                padded_width,
47                nblocks,
48            );
49        }
50    }
51
52    #[cfg(target_arch = "aarch64")]
53    {
54        use archmage::SimdToken;
55        if let Some(token) = archmage::NeonToken::summon() {
56            return compute_block_l2_errors_neon(
57                token,
58                original,
59                reconstructed,
60                mask1x1,
61                xsize_blocks,
62                ysize_blocks,
63                padded_width,
64                nblocks,
65            );
66        }
67    }
68
69    compute_block_l2_errors_scalar(
70        original,
71        reconstructed,
72        mask1x1,
73        xsize_blocks,
74        ysize_blocks,
75        padded_width,
76        nblocks,
77    )
78}
79
80#[inline]
81pub fn compute_block_l2_errors_scalar(
82    original: [&[f32]; 3],
83    reconstructed: [&[f32]; 3],
84    mask1x1: &[f32],
85    xsize_blocks: usize,
86    ysize_blocks: usize,
87    padded_width: usize,
88    nblocks: usize,
89) -> Vec<f32> {
90    let mut errors = vec![0.0f32; nblocks];
91
92    for by in 0..ysize_blocks {
93        for bx in 0..xsize_blocks {
94            let block_idx = by * xsize_blocks + bx;
95            let mut total_err = 0.0f32;
96
97            for py in 0..8 {
98                for px in 0..8 {
99                    let y = by * 8 + py;
100                    let x = bx * 8 + px;
101                    let pixel_idx = y * padded_width + x;
102                    let mask = mask1x1[pixel_idx];
103                    let mask_sq = mask * mask;
104
105                    for c in 0..3 {
106                        let diff = original[c][pixel_idx] - reconstructed[c][pixel_idx];
107                        total_err += CHANNEL_WEIGHTS[c] * mask_sq * diff * diff;
108                    }
109                }
110            }
111
112            errors[block_idx] = total_err;
113        }
114    }
115
116    errors
117}
118
119#[cfg(target_arch = "x86_64")]
120#[inline]
121#[archmage::arcane]
122#[allow(clippy::too_many_arguments)]
123pub fn compute_block_l2_errors_avx2(
124    token: archmage::X64V3Token,
125    original: [&[f32]; 3],
126    reconstructed: [&[f32]; 3],
127    mask1x1: &[f32],
128    xsize_blocks: usize,
129    ysize_blocks: usize,
130    padded_width: usize,
131    nblocks: usize,
132) -> Vec<f32> {
133    use magetypes::simd::f32x8;
134
135    let w_x = f32x8::splat(token, CHANNEL_WEIGHTS[0]);
136    // w_y = 1.0, multiplication skipped in inner loop
137    let w_b = f32x8::splat(token, CHANNEL_WEIGHTS[2]);
138
139    let mut errors = vec![0.0f32; nblocks];
140
141    for by in 0..ysize_blocks {
142        for bx in 0..xsize_blocks {
143            let block_idx = by * xsize_blocks + bx;
144            let mut acc = f32x8::zero(token);
145
146            for py in 0..8 {
147                let row_start = (by * 8 + py) * padded_width + bx * 8;
148
149                // Load 8 mask values and square them
150                let mask_v = f32x8::from_slice(token, &mask1x1[row_start..]);
151                let mask_sq = mask_v * mask_v;
152
153                // X channel: w_x * mask_sq * (orig_x - recon_x)^2
154                let orig_x = f32x8::from_slice(token, &original[0][row_start..]);
155                let recon_x = f32x8::from_slice(token, &reconstructed[0][row_start..]);
156                let diff_x = orig_x - recon_x;
157                acc += w_x * mask_sq * diff_x * diff_x;
158
159                // Y channel: w_y * mask_sq * (orig_y - recon_y)^2
160                let orig_y = f32x8::from_slice(token, &original[1][row_start..]);
161                let recon_y = f32x8::from_slice(token, &reconstructed[1][row_start..]);
162                let diff_y = orig_y - recon_y;
163                // w_y = 1.0, so skip the multiply
164                acc += mask_sq * diff_y * diff_y;
165
166                // B channel: w_b * mask_sq * (orig_b - recon_b)^2
167                let orig_b = f32x8::from_slice(token, &original[2][row_start..]);
168                let recon_b = f32x8::from_slice(token, &reconstructed[2][row_start..]);
169                let diff_b = orig_b - recon_b;
170                acc += w_b * mask_sq * diff_b * diff_b;
171            }
172
173            // Horizontal sum of the 8-lane accumulator
174            errors[block_idx] = acc.reduce_add();
175        }
176    }
177
178    errors
179}
180
181// ============================================================================
182// aarch64 NEON implementation
183// ============================================================================
184
185#[cfg(target_arch = "aarch64")]
186#[inline]
187#[archmage::arcane]
188#[allow(clippy::too_many_arguments)]
189pub fn compute_block_l2_errors_neon(
190    token: archmage::NeonToken,
191    original: [&[f32]; 3],
192    reconstructed: [&[f32]; 3],
193    mask1x1: &[f32],
194    xsize_blocks: usize,
195    ysize_blocks: usize,
196    padded_width: usize,
197    nblocks: usize,
198) -> Vec<f32> {
199    use magetypes::simd::f32x4;
200
201    let w_x = f32x4::splat(token, CHANNEL_WEIGHTS[0]);
202    let w_b = f32x4::splat(token, CHANNEL_WEIGHTS[2]);
203
204    let mut errors = vec![0.0f32; nblocks];
205
206    for by in 0..ysize_blocks {
207        for bx in 0..xsize_blocks {
208            let block_idx = by * xsize_blocks + bx;
209            let mut acc = f32x4::zero(token);
210
211            for py in 0..8 {
212                let row_start = (by * 8 + py) * padded_width + bx * 8;
213
214                // Process 8 pixels as two f32x4 chunks
215                for half in 0..2usize {
216                    let off = row_start + half * 4;
217
218                    let mask_v = f32x4::from_slice(token, &mask1x1[off..]);
219                    let mask_sq = mask_v * mask_v;
220
221                    let orig_x = f32x4::from_slice(token, &original[0][off..]);
222                    let recon_x = f32x4::from_slice(token, &reconstructed[0][off..]);
223                    let diff_x = orig_x - recon_x;
224                    acc += w_x * mask_sq * diff_x * diff_x;
225
226                    let orig_y = f32x4::from_slice(token, &original[1][off..]);
227                    let recon_y = f32x4::from_slice(token, &reconstructed[1][off..]);
228                    let diff_y = orig_y - recon_y;
229                    acc += mask_sq * diff_y * diff_y;
230
231                    let orig_b = f32x4::from_slice(token, &original[2][off..]);
232                    let recon_b = f32x4::from_slice(token, &reconstructed[2][off..]);
233                    let diff_b = orig_b - recon_b;
234                    acc += w_b * mask_sq * diff_b * diff_b;
235                }
236            }
237
238            errors[block_idx] = acc.reduce_add();
239        }
240    }
241
242    errors
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use alloc::vec;
249
250    #[test]
251    fn test_block_l2_errors_uniform() {
252        let xsize_blocks = 2;
253        let ysize_blocks = 2;
254        let padded_width = xsize_blocks * 8;
255        let n = padded_width * ysize_blocks * 8;
256
257        // Uniform original, zero reconstructed → diff = original
258        let original = [vec![1.0f32; n], vec![1.0f32; n], vec![1.0f32; n]];
259        let reconstructed = [vec![0.0f32; n], vec![0.0f32; n], vec![0.0f32; n]];
260        let mask = vec![1.0f32; n];
261
262        let errors = compute_block_l2_errors(
263            [&original[0], &original[1], &original[2]],
264            [&reconstructed[0], &reconstructed[1], &reconstructed[2]],
265            &mask,
266            xsize_blocks,
267            ysize_blocks,
268        );
269
270        // Each pixel: mask^2 * (w_x * 1^2 + w_y * 1^2 + w_b * 1^2)
271        //           = 1.0 * (12.339445 + 1.0 + 0.2) = 13.539445
272        // 64 pixels per block: 64 * 13.539445 = 866.52448
273        let expected = 64.0 * (CHANNEL_WEIGHTS[0] + CHANNEL_WEIGHTS[1] + CHANNEL_WEIGHTS[2]);
274        for (i, &err) in errors.iter().enumerate() {
275            assert!(
276                (err - expected).abs() < 0.1,
277                "Block {} error {} != expected {}",
278                i,
279                err,
280                expected
281            );
282        }
283    }
284
285    #[test]
286    fn test_block_l2_errors_matches_scalar() {
287        let xsize_blocks = 4;
288        let ysize_blocks = 4;
289        let padded_width = xsize_blocks * 8;
290        let n = padded_width * ysize_blocks * 8;
291
292        // Create varied test data
293        let mut orig0 = vec![0.0f32; n];
294        let mut orig1 = vec![0.0f32; n];
295        let mut orig2 = vec![0.0f32; n];
296        let mut recon0 = vec![0.0f32; n];
297        let mut recon1 = vec![0.0f32; n];
298        let mut recon2 = vec![0.0f32; n];
299        let mut mask = vec![0.0f32; n];
300
301        for i in 0..n {
302            let f = i as f32;
303            orig0[i] = (f * 0.013).sin() * 0.5;
304            orig1[i] = (f * 0.017).cos() * 0.8;
305            orig2[i] = (f * 0.023).sin() * 0.3;
306            recon0[i] = orig0[i] + (f * 0.031).sin() * 0.1;
307            recon1[i] = orig1[i] + (f * 0.037).cos() * 0.05;
308            recon2[i] = orig2[i] + (f * 0.041).sin() * 0.02;
309            mask[i] = 0.5 + (f * 0.007).sin().abs() * 0.5;
310        }
311
312        let simd_result = compute_block_l2_errors(
313            [&orig0, &orig1, &orig2],
314            [&recon0, &recon1, &recon2],
315            &mask,
316            xsize_blocks,
317            ysize_blocks,
318        );
319
320        let scalar_result = compute_block_l2_errors_scalar(
321            [&orig0, &orig1, &orig2],
322            [&recon0, &recon1, &recon2],
323            &mask,
324            xsize_blocks,
325            ysize_blocks,
326            padded_width,
327            xsize_blocks * ysize_blocks,
328        );
329
330        assert_eq!(simd_result.len(), scalar_result.len());
331        for (i, (&s, &sc)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
332            let rel_err = if sc.abs() > 1e-10 {
333                ((s - sc) / sc).abs()
334            } else {
335                (s - sc).abs()
336            };
337            assert!(
338                rel_err < 1e-5,
339                "Block {} SIMD {} vs scalar {} rel_err {}",
340                i,
341                s,
342                sc,
343                rel_err
344            );
345        }
346    }
347}