Skip to main content

djvu_iw44/
lib.rs

1//! IW44 wavelet image decoder — pure-Rust clean-room implementation (phase 2c).
2//!
3//! Implements the IW44 progressive wavelet codec used by DjVu BG44, FG44, and
4//! TH44 chunks.  Each BG44 chunk may carry one or more *slices*; the ZP coder
5//! state persists across all chunks so that progressive refinement works correctly.
6//!
7//! ## Key public types
8//!
9//! - `Iw44Image` — progressive decoder; call `Iw44Image::decode_chunk` for
10//!   each BG44/FG44/TH44 chunk, then `Iw44Image::to_rgb` to obtain an RGB
11//!   pixmap.
12//! - `Iw44Error` — typed error enum (re-exported from
13//!   this crate).
14//!
15//! ## Architecture
16//!
17//! YCbCr planes are kept separate (`y: Vec<i16>`, `cb: Vec<i16>`, `cr: Vec<i16>`)
18//! until `to_rgb()` is called.  This allows future SIMD processing on each plane
19//! independently.  No interleaved buffers exist inside this module.
20
21#![cfg_attr(not(feature = "std"), no_std)]
22#![deny(unsafe_code)]
23
24#[cfg(not(feature = "std"))]
25extern crate alloc;
26
27#[cfg(not(feature = "std"))]
28use alloc::{vec, vec::Vec};
29#[cfg(feature = "std")]
30use std::{vec, vec::Vec};
31
32use djvu_pixmap::Pixmap;
33use djvu_zp::ZpDecoder;
34
35/// IW44 wavelet image decoding errors.
36#[derive(Debug, thiserror::Error, PartialEq, Eq)]
37pub enum Iw44Error {
38    /// Input ended before the IW44 stream was complete.
39    #[error("IW44 stream is truncated")]
40    Truncated,
41
42    /// The IW44 stream contains invalid data.
43    #[error("IW44 stream contains invalid data")]
44    Invalid,
45
46    /// A BG44/FG44/TH44 chunk is too short (fewer than 2 bytes).
47    #[error("IW44 chunk is too short")]
48    ChunkTooShort,
49
50    /// The first chunk header is too short (needs at least 9 bytes).
51    #[error("IW44 first chunk header too short (need ≥ 9 bytes)")]
52    HeaderTooShort,
53
54    /// Image width or height is zero.
55    #[error("IW44 image has zero dimension")]
56    ZeroDimension,
57
58    /// Image dimensions exceed the safety limit.
59    #[error("IW44 image dimensions too large")]
60    ImageTooLarge,
61
62    /// A subsequent chunk was encountered before the first chunk.
63    #[error("IW44 subsequent chunk received before first chunk")]
64    MissingFirstChunk,
65
66    /// The subsample parameter must be >= 1.
67    #[error("IW44 subsample must be >= 1")]
68    InvalidSubsample,
69
70    /// No codec has been initialized (no chunks decoded yet).
71    #[error("IW44 codec not yet initialized")]
72    MissingCodec,
73
74    /// The ZP arithmetic coder stream is too short.
75    #[error("IW44 ZP coder stream too short")]
76    ZpTooShort,
77}
78
79// ---- Band-bucket mapping: 10 bands, each mapped to a range of buckets --------
80
81/// `BAND_BUCKETS[band]` = `(first_bucket, last_bucket)` inclusive.
82const BAND_BUCKETS: [(usize, usize); 10] = [
83    (0, 0),
84    (1, 1),
85    (2, 2),
86    (3, 3),
87    (4, 7),
88    (8, 11),
89    (12, 15),
90    (16, 31),
91    (32, 47),
92    (48, 63),
93];
94
95/// Initial quantization step table for the low-frequency band (band 0).
96const QUANT_LO_INIT: [u32; 16] = [
97    0x004000, 0x008000, 0x008000, 0x010000, 0x010000, 0x010000, 0x010000, 0x010000, 0x010000,
98    0x010000, 0x010000, 0x010000, 0x020000, 0x020000, 0x020000, 0x020000,
99];
100
101/// Initial quantization step table for high-frequency bands (bands 1–9).
102const QUANT_HI_INIT: [u32; 10] = [
103    0, 0x020000, 0x020000, 0x040000, 0x040000, 0x040000, 0x080000, 0x040000, 0x040000, 0x080000,
104];
105
106// ---- Coefficient state flags -------------------------------------------------
107
108const ZERO: u8 = 1;
109const ACTIVE: u8 = 2;
110const NEW: u8 = 4;
111const UNK: u8 = 8;
112
113// ---- Zigzag scan tables ------------------------------------------------------
114//
115// Each coefficient index `i` (0..1024) maps to a `(row, col)` within the 32×32
116// block via bit-interleaving: even bits → column, odd bits → row.
117
118const fn zigzag_row(i: usize) -> u8 {
119    let b1 = ((i >> 1) & 1) as u8;
120    let b3 = ((i >> 3) & 1) as u8;
121    let b5 = ((i >> 5) & 1) as u8;
122    let b7 = ((i >> 7) & 1) as u8;
123    let b9 = ((i >> 9) & 1) as u8;
124    b1 * 16 + b3 * 8 + b5 * 4 + b7 * 2 + b9
125}
126
127const fn zigzag_col(i: usize) -> u8 {
128    let b0 = (i & 1) as u8;
129    let b2 = ((i >> 2) & 1) as u8;
130    let b4 = ((i >> 4) & 1) as u8;
131    let b6 = ((i >> 6) & 1) as u8;
132    let b8 = ((i >> 8) & 1) as u8;
133    b0 * 16 + b2 * 8 + b4 * 4 + b6 * 2 + b8
134}
135
136/// Inverse zigzag: `ZIGZAG_INV[row * 32 + col]` is the index `i` such that
137/// `zigzag_row(i) == row as u8 && zigzag_col(i) == col as u8`.
138///
139/// Enables row-major scatter (sequential writes to the plane) at the cost of
140/// gathering block coefficients in zigzag order (2 KB block fits in L1).
141static ZIGZAG_INV: [u16; 1024] = {
142    let mut table = [0u16; 1024];
143    let mut i = 0usize;
144    while i < 1024 {
145        let r = zigzag_row(i) as usize;
146        let c = zigzag_col(i) as usize;
147        table[r * 32 + c] = i as u16;
148        i += 1;
149    }
150    table
151};
152
153/// Compact inverse zigzag for sub=2 (16×16 sub-block, 256 entries).
154/// `ZIGZAG_INV_SUB2[row * 16 + col]` = index `i` in 0..256 such that
155/// `zigzag_row(i) >> 1 == row && zigzag_col(i) >> 1 == col`.
156static ZIGZAG_INV_SUB2: [u8; 256] = {
157    let mut table = [0u8; 256];
158    let mut i = 0usize;
159    while i < 256 {
160        let r = (zigzag_row(i) >> 1) as usize;
161        let c = (zigzag_col(i) >> 1) as usize;
162        table[r * 16 + c] = i as u8;
163        i += 1;
164    }
165    table
166};
167
168/// Compact inverse zigzag for sub=4 (8×8 sub-block, 64 entries).
169/// `ZIGZAG_INV_SUB4[row * 8 + col]` = index `i` in 0..64.
170static ZIGZAG_INV_SUB4: [u8; 64] = {
171    let mut table = [0u8; 64];
172    let mut i = 0usize;
173    while i < 64 {
174        let r = (zigzag_row(i) >> 2) as usize;
175        let c = (zigzag_col(i) >> 2) as usize;
176        table[r * 8 + c] = i as u8;
177        i += 1;
178    }
179    table
180};
181
182/// Compact inverse zigzag for sub=8 (4×4 sub-block, 16 entries).
183/// `ZIGZAG_INV_SUB8[row * 4 + col]` = index `i` in 0..16.
184static ZIGZAG_INV_SUB8: [u8; 16] = {
185    let mut table = [0u8; 16];
186    let mut i = 0usize;
187    while i < 16 {
188        let r = (zigzag_row(i) >> 3) as usize;
189        let c = (zigzag_col(i) >> 3) as usize;
190        table[r * 4 + c] = i as u8;
191        i += 1;
192    }
193    table
194};
195
196// ---- Normalization -----------------------------------------------------------
197
198/// Map a raw wavelet coefficient to a signed pixel offset in `[-128, 127]`.
199#[inline]
200fn normalize(val: i16) -> i32 {
201    let v = ((val as i32) + 32) >> 6;
202    v.clamp(-128, 127)
203}
204
205// ---- SIMD YCbCr→RGBA row conversion -----------------------------------------
206//
207// Processes 8 pixels per iteration using `wide::i32x8` (maps to AVX2 on x86_64,
208// NEON on ARM64, or scalar on other targets — all in safe Rust).
209
210/// Convert one row of pre-normalized YCbCr values to RGBA using SIMD.
211///
212/// `y_row`, `cb_row`, `cr_row` are normalized i32 values in `[-128, 127]`.
213/// `out` must hold exactly `y_row.len() * 4` bytes (RGBA).
214///
215/// DjVu YCbCr→RGB formula (LeCun 1998):
216/// ```text
217/// t2    = Cr + (Cr >> 1)
218/// t3    = Y  + 128 - (Cb >> 2)
219/// R     = clamp(Y  + 128 + t2,      0, 255)
220/// G     = clamp(t3 - (t2 >> 1),     0, 255)
221/// B     = clamp(t3 + (Cb << 1),     0, 255)
222/// ```
223pub(crate) fn ycbcr_row_to_rgba(y_row: &[i32], cb_row: &[i32], cr_row: &[i32], out: &mut [u8]) {
224    debug_assert_eq!(y_row.len(), cb_row.len());
225    debug_assert_eq!(y_row.len(), cr_row.len());
226    debug_assert_eq!(out.len(), y_row.len() * 4);
227
228    let w = y_row.len();
229
230    #[cfg(target_arch = "aarch64")]
231    {
232        #[allow(unsafe_code)]
233        unsafe {
234            ycbcr_neon(
235                y_row.as_ptr(),
236                cb_row.as_ptr(),
237                cr_row.as_ptr(),
238                out.as_mut_ptr(),
239                w,
240            )
241        };
242        return;
243    }
244
245    // Portable path: chunks_exact eliminates per-element bounds checks.
246    #[allow(unreachable_code)]
247    ycbcr_portable(y_row, cb_row, cr_row, out, w);
248}
249
250/// Convert raw i16 plane row data to RGBA, fusing normalize + YCbCr in one pass.
251///
252/// Uses `ycbcr_neon_raw` on AArch64 (avoids three intermediate i32 buffers and
253/// the separate normalize loops).  Falls back to two-pass on other targets.
254///
255/// `y`, `cb`, `cr` must all have the same length `w`; `out` must hold `w * 4` bytes.
256#[inline]
257fn ycbcr_row_from_i16(y: &[i16], cb: &[i16], cr: &[i16], out: &mut [u8]) {
258    let w = y.len();
259    debug_assert_eq!(cb.len(), w);
260    debug_assert_eq!(cr.len(), w);
261    debug_assert_eq!(out.len(), w * 4);
262    #[cfg(target_arch = "aarch64")]
263    {
264        #[allow(unsafe_code)]
265        unsafe {
266            ycbcr_neon_raw(y.as_ptr(), cb.as_ptr(), cr.as_ptr(), out.as_mut_ptr(), w);
267        }
268        return;
269    }
270    // Runtime AVX2 detection requires `std` (`is_x86_feature_detected!`).
271    #[cfg(all(target_arch = "x86_64", feature = "std"))]
272    {
273        if std::is_x86_feature_detected!("avx2") {
274            #[allow(unsafe_code)]
275            unsafe {
276                ycbcr_avx2_raw(y.as_ptr(), cb.as_ptr(), cr.as_ptr(), out.as_mut_ptr(), w);
277            }
278            return;
279        }
280    }
281    // WASM simd128 is compile-time only; no runtime detection.
282    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
283    {
284        #[allow(unsafe_code)]
285        unsafe {
286            ycbcr_simd128_raw(y.as_ptr(), cb.as_ptr(), cr.as_ptr(), out.as_mut_ptr(), w);
287        }
288        return;
289    }
290    #[allow(unreachable_code)]
291    {
292        let mut y_norm = vec![0i32; w];
293        let mut cb_norm = vec![0i32; w];
294        let mut cr_norm = vec![0i32; w];
295        for (col, v) in y_norm.iter_mut().enumerate() {
296            *v = normalize(y[col]);
297        }
298        for col in 0..w {
299            cb_norm[col] = normalize(cb[col]);
300            cr_norm[col] = normalize(cr[col]);
301        }
302        ycbcr_row_to_rgba(&y_norm, &cb_norm, &cr_norm, out);
303    }
304}
305
306/// Convert raw i16 plane row data to RGBA with chroma at half horizontal resolution.
307///
308/// `y` has length ≥ `w`; `cb_half`/`cr_half` have length ≥ `(w+1)/2`.  Each
309/// chroma sample is nearest-neighbour upsampled to two adjacent output pixels.
310/// Uses `ycbcr_neon_raw_half` on AArch64; two-pass fallback elsewhere.
311#[inline]
312fn ycbcr_row_from_i16_half(y: &[i16], cb_half: &[i16], cr_half: &[i16], out: &mut [u8], w: usize) {
313    debug_assert!(y.len() >= w);
314    debug_assert_eq!(out.len(), w * 4);
315    #[cfg(target_arch = "aarch64")]
316    {
317        #[allow(unsafe_code)]
318        unsafe {
319            ycbcr_neon_raw_half(
320                y.as_ptr(),
321                cb_half.as_ptr(),
322                cr_half.as_ptr(),
323                out.as_mut_ptr(),
324                w,
325            );
326        }
327        return;
328    }
329    #[cfg(all(target_arch = "x86_64", feature = "std"))]
330    {
331        if std::is_x86_feature_detected!("avx2") {
332            #[allow(unsafe_code)]
333            unsafe {
334                ycbcr_avx2_raw_half(
335                    y.as_ptr(),
336                    cb_half.as_ptr(),
337                    cr_half.as_ptr(),
338                    out.as_mut_ptr(),
339                    w,
340                );
341            }
342            return;
343        }
344    }
345    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
346    {
347        #[allow(unsafe_code)]
348        unsafe {
349            ycbcr_simd128_raw_half(
350                y.as_ptr(),
351                cb_half.as_ptr(),
352                cr_half.as_ptr(),
353                out.as_mut_ptr(),
354                w,
355            );
356        }
357        return;
358    }
359    #[allow(unreachable_code)]
360    {
361        let mut y_norm = vec![0i32; w];
362        let mut cb_norm = vec![0i32; w];
363        let mut cr_norm = vec![0i32; w];
364        for (col, v) in y_norm.iter_mut().enumerate() {
365            *v = normalize(y[col]);
366        }
367        for col in 0..w {
368            cb_norm[col] = normalize(cb_half[col / 2]);
369            cr_norm[col] = normalize(cr_half[col / 2]);
370        }
371        ycbcr_row_to_rgba(&y_norm, &cb_norm, &cr_norm, out);
372    }
373}
374
375/// Portable YCbCr→RGBA using chunks_exact so LLVM sees exact 8-element slices.
376#[inline(always)]
377fn ycbcr_portable(y_row: &[i32], cb_row: &[i32], cr_row: &[i32], out: &mut [u8], w: usize) {
378    use wide::i32x8;
379    let c128 = i32x8::splat(128);
380    let c0 = i32x8::splat(0);
381    let c255 = i32x8::splat(255);
382
383    let full8 = w / 8;
384    for (((yc, cbc), crc), outc) in y_row[..full8 * 8]
385        .chunks_exact(8)
386        .zip(cb_row[..full8 * 8].chunks_exact(8))
387        .zip(cr_row[..full8 * 8].chunks_exact(8))
388        .zip(out[..full8 * 32].chunks_exact_mut(32))
389    {
390        let ys = i32x8::from([yc[0], yc[1], yc[2], yc[3], yc[4], yc[5], yc[6], yc[7]]);
391        let bs = i32x8::from([
392            cbc[0], cbc[1], cbc[2], cbc[3], cbc[4], cbc[5], cbc[6], cbc[7],
393        ]);
394        let rs = i32x8::from([
395            crc[0], crc[1], crc[2], crc[3], crc[4], crc[5], crc[6], crc[7],
396        ]);
397        let t2 = rs + (rs >> 1_i32);
398        let t3 = ys + c128 - (bs >> 2_i32);
399        let red = (ys + c128 + t2).max(c0).min(c255).to_array();
400        let grn = (t3 - (t2 >> 1_i32)).max(c0).min(c255).to_array();
401        let blu = (t3 + (bs << 1_i32)).max(c0).min(c255).to_array();
402        for i in 0..8 {
403            outc[i * 4] = red[i] as u8;
404            outc[i * 4 + 1] = grn[i] as u8;
405            outc[i * 4 + 2] = blu[i] as u8;
406            outc[i * 4 + 3] = 255;
407        }
408    }
409    for col in (full8 * 8)..w {
410        let y = y_row[col];
411        let b = cb_row[col];
412        let r = cr_row[col];
413        let t2 = r + (r >> 1);
414        let t3 = y + 128 - (b >> 2);
415        out[col * 4] = (y + 128 + t2).clamp(0, 255) as u8;
416        out[col * 4 + 1] = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
417        out[col * 4 + 2] = (t3 + (b << 1)).clamp(0, 255) as u8;
418        out[col * 4 + 3] = 255;
419    }
420}
421
422/// AArch64 NEON fused normalize + YCbCr→RGBA from raw i16 plane data (non-chroma-half).
423///
424/// Loads 8 i16 per channel, applies `normalize()` inline using `vrshrq_n_s16`
425/// (rounding-shift by 6, i.e. `(v+32)>>6`) and clamps to `[-128,127]`, then
426/// runs the YCbCr→RGBA formula.  Eliminates the separate normalize pass and the
427/// three intermediate i32 buffers.
428///
429/// `cbp` and `crp` must point to `w` values each (same stride as `yp`).
430#[cfg(target_arch = "aarch64")]
431#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
432#[target_feature(enable = "neon")]
433unsafe fn ycbcr_neon_raw(
434    yp: *const i16,
435    cbp: *const i16,
436    crp: *const i16,
437    outp: *mut u8,
438    w: usize,
439) {
440    use core::arch::aarch64::*;
441    // After normalize+clamp all values ∈ [-128, 127].  The YCbCr arithmetic
442    // intermediates all fit in i16 (proof: y128∈[0,255], t2∈[-192,190],
443    // t3∈[-31,287], r16∈[-192,445], g16∈[-126,383], b16∈[-287,541]).
444    // vqmovun_s16 saturates signed i16 → unsigned u8, clamping to [0,255]
445    // in one instruction — no separate min/max clamp ops needed.
446    let n_min = vdupq_n_s16(-128);
447    let n_max = vdupq_n_s16(127);
448    let c128 = vdupq_n_s16(128);
449    let alpha = vdup_n_u8(255);
450
451    let full8 = w / 8;
452    for i in 0..full8 {
453        let off = i * 8;
454        // Load + normalize (rounded right-shift by 6) + clamp to [-128, 127] at i16
455        let yc = vmaxq_s16(
456            vminq_s16(vrshrq_n_s16::<6>(vld1q_s16(yp.add(off))), n_max),
457            n_min,
458        );
459        let cbc = vmaxq_s16(
460            vminq_s16(vrshrq_n_s16::<6>(vld1q_s16(cbp.add(off))), n_max),
461            n_min,
462        );
463        let crc = vmaxq_s16(
464            vminq_s16(vrshrq_n_s16::<6>(vld1q_s16(crp.add(off))), n_max),
465            n_min,
466        );
467        // All arithmetic stays at i16 — no widening to i32 needed.
468        // y128 = y + 128, range [0, 255]
469        let y128 = vaddq_s16(yc, c128);
470        // t2 = cr + (cr >> 1) = 1.5·cr, range [-192, 190]
471        let t2 = vaddq_s16(crc, vshrq_n_s16::<1>(crc));
472        // t3 = y128 - (cb >> 2), range [-31, 287]
473        let t3 = vsubq_s16(y128, vshrq_n_s16::<2>(cbc));
474        // R = y128 + t2, range [-192, 445]
475        let r16 = vaddq_s16(y128, t2);
476        // G = t3 - (t2 >> 1), range [-126, 383]
477        let g16 = vsubq_s16(t3, vshrq_n_s16::<1>(t2));
478        // B = t3 + 2·cb, range [-287, 541]
479        let b16 = vaddq_s16(t3, vshlq_n_s16::<1>(cbc));
480        // Saturating narrow signed i16 → unsigned u8 (clamps to [0, 255])
481        let r8 = vqmovun_s16(r16);
482        let g8 = vqmovun_s16(g16);
483        let b8 = vqmovun_s16(b16);
484        vst4_u8(outp.add(off * 4), uint8x8x4_t(r8, g8, b8, alpha));
485    }
486    // Scalar tail
487    for col in (full8 * 8)..w {
488        let y = normalize(*yp.add(col));
489        let b = normalize(*cbp.add(col));
490        let r = normalize(*crp.add(col));
491        let t2 = r + (r >> 1);
492        let t3 = y + 128 - (b >> 2);
493        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
494        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
495        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
496        *outp.add(col * 4 + 3) = 255;
497    }
498}
499
500/// AArch64 NEON fused normalize + YCbCr→RGBA from raw i16 plane data (chroma-half).
501///
502/// `cbp` and `crp` point to chroma planes at half the horizontal resolution.
503/// Each chroma sample is nearest-neighbour upsampled to two luma columns.
504/// 8 output pixels are produced per iteration, consuming 8 Y samples and 4 Cb/Cr samples.
505#[cfg(target_arch = "aarch64")]
506#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
507#[target_feature(enable = "neon")]
508unsafe fn ycbcr_neon_raw_half(
509    yp: *const i16,
510    cbp: *const i16,
511    crp: *const i16,
512    outp: *mut u8,
513    w: usize,
514) {
515    use core::arch::aarch64::*;
516    // Same i16 arithmetic as ycbcr_neon_raw — all intermediates fit in i16.
517    let n_min = vdupq_n_s16(-128);
518    let n_max = vdupq_n_s16(127);
519    let c128 = vdupq_n_s16(128);
520    let alpha = vdup_n_u8(255);
521
522    let full8 = w / 8;
523    for i in 0..full8 {
524        let off = i * 8;
525        let c_off = i * 4;
526        // Load + normalize Y (8 consecutive)
527        let yc = vmaxq_s16(
528            vminq_s16(vrshrq_n_s16::<6>(vld1q_s16(yp.add(off))), n_max),
529            n_min,
530        );
531        // Load 4 chroma values, normalize at i16 level, then upsample 4→8 by
532        // duplicating each value: [a,b,c,d] → [a,a,b,b,c,c,d,d] via vzip1q
533        let cb4 = vmaxq_s16(
534            vminq_s16(
535                vrshrq_n_s16::<6>(vcombine_s16(vld1_s16(cbp.add(c_off)), vdup_n_s16(0))),
536                n_max,
537            ),
538            n_min,
539        );
540        let cr4 = vmaxq_s16(
541            vminq_s16(
542                vrshrq_n_s16::<6>(vcombine_s16(vld1_s16(crp.add(c_off)), vdup_n_s16(0))),
543                n_max,
544            ),
545            n_min,
546        );
547        // Upsample: interleave low 4 lanes with themselves → [a,a,b,b,c,c,d,d]
548        let cbc = vzip1q_s16(cb4, cb4);
549        let crc = vzip1q_s16(cr4, cr4);
550        // All arithmetic at i16 level (same ranges as non-half path after upsample)
551        let y128 = vaddq_s16(yc, c128);
552        let t2 = vaddq_s16(crc, vshrq_n_s16::<1>(crc));
553        let t3 = vsubq_s16(y128, vshrq_n_s16::<2>(cbc));
554        let r16 = vaddq_s16(y128, t2);
555        let g16 = vsubq_s16(t3, vshrq_n_s16::<1>(t2));
556        let b16 = vaddq_s16(t3, vshlq_n_s16::<1>(cbc));
557        let r8 = vqmovun_s16(r16);
558        let g8 = vqmovun_s16(g16);
559        let b8 = vqmovun_s16(b16);
560        vst4_u8(outp.add(off * 4), uint8x8x4_t(r8, g8, b8, alpha));
561    }
562    // Scalar tail
563    for col in (full8 * 8)..w {
564        let y = normalize(*yp.add(col));
565        let b = normalize(*cbp.add(col / 2));
566        let r = normalize(*crp.add(col / 2));
567        let t2 = r + (r >> 1);
568        let t3 = y + 128 - (b >> 2);
569        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
570        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
571        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
572        *outp.add(col * 4 + 3) = 255;
573    }
574}
575
576/// x86_64 AVX2 fused normalize + YCbCr→RGBA from raw i16 plane data (non-chroma-half).
577///
578/// 16 pixels per iteration (vs NEON's 8): __m256i holds 16 i16. Pack-down to u8
579/// is done via SSE `_mm_packus_epi16` on the two 128-bit halves followed by an
580/// SSE byte-interleave to materialise R/G/B/A → RGBA bytes.
581///
582/// `cbp` and `crp` must point to `w` values each (same stride as `yp`).
583#[cfg(all(target_arch = "x86_64", feature = "std"))]
584#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
585#[target_feature(enable = "avx2")]
586unsafe fn ycbcr_avx2_raw(
587    yp: *const i16,
588    cbp: *const i16,
589    crp: *const i16,
590    outp: *mut u8,
591    w: usize,
592) {
593    use core::arch::x86_64::*;
594    let n_min = _mm256_set1_epi16(-128);
595    let n_max = _mm256_set1_epi16(127);
596    let c128 = _mm256_set1_epi16(128);
597    let one = _mm256_set1_epi16(1);
598
599    let full16 = w / 16;
600    for i in 0..full16 {
601        let off = i * 16;
602        // Rounding right shift by 6 + clamp to [-128, 127].
603        // Equivalent to scalar `((v as i32 + 32) >> 6).clamp(-128, 127)` and to NEON
604        // `vrshrq_n_s16::<6>` followed by clamp.  We compute it at i16 width without
605        // overflow as `(v >> 6) + ((v as u16 >> 5) & 1)` — the bit-5 logical-shifted
606        // term is the round-half-away-from-zero correction and matches the wider
607        // intermediate that NEON / scalar use.
608        let load_norm_clamp = |p: *const i16| -> __m256i {
609            let v = _mm256_loadu_si256(p as *const __m256i);
610            let high = _mm256_srai_epi16::<6>(v);
611            let bit5 = _mm256_and_si256(_mm256_srli_epi16::<5>(v), one);
612            let n = _mm256_add_epi16(high, bit5);
613            _mm256_max_epi16(_mm256_min_epi16(n, n_max), n_min)
614        };
615        let yc = load_norm_clamp(yp.add(off));
616        let cbc = load_norm_clamp(cbp.add(off));
617        let crc = load_norm_clamp(crp.add(off));
618
619        // Same i16 arithmetic as NEON path; ranges fit in i16 → no widening.
620        let y128 = _mm256_add_epi16(yc, c128);
621        let t2 = _mm256_add_epi16(crc, _mm256_srai_epi16::<1>(crc));
622        let t3 = _mm256_sub_epi16(y128, _mm256_srai_epi16::<2>(cbc));
623        let r16 = _mm256_add_epi16(y128, t2);
624        let g16 = _mm256_sub_epi16(t3, _mm256_srai_epi16::<1>(t2));
625        let b16 = _mm256_add_epi16(t3, _mm256_slli_epi16::<1>(cbc));
626
627        // Saturating narrow signed i16 → unsigned u8 in halves (clamps to [0, 255])
628        let r_pack = _mm_packus_epi16(
629            _mm256_castsi256_si128(r16),
630            _mm256_extracti128_si256::<1>(r16),
631        );
632        let g_pack = _mm_packus_epi16(
633            _mm256_castsi256_si128(g16),
634            _mm256_extracti128_si256::<1>(g16),
635        );
636        let b_pack = _mm_packus_epi16(
637            _mm256_castsi256_si128(b16),
638            _mm256_extracti128_si256::<1>(b16),
639        );
640        let a_pack = _mm_set1_epi8(-1i8);
641
642        // Interleave R/G and B/A into pairs, then unpack i16 to materialise RGBA.
643        let rg_lo = _mm_unpacklo_epi8(r_pack, g_pack);
644        let rg_hi = _mm_unpackhi_epi8(r_pack, g_pack);
645        let ba_lo = _mm_unpacklo_epi8(b_pack, a_pack);
646        let ba_hi = _mm_unpackhi_epi8(b_pack, a_pack);
647
648        let rgba0 = _mm_unpacklo_epi16(rg_lo, ba_lo);
649        let rgba1 = _mm_unpackhi_epi16(rg_lo, ba_lo);
650        let rgba2 = _mm_unpacklo_epi16(rg_hi, ba_hi);
651        let rgba3 = _mm_unpackhi_epi16(rg_hi, ba_hi);
652
653        let dst = outp.add(off * 4) as *mut __m128i;
654        _mm_storeu_si128(dst, rgba0);
655        _mm_storeu_si128(dst.add(1), rgba1);
656        _mm_storeu_si128(dst.add(2), rgba2);
657        _mm_storeu_si128(dst.add(3), rgba3);
658    }
659    // Scalar tail
660    for col in (full16 * 16)..w {
661        let y = normalize(*yp.add(col));
662        let b = normalize(*cbp.add(col));
663        let r = normalize(*crp.add(col));
664        let t2 = r + (r >> 1);
665        let t3 = y + 128 - (b >> 2);
666        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
667        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
668        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
669        *outp.add(col * 4 + 3) = 255;
670    }
671}
672
673/// x86_64 AVX2 fused normalize + YCbCr→RGBA from raw i16 plane data (chroma-half).
674///
675/// 16 Y / 8 chroma per iteration. Chroma upsample uses `_mm256_permute4x64_epi64`
676/// to place chromas 0-3 in the low 128-bit lane low half and chromas 4-7 in the
677/// high 128-bit lane low half, then `_mm256_unpacklo_epi16(v, v)` duplicates each
678/// chroma into two adjacent i16 lanes per 128-bit half.
679#[cfg(all(target_arch = "x86_64", feature = "std"))]
680#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
681#[target_feature(enable = "avx2")]
682unsafe fn ycbcr_avx2_raw_half(
683    yp: *const i16,
684    cbp: *const i16,
685    crp: *const i16,
686    outp: *mut u8,
687    w: usize,
688) {
689    use core::arch::x86_64::*;
690    let n_min = _mm256_set1_epi16(-128);
691    let n_max = _mm256_set1_epi16(127);
692    let c128 = _mm256_set1_epi16(128);
693    let one = _mm256_set1_epi16(1);
694
695    // Overflow-safe rounding right shift by 6 + clamp to [-128, 127];
696    // see `ycbcr_avx2_raw` for the equivalence proof.
697    let norm_clamp = |v: __m256i| -> __m256i {
698        let high = _mm256_srai_epi16::<6>(v);
699        let bit5 = _mm256_and_si256(_mm256_srli_epi16::<5>(v), one);
700        let n = _mm256_add_epi16(high, bit5);
701        _mm256_max_epi16(_mm256_min_epi16(n, n_max), n_min)
702    };
703
704    let full16 = w / 16;
705    for i in 0..full16 {
706        let off = i * 16;
707        let c_off = i * 8;
708
709        // Load + normalize 16 Y samples
710        let yv = _mm256_loadu_si256(yp.add(off) as *const __m256i);
711        let yc = norm_clamp(yv);
712
713        // Load 8 chroma i16 (one __m128i), upsample to 16 by duplicating each.
714        let upsample = |p: *const i16| -> __m256i {
715            let v8 = _mm_loadu_si128(p as *const __m128i);
716            // Place i16s 0-3 into i64-lane 0 (already there), i16s 4-7 into i64-lane 2.
717            // permute4x64 mask 0b00_01_00_00: out0←src0, out1←src0, out2←src1, out3←src0.
718            let spread = _mm256_permute4x64_epi64::<0b00_01_00_00>(_mm256_castsi128_si256(v8));
719            // Per-128-bit-lane interleave with itself: duplicates each i16 lane.
720            _mm256_unpacklo_epi16(spread, spread)
721        };
722        let cbc = norm_clamp(upsample(cbp.add(c_off)));
723        let crc = norm_clamp(upsample(crp.add(c_off)));
724
725        let y128 = _mm256_add_epi16(yc, c128);
726        let t2 = _mm256_add_epi16(crc, _mm256_srai_epi16::<1>(crc));
727        let t3 = _mm256_sub_epi16(y128, _mm256_srai_epi16::<2>(cbc));
728        let r16 = _mm256_add_epi16(y128, t2);
729        let g16 = _mm256_sub_epi16(t3, _mm256_srai_epi16::<1>(t2));
730        let b16 = _mm256_add_epi16(t3, _mm256_slli_epi16::<1>(cbc));
731
732        let r_pack = _mm_packus_epi16(
733            _mm256_castsi256_si128(r16),
734            _mm256_extracti128_si256::<1>(r16),
735        );
736        let g_pack = _mm_packus_epi16(
737            _mm256_castsi256_si128(g16),
738            _mm256_extracti128_si256::<1>(g16),
739        );
740        let b_pack = _mm_packus_epi16(
741            _mm256_castsi256_si128(b16),
742            _mm256_extracti128_si256::<1>(b16),
743        );
744        let a_pack = _mm_set1_epi8(-1i8);
745
746        let rg_lo = _mm_unpacklo_epi8(r_pack, g_pack);
747        let rg_hi = _mm_unpackhi_epi8(r_pack, g_pack);
748        let ba_lo = _mm_unpacklo_epi8(b_pack, a_pack);
749        let ba_hi = _mm_unpackhi_epi8(b_pack, a_pack);
750
751        let rgba0 = _mm_unpacklo_epi16(rg_lo, ba_lo);
752        let rgba1 = _mm_unpackhi_epi16(rg_lo, ba_lo);
753        let rgba2 = _mm_unpacklo_epi16(rg_hi, ba_hi);
754        let rgba3 = _mm_unpackhi_epi16(rg_hi, ba_hi);
755
756        let dst = outp.add(off * 4) as *mut __m128i;
757        _mm_storeu_si128(dst, rgba0);
758        _mm_storeu_si128(dst.add(1), rgba1);
759        _mm_storeu_si128(dst.add(2), rgba2);
760        _mm_storeu_si128(dst.add(3), rgba3);
761    }
762    // Scalar tail
763    for col in (full16 * 16)..w {
764        let y = normalize(*yp.add(col));
765        let b = normalize(*cbp.add(col / 2));
766        let r = normalize(*crp.add(col / 2));
767        let t2 = r + (r >> 1);
768        let t3 = y + 128 - (b >> 2);
769        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
770        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
771        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
772        *outp.add(col * 4 + 3) = 255;
773    }
774}
775
776/// WASM simd128 fused normalize + YCbCr→RGBA from raw i16 plane data (non-chroma-half).
777///
778/// 8 pixels per iteration, mirroring the AArch64 NEON kernel byte-for-byte.
779/// `v128` is 128 bits → 8×i16, same width as NEON's `int16x8_t`. Saturating
780/// signed-i16 → unsigned-u8 narrow is one instruction (`u8x16_narrow_i16x8`),
781/// equivalent to NEON `vqmovun_s16`.
782///
783/// RGBA byte-interleave is materialised via two `i8x16_shuffle` calls
784/// (constant-mask shuffle, 16 lanes each, picking from {r/g pack, b/alpha pack}).
785/// WASM has no `vst4`-equivalent; the shuffle pair is the simd128 idiom.
786///
787/// `cbp` and `crp` must point to `w` values each (same stride as `yp`).
788#[cfg(target_arch = "wasm32")]
789#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
790#[target_feature(enable = "simd128")]
791unsafe fn ycbcr_simd128_raw(
792    yp: *const i16,
793    cbp: *const i16,
794    crp: *const i16,
795    outp: *mut u8,
796    w: usize,
797) {
798    use core::arch::wasm32::*;
799    let n_min = i16x8_splat(-128);
800    let n_max = i16x8_splat(127);
801    let c128 = i16x8_splat(128);
802    let one = i16x8_splat(1);
803    // Saturating-narrow input ≥ 255 → 255, so any sentinel ≥ 255 produces the
804    // alpha byte without a separate splat-store path.
805    let alpha_src = i16x8_splat(255);
806
807    let full8 = w / 8;
808    for i in 0..full8 {
809        let off = i * 8;
810        // Rounding right shift by 6 + clamp to [-128, 127].
811        // Same overflow-safe form as the AVX2 path: `(v >> 6) + ((v as u16 >> 5) & 1)`.
812        // Avoids the i16-overflow that would happen with `(v + 32) >> 6` for v near
813        // `i16::MAX` and matches the wider intermediate that NEON `vrshrq_n_s16` uses.
814        let load_norm_clamp = |p: *const i16| -> v128 {
815            let v = v128_load(p as *const v128);
816            let high = i16x8_shr(v, 6);
817            let bit5 = v128_and(u16x8_shr(v, 5), one);
818            let n = i16x8_add(high, bit5);
819            i16x8_max(i16x8_min(n, n_max), n_min)
820        };
821        let yc = load_norm_clamp(yp.add(off));
822        let cbc = load_norm_clamp(cbp.add(off));
823        let crc = load_norm_clamp(crp.add(off));
824
825        // Same i16 arithmetic as NEON / AVX2 — all intermediates fit in i16.
826        let y128 = i16x8_add(yc, c128);
827        let t2 = i16x8_add(crc, i16x8_shr(crc, 1));
828        let t3 = i16x8_sub(y128, i16x8_shr(cbc, 2));
829        let r16 = i16x8_add(y128, t2);
830        let g16 = i16x8_sub(t3, i16x8_shr(t2, 1));
831        let b16 = i16x8_add(t3, i16x8_shl(cbc, 1));
832
833        // Saturating signed→unsigned narrow: i16x8 → u8x16 (clamps to [0, 255]).
834        // Pack two i16x8 vectors into one u8x16 in a single op — exactly NEON's
835        // `vqmovun_s16` semantics, just in the wider 16-lane form.
836        let v_rg = u8x16_narrow_i16x8(r16, g16);
837        let v_ba = u8x16_narrow_i16x8(b16, alpha_src);
838
839        // Interleave to RGBA: pixel n = (r_n, g_n, b_n, a_n).
840        // v_rg lanes: 0..7 = r, 8..15 = g. v_ba lanes: 0..7 = b, 8..15 = 255.
841        // Constant byte-shuffle picks {r_n=v_rg[n], g_n=v_rg[n+8], b_n=v_ba[n+0], a_n=v_ba[n+8]}.
842        let out0 =
843            i8x16_shuffle::<0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27>(v_rg, v_ba);
844        let out1 =
845            i8x16_shuffle::<4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31>(v_rg, v_ba);
846
847        v128_store(outp.add(off * 4) as *mut v128, out0);
848        v128_store(outp.add(off * 4 + 16) as *mut v128, out1);
849    }
850    // Scalar tail
851    for col in (full8 * 8)..w {
852        let y = normalize(*yp.add(col));
853        let b = normalize(*cbp.add(col));
854        let r = normalize(*crp.add(col));
855        let t2 = r + (r >> 1);
856        let t3 = y + 128 - (b >> 2);
857        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
858        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
859        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
860        *outp.add(col * 4 + 3) = 255;
861    }
862}
863
864/// WASM simd128 fused normalize + YCbCr→RGBA, chroma-half variant.
865///
866/// 8 luma + 4 chroma per iteration. Chroma is loaded as 8 bytes via
867/// `v128_load64_zero` (low half = 4 i16, high half = 0), normalized at
868/// i16 width across all 8 lanes (high lanes normalize to 0, unused), and
869/// nearest-neighbour upsampled to 8 lanes via a constant byte shuffle that
870/// duplicates each of the low 4 i16 lanes (`[a,b,c,d,_,_,_,_]` → `[a,a,b,b,c,c,d,d]`).
871#[cfg(target_arch = "wasm32")]
872#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
873#[target_feature(enable = "simd128")]
874unsafe fn ycbcr_simd128_raw_half(
875    yp: *const i16,
876    cbp: *const i16,
877    crp: *const i16,
878    outp: *mut u8,
879    w: usize,
880) {
881    use core::arch::wasm32::*;
882    let n_min = i16x8_splat(-128);
883    let n_max = i16x8_splat(127);
884    let c128 = i16x8_splat(128);
885    let one = i16x8_splat(1);
886    let alpha_src = i16x8_splat(255);
887
888    let full8 = w / 8;
889    for i in 0..full8 {
890        let off = i * 8;
891        let c_off = i * 4;
892        // Y: full 8-lane load + normalize (same as non-half path).
893        let load_norm_clamp = |p: *const i16| -> v128 {
894            let v = v128_load(p as *const v128);
895            let high = i16x8_shr(v, 6);
896            let bit5 = v128_and(u16x8_shr(v, 5), one);
897            let n = i16x8_add(high, bit5);
898            i16x8_max(i16x8_min(n, n_max), n_min)
899        };
900        let yc = load_norm_clamp(yp.add(off));
901
902        // Chroma: load 4 i16 = 8 bytes into low half of v128, zero upper half.
903        // Normalize on the full vector (upper 4 lanes normalize to 0, harmless).
904        let load_norm_chroma_4 = |p: *const i16| -> v128 {
905            let v = v128_load64_zero(p as *const u64);
906            let high = i16x8_shr(v, 6);
907            let bit5 = v128_and(u16x8_shr(v, 5), one);
908            let n = i16x8_add(high, bit5);
909            i16x8_max(i16x8_min(n, n_max), n_min)
910        };
911        let cb4 = load_norm_chroma_4(cbp.add(c_off));
912        let cr4 = load_norm_chroma_4(crp.add(c_off));
913
914        // Upsample each i16 lane into a pair (`zip-low` of self+self).
915        // Byte-level shuffle: bytes 0,1 → 0,1,2,3 ; 2,3 → 4,5,6,7 ; etc.
916        let cbc = i8x16_shuffle::<0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7>(cb4, cb4);
917        let crc = i8x16_shuffle::<0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7>(cr4, cr4);
918
919        let y128 = i16x8_add(yc, c128);
920        let t2 = i16x8_add(crc, i16x8_shr(crc, 1));
921        let t3 = i16x8_sub(y128, i16x8_shr(cbc, 2));
922        let r16 = i16x8_add(y128, t2);
923        let g16 = i16x8_sub(t3, i16x8_shr(t2, 1));
924        let b16 = i16x8_add(t3, i16x8_shl(cbc, 1));
925
926        let v_rg = u8x16_narrow_i16x8(r16, g16);
927        let v_ba = u8x16_narrow_i16x8(b16, alpha_src);
928        let out0 =
929            i8x16_shuffle::<0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27>(v_rg, v_ba);
930        let out1 =
931            i8x16_shuffle::<4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31>(v_rg, v_ba);
932        v128_store(outp.add(off * 4) as *mut v128, out0);
933        v128_store(outp.add(off * 4 + 16) as *mut v128, out1);
934    }
935    for col in (full8 * 8)..w {
936        let y = normalize(*yp.add(col));
937        let b = normalize(*cbp.add(col / 2));
938        let r = normalize(*crp.add(col / 2));
939        let t2 = r + (r >> 1);
940        let t3 = y + 128 - (b >> 2);
941        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
942        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
943        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
944        *outp.add(col * 4 + 3) = 255;
945    }
946}
947
948/// AArch64 NEON: 6× vld1q_s32 + SIMD arithmetic + vst4_u8 per 8 pixels.
949/// Replaces 80+ bounds-check branches per 8 pixels in the LLVM-generated portable code.
950#[cfg(target_arch = "aarch64")]
951#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
952#[target_feature(enable = "neon")]
953unsafe fn ycbcr_neon(yp: *const i32, cbp: *const i32, crp: *const i32, outp: *mut u8, w: usize) {
954    use core::arch::aarch64::*;
955    let c128 = vdupq_n_s32(128);
956    let c0 = vdupq_n_s32(0);
957    let c255 = vdupq_n_s32(255);
958    let alpha = vdup_n_u8(255);
959
960    let full8 = w / 8;
961    for i in 0..full8 {
962        let off = i * 8;
963        // Load 8 × i32 from each channel (2 × vld1q_s32 = one cache line per channel)
964        let y_lo = vld1q_s32(yp.add(off));
965        let y_hi = vld1q_s32(yp.add(off + 4));
966        let cb_lo = vld1q_s32(cbp.add(off));
967        let cb_hi = vld1q_s32(cbp.add(off + 4));
968        let cr_lo = vld1q_s32(crp.add(off));
969        let cr_hi = vld1q_s32(crp.add(off + 4));
970
971        // t2 = cr + (cr >> 1)
972        let t2_lo = vaddq_s32(cr_lo, vshrq_n_s32::<1>(cr_lo));
973        let t2_hi = vaddq_s32(cr_hi, vshrq_n_s32::<1>(cr_hi));
974        // t3 = y + 128 - (cb >> 2)
975        let t3_lo = vsubq_s32(vaddq_s32(y_lo, c128), vshrq_n_s32::<2>(cb_lo));
976        let t3_hi = vsubq_s32(vaddq_s32(y_hi, c128), vshrq_n_s32::<2>(cb_hi));
977
978        // red = clamp(y + 128 + t2)
979        let r_lo = vminq_s32(vmaxq_s32(vaddq_s32(vaddq_s32(y_lo, c128), t2_lo), c0), c255);
980        let r_hi = vminq_s32(vmaxq_s32(vaddq_s32(vaddq_s32(y_hi, c128), t2_hi), c0), c255);
981        // green = clamp(t3 - (t2 >> 1))
982        let g_lo = vminq_s32(
983            vmaxq_s32(vsubq_s32(t3_lo, vshrq_n_s32::<1>(t2_lo)), c0),
984            c255,
985        );
986        let g_hi = vminq_s32(
987            vmaxq_s32(vsubq_s32(t3_hi, vshrq_n_s32::<1>(t2_hi)), c0),
988            c255,
989        );
990        // blue = clamp(t3 + (cb << 1))
991        let b_lo = vminq_s32(
992            vmaxq_s32(vaddq_s32(t3_lo, vshlq_n_s32::<1>(cb_lo)), c0),
993            c255,
994        );
995        let b_hi = vminq_s32(
996            vmaxq_s32(vaddq_s32(t3_hi, vshlq_n_s32::<1>(cb_hi)), c0),
997            c255,
998        );
999
1000        // Narrow i32×4 → i16×4 → u8×8 for each channel
1001        let r8 = vqmovun_s16(vcombine_s16(vmovn_s32(r_lo), vmovn_s32(r_hi)));
1002        let g8 = vqmovun_s16(vcombine_s16(vmovn_s32(g_lo), vmovn_s32(g_hi)));
1003        let b8 = vqmovun_s16(vcombine_s16(vmovn_s32(b_lo), vmovn_s32(b_hi)));
1004
1005        // Store 8 RGBA pixels (32 bytes) interleaved via vst4_u8
1006        vst4_u8(outp.add(off * 4), uint8x8x4_t(r8, g8, b8, alpha));
1007    }
1008
1009    // Scalar tail
1010    for col in (full8 * 8)..w {
1011        let y = *yp.add(col);
1012        let b = *cbp.add(col);
1013        let r = *crp.add(col);
1014        let t2 = r + (r >> 1);
1015        let t3 = y + 128 - (b >> 2);
1016        *outp.add(col * 4) = (y + 128 + t2).clamp(0, 255) as u8;
1017        *outp.add(col * 4 + 1) = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
1018        *outp.add(col * 4 + 2) = (t3 + (b << 1)).clamp(0, 255) as u8;
1019        *outp.add(col * 4 + 3) = 255;
1020    }
1021}
1022
1023// ---- Per-channel wavelet decoder --------------------------------------------
1024
1025/// State for a single YCbCr plane wavelet decoder.
1026///
1027/// Holds 32×32 block coefficients and the ZP context tables that persist
1028/// across progressive slices.
1029#[derive(Clone, Debug)]
1030struct PlaneDecoder {
1031    width: usize,
1032    height: usize,
1033    block_cols: usize,
1034    /// Row-major array of 32×32 blocks; each block holds 1024 i16 coefficients
1035    /// in zigzag-scan order.
1036    blocks: Vec<[i16; 1024]>,
1037    quant_lo: [u32; 16],
1038    quant_hi: [u32; 10],
1039    /// Current band index (0..10, wraps around).
1040    curband: usize,
1041    // ZP context bytes — persistent across slices and chunks.
1042    ctx_decode_bucket: [u8; 1],
1043    ctx_decode_coef: [u8; 80],
1044    ctx_activate_coef: [u8; 16],
1045    ctx_increase_coef: [u8; 1],
1046    // Per-block temporary decode state (re-used each block, not persisted).
1047    coeffstate: [[u8; 16]; 16],
1048    bucketstate: [u8; 16],
1049    bbstate: u8,
1050}
1051
1052/// Map 16 i16 coefficients at `block[base..]` to UNK/ACTIVE flags, store in `bucket`,
1053/// and return the OR of all flag bytes (bstatetmp).
1054///
1055/// Dispatches to NEON on aarch64, AVX2 on x86_64 when available, else scalar.
1056#[allow(unsafe_code)]
1057#[inline(always)]
1058fn prelim_flags_bucket(block: &[i16; 1024], base: usize, bucket: &mut [u8; 16]) -> u8 {
1059    #[cfg(target_arch = "aarch64")]
1060    // SAFETY: NEON is mandatory on aarch64; `base + 16 <= 1024` is guaranteed by
1061    // BAND_BUCKETS (max bucket index 63, so base = 63 * 16 = 1008, 1008 + 16 = 1024).
1062    return unsafe { prelim_flags_bucket_neon(block, base, bucket) };
1063
1064    #[cfg(all(target_arch = "x86_64", feature = "std"))]
1065    {
1066        if std::is_x86_feature_detected!("avx2") {
1067            // SAFETY: AVX2 was just feature-detected; `base + 16 <= 1024` per BAND_BUCKETS.
1068            return unsafe { prelim_flags_bucket_avx2(block, base, bucket) };
1069        }
1070    }
1071
1072    #[cfg_attr(target_arch = "aarch64", allow(unreachable_code))]
1073    {
1074        let mut bstate = 0u8;
1075        for k in 0..16 {
1076            let f = if block[base + k] == 0 { UNK } else { ACTIVE };
1077            bucket[k] = f;
1078            bstate |= f;
1079        }
1080        bstate
1081    }
1082}
1083
1084/// NEON-vectorized version of `prelim_flags_bucket` for aarch64.
1085///
1086/// Loads 16 i16 values, compares to zero with NEON, narrows to u8 flags
1087/// (UNK=8 for zero, ACTIVE=2 for non-zero), stores, and OR-reduces to bstatetmp.
1088#[cfg(target_arch = "aarch64")]
1089#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1090#[target_feature(enable = "neon")]
1091unsafe fn prelim_flags_bucket_neon(block: &[i16; 1024], base: usize, bucket: &mut [u8; 16]) -> u8 {
1092    use core::arch::aarch64::*;
1093    let ptr = block.as_ptr().add(base);
1094    // Load as u16 — zero-comparison is the same for signed and unsigned 16-bit.
1095    let c0 = vreinterpretq_u16_s16(vld1q_s16(ptr));
1096    let c1 = vreinterpretq_u16_s16(vld1q_s16(ptr.add(8)));
1097    // nz: 0xFFFF where coef != 0, 0x0000 where coef == 0
1098    let zero = vdupq_n_u16(0);
1099    let nz0 = vmvnq_u16(vceqq_u16(c0, zero));
1100    let nz1 = vmvnq_u16(vceqq_u16(c1, zero));
1101    // result = UNK ^ ((UNK ^ ACTIVE) & nz)  ⟹  UNK(8) if zero, ACTIVE(2) if nonzero
1102    // UNK ^ ACTIVE = 8 ^ 2 = 10
1103    let xv = vdupq_n_u16(10);
1104    let uv = vdupq_n_u16(8);
1105    let r0 = veorq_u16(uv, vandq_u16(xv, nz0));
1106    let r1 = veorq_u16(uv, vandq_u16(xv, nz1));
1107    // Narrow u16 → u8 (values 2 and 8 both fit; high byte of each lane is 0)
1108    let out = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
1109    vst1q_u8(bucket.as_mut_ptr(), out);
1110    // Horizontal OR: fold 16 u8 lanes to 1
1111    let lo = vget_low_u8(out);
1112    let hi = vget_high_u8(out);
1113    let v4 = vorr_u8(lo, hi);
1114    let v2 = vorr_u8(v4, vext_u8::<4>(v4, v4));
1115    let v1 = vorr_u8(v2, vext_u8::<2>(v2, v2));
1116    let v0 = vorr_u8(v1, vext_u8::<1>(v1, v1));
1117    vget_lane_u8::<0>(v0)
1118}
1119
1120/// AVX2-vectorized version of `prelim_flags_bucket` for x86_64.
1121///
1122/// Loads 16 i16 in one `__m256i`, compares to zero with `_mm256_cmpeq_epi16`,
1123/// builds UNK/ACTIVE flags via `uv ^ (xv & nz)` where UNK=8 and XV=10
1124/// (= UNK ^ ACTIVE), narrows to 16 u8 with `_mm_packus_epi16` (saturating but
1125/// values 2/8 fit), stores via `_mm_storeu_si128`, and horizontally OR-reduces
1126/// the 16 bytes to one byte via shift+OR.
1127///
1128/// Mirror of `prelim_flags_bucket_neon` — same operations, AVX2 lanes.
1129#[cfg(all(target_arch = "x86_64", feature = "std"))]
1130#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1131#[target_feature(enable = "avx2")]
1132unsafe fn prelim_flags_bucket_avx2(block: &[i16; 1024], base: usize, bucket: &mut [u8; 16]) -> u8 {
1133    use core::arch::x86_64::*;
1134    // Load 16 contiguous i16 (32 bytes) at block[base..base+16].
1135    let coefs = _mm256_loadu_si256(block.as_ptr().add(base) as *const __m256i);
1136    // eq: 0xFFFF where coef == 0, 0x0000 where != 0.
1137    let zero = _mm256_setzero_si256();
1138    let eq = _mm256_cmpeq_epi16(coefs, zero);
1139    // nz = !eq.  cmpeq(x, x) == all-ones.
1140    let all_ones = _mm256_cmpeq_epi16(zero, zero);
1141    let nz = _mm256_xor_si256(eq, all_ones);
1142    // result = UNK ^ ((UNK ^ ACTIVE) & nz)  ⟹  UNK(8) if zero, ACTIVE(2) if nonzero.
1143    let xv = _mm256_set1_epi16(10);
1144    let uv = _mm256_set1_epi16(8);
1145    let r16 = _mm256_xor_si256(uv, _mm256_and_si256(xv, nz));
1146    // Narrow u16 → u8: pack the two 128-bit halves.  `_mm_packus_epi16` saturates
1147    // to [0, 255] but our values are 2 or 8 — equivalent to truncation here.
1148    let r_lo = _mm256_castsi256_si128(r16);
1149    let r_hi = _mm256_extracti128_si256::<1>(r16);
1150    let packed = _mm_packus_epi16(r_lo, r_hi);
1151    _mm_storeu_si128(bucket.as_mut_ptr() as *mut __m128i, packed);
1152    // Horizontal OR of 16 u8 lanes → 1 byte via successive shift+OR.
1153    let or64 = _mm_or_si128(packed, _mm_unpackhi_epi64(packed, packed));
1154    let or32 = _mm_or_si128(or64, _mm_srli_si128::<4>(or64));
1155    let or16_red = _mm_or_si128(or32, _mm_srli_si128::<2>(or32));
1156    let or8 = _mm_or_si128(or16_red, _mm_srli_si128::<1>(or16_red));
1157    _mm_extract_epi8::<0>(or8) as u8
1158}
1159
1160/// NEON-vectorized band-0 path of `preliminary_flag_computation`.
1161///
1162/// Band 0 differs from bands 1-9: only update entries where `old_flags[k] != ZERO (1)`.
1163/// Uses `vbslq_u8` to blend new flags (UNK/ACTIVE from coef) with old flags (keep ZERO).
1164#[cfg(target_arch = "aarch64")]
1165#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1166#[target_feature(enable = "neon")]
1167unsafe fn prelim_flags_band0_neon(block: &[i16; 1024], old_flags: &mut [u8; 16]) -> u8 {
1168    use core::arch::aarch64::*;
1169    // Load old coeffstate[0] (u8 flags: ZERO=1, UNK=8, ACTIVE=2).
1170    let old_u8 = vld1q_u8(old_flags.as_ptr());
1171    // should_update mask: 0xFF where old_flags[k] != ZERO(1), 0x00 where == ZERO
1172    let one_u8 = vdupq_n_u8(1);
1173    let is_zero_state = vceqq_u8(old_u8, one_u8); // 0xFF where ZERO, 0x00 elsewhere
1174    let should_update = vmvnq_u8(is_zero_state); // 0xFF where not-ZERO
1175    // Compute new flags from first 16 coefs (same as prelim_flags_bucket_neon with base=0).
1176    let ptr = block.as_ptr();
1177    let c0 = vreinterpretq_u16_s16(vld1q_s16(ptr));
1178    let c1 = vreinterpretq_u16_s16(vld1q_s16(ptr.add(8)));
1179    let zero16 = vdupq_n_u16(0);
1180    let nz0 = vmvnq_u16(vceqq_u16(c0, zero16));
1181    let nz1 = vmvnq_u16(vceqq_u16(c1, zero16));
1182    let xv = vdupq_n_u16(10); // UNK ^ ACTIVE = 10
1183    let uv = vdupq_n_u16(8); // UNK = 8
1184    let r0 = veorq_u16(uv, vandq_u16(xv, nz0));
1185    let r1 = veorq_u16(uv, vandq_u16(xv, nz1));
1186    let new_flags = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
1187    // Blend: where should_update, take new_flags; where ZERO state, keep old.
1188    let result = vbslq_u8(should_update, new_flags, old_u8);
1189    vst1q_u8(old_flags.as_mut_ptr(), result);
1190    // Horizontal OR of final flags for bstatetmp.
1191    let lo = vget_low_u8(result);
1192    let hi = vget_high_u8(result);
1193    let v4 = vorr_u8(lo, hi);
1194    let v2 = vorr_u8(v4, vext_u8::<4>(v4, v4));
1195    let v1 = vorr_u8(v2, vext_u8::<2>(v2, v2));
1196    let v0 = vorr_u8(v1, vext_u8::<1>(v1, v1));
1197    vget_lane_u8::<0>(v0)
1198}
1199
1200/// AVX2-vectorized band-0 path of `preliminary_flag_computation` for x86_64.
1201///
1202/// Mirror of `prelim_flags_band0_neon`: only updates entries where
1203/// `old_flags[k] != ZERO(1)`; uses an SSE2 blend (`(new & m) | (old & ~m)`)
1204/// for the conditional-write step that NEON does with `vbslq_u8`.
1205#[cfg(all(target_arch = "x86_64", feature = "std"))]
1206#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1207#[target_feature(enable = "avx2")]
1208unsafe fn prelim_flags_band0_avx2(block: &[i16; 1024], old_flags: &mut [u8; 16]) -> u8 {
1209    use core::arch::x86_64::*;
1210    // Load old coeffstate[0] (16 u8 flags: ZERO=1, UNK=8, ACTIVE=2).
1211    let old_u8 = _mm_loadu_si128(old_flags.as_ptr() as *const __m128i);
1212    // should_update mask: 0xFF where old_flags[k] != ZERO(1), 0x00 where == ZERO.
1213    let one_u8 = _mm_set1_epi8(1);
1214    let is_zero_state = _mm_cmpeq_epi8(old_u8, one_u8);
1215    let all_ones_128 = _mm_cmpeq_epi8(old_u8, old_u8);
1216    let should_update = _mm_xor_si128(is_zero_state, all_ones_128);
1217
1218    // Compute new flags from first 16 coefs (same recipe as prelim_flags_bucket_avx2 with base=0).
1219    let coefs = _mm256_loadu_si256(block.as_ptr() as *const __m256i);
1220    let zero = _mm256_setzero_si256();
1221    let eq = _mm256_cmpeq_epi16(coefs, zero);
1222    let all_ones_256 = _mm256_cmpeq_epi16(zero, zero);
1223    let nz = _mm256_xor_si256(eq, all_ones_256);
1224    let xv = _mm256_set1_epi16(10);
1225    let uv = _mm256_set1_epi16(8);
1226    let r16 = _mm256_xor_si256(uv, _mm256_and_si256(xv, nz));
1227    let r_lo = _mm256_castsi256_si128(r16);
1228    let r_hi = _mm256_extracti128_si256::<1>(r16);
1229    let new_flags = _mm_packus_epi16(r_lo, r_hi);
1230
1231    // Blend: (new & should_update) | (old & ~should_update).
1232    let blended = _mm_or_si128(
1233        _mm_and_si128(should_update, new_flags),
1234        _mm_andnot_si128(should_update, old_u8),
1235    );
1236    _mm_storeu_si128(old_flags.as_mut_ptr() as *mut __m128i, blended);
1237
1238    // Horizontal OR of 16 u8 lanes → 1 byte (same reduction as the bucket path).
1239    let or64 = _mm_or_si128(blended, _mm_unpackhi_epi64(blended, blended));
1240    let or32 = _mm_or_si128(or64, _mm_srli_si128::<4>(or64));
1241    let or16_red = _mm_or_si128(or32, _mm_srli_si128::<2>(or32));
1242    let or8 = _mm_or_si128(or16_red, _mm_srli_si128::<1>(or16_red));
1243    _mm_extract_epi8::<0>(or8) as u8
1244}
1245
1246/// Dispatcher for the band-0 path of `preliminary_flag_computation`.
1247///
1248/// Picks NEON on aarch64, AVX2 on x86_64 when available, scalar otherwise.
1249#[allow(unsafe_code)]
1250#[inline(always)]
1251fn band0_dispatch(block: &[i16; 1024], old_flags: &mut [u8; 16]) -> u8 {
1252    #[cfg(target_arch = "aarch64")]
1253    // SAFETY: NEON always available on aarch64; block[0..16] valid by construction.
1254    return unsafe { prelim_flags_band0_neon(block, old_flags) };
1255
1256    #[cfg(all(target_arch = "x86_64", feature = "std"))]
1257    {
1258        if std::is_x86_feature_detected!("avx2") {
1259            // SAFETY: AVX2 was just feature-detected; block[0..16] valid by construction.
1260            return unsafe { prelim_flags_band0_avx2(block, old_flags) };
1261        }
1262    }
1263
1264    #[cfg_attr(target_arch = "aarch64", allow(unreachable_code))]
1265    {
1266        let mut b = 0u8;
1267        for k in 0..16 {
1268            if old_flags[k] != ZERO {
1269                old_flags[k] = if block[k] == 0 { UNK } else { ACTIVE };
1270            }
1271            b |= old_flags[k];
1272        }
1273        b
1274    }
1275}
1276
1277impl PlaneDecoder {
1278    fn new(width: usize, height: usize) -> Self {
1279        let block_cols = width.div_ceil(32);
1280        let block_rows = height.div_ceil(32);
1281        let block_count = block_cols * block_rows;
1282        PlaneDecoder {
1283            width,
1284            height,
1285            block_cols,
1286            blocks: vec![[0i16; 1024]; block_count],
1287            quant_lo: QUANT_LO_INIT,
1288            quant_hi: QUANT_HI_INIT,
1289            curband: 0,
1290            ctx_decode_bucket: [0; 1],
1291            ctx_decode_coef: [0; 80],
1292            ctx_activate_coef: [0; 16],
1293            ctx_increase_coef: [0; 1],
1294            coeffstate: [[0; 16]; 16],
1295            bucketstate: [0; 16],
1296            bbstate: 0,
1297        }
1298    }
1299
1300    /// Decode one slice (one band across all blocks) from `zp`.
1301    fn decode_slice(&mut self, zp: &mut ZpDecoder<'_>) {
1302        if !self.is_null_slice() {
1303            for block_idx in 0..self.blocks.len() {
1304                self.preliminary_flag_computation(block_idx);
1305                if self.block_band_decoding_pass(zp) && self.bucket_decoding_pass(zp, block_idx) {
1306                    self.newly_active_coefficient_decoding_pass(zp, block_idx);
1307                }
1308                // Skip the inner loop entirely when no ACTIVE coefficients exist
1309                // (avoids function call + zp register flush for fresh/sparse blocks).
1310                if (self.bbstate & ACTIVE) != 0 {
1311                    self.previously_active_coefficient_decoding_pass(zp, block_idx);
1312                }
1313            }
1314        }
1315        self.finish_slice();
1316    }
1317
1318    fn is_null_slice(&mut self) -> bool {
1319        if self.curband == 0 {
1320            let mut is_null = true;
1321            for i in 0..16 {
1322                let threshold = self.quant_lo[i];
1323                self.coeffstate[0][i] = ZERO;
1324                if threshold > 0 && threshold < 0x8000 {
1325                    self.coeffstate[0][i] = UNK;
1326                    is_null = false;
1327                }
1328            }
1329            is_null
1330        } else {
1331            let threshold = self.quant_hi[self.curband];
1332            !(threshold > 0 && threshold < 0x8000)
1333        }
1334    }
1335
1336    fn preliminary_flag_computation(&mut self, block_idx: usize) {
1337        self.bbstate = 0;
1338        let (from, to) = BAND_BUCKETS[self.curband];
1339
1340        if self.curband != 0 {
1341            for (boff, j) in (from..=to).enumerate() {
1342                let bstatetmp = prelim_flags_bucket(
1343                    &self.blocks[block_idx],
1344                    j << 4,
1345                    &mut self.coeffstate[boff],
1346                );
1347                self.bucketstate[boff] = bstatetmp;
1348                self.bbstate |= bstatetmp;
1349            }
1350        } else {
1351            let bstatetmp = band0_dispatch(&self.blocks[block_idx], &mut self.coeffstate[0]);
1352            self.bucketstate[0] = bstatetmp;
1353            self.bbstate |= bstatetmp;
1354        }
1355    }
1356
1357    fn block_band_decoding_pass(&mut self, zp: &mut ZpDecoder<'_>) -> bool {
1358        let (from, to) = BAND_BUCKETS[self.curband];
1359        let bcount = to - from + 1;
1360        let should_mark_new = bcount < 16
1361            || (self.bbstate & ACTIVE) != 0
1362            || ((self.bbstate & UNK) != 0 && zp.decode_bit(&mut self.ctx_decode_bucket[0]));
1363        if should_mark_new {
1364            self.bbstate |= NEW;
1365        }
1366        (self.bbstate & NEW) != 0
1367    }
1368
1369    /// Returns `true` if any bucket was newly marked active (NEW bit set).
1370    fn bucket_decoding_pass(&mut self, zp: &mut ZpDecoder<'_>, block_idx: usize) -> bool {
1371        let (from, to) = BAND_BUCKETS[self.curband];
1372        let mut any_new = false;
1373        for (boff, i) in (from..=to).enumerate() {
1374            if (self.bucketstate[boff] & UNK) == 0 {
1375                continue;
1376            }
1377            let mut n: usize = 0;
1378            if self.curband != 0 {
1379                let t = 4 * i;
1380                for j in t..t + 4 {
1381                    if self.blocks[block_idx][j] != 0 {
1382                        n += 1;
1383                    }
1384                }
1385                if n == 4 {
1386                    n = 3;
1387                }
1388            }
1389            if (self.bbstate & ACTIVE) != 0 {
1390                n |= 4;
1391            }
1392            if zp.decode_bit(&mut self.ctx_decode_coef[n + self.curband * 8]) {
1393                self.bucketstate[boff] |= NEW;
1394                any_new = true;
1395            }
1396        }
1397        any_new
1398    }
1399
1400    fn newly_active_coefficient_decoding_pass(&mut self, zp: &mut ZpDecoder<'_>, block_idx: usize) {
1401        let (from, to) = BAND_BUCKETS[self.curband];
1402        let mut step = self.quant_hi[self.curband];
1403        for (boff, i) in (from..=to).enumerate() {
1404            if (self.bucketstate[boff] & NEW) != 0 {
1405                let shift: usize = if (self.bucketstate[boff] & ACTIVE) != 0 {
1406                    8
1407                } else {
1408                    0
1409                };
1410                let mut np: usize = 0;
1411                for j in 0..16 {
1412                    if (self.coeffstate[boff][j] & UNK) != 0 {
1413                        np += 1;
1414                    }
1415                }
1416                for j in 0..16 {
1417                    if (self.coeffstate[boff][j] & UNK) != 0 {
1418                        let ip = np.min(7);
1419                        if zp.decode_bit(&mut self.ctx_activate_coef[shift + ip]) {
1420                            let sign = if zp.decode_passthrough_iw44() {
1421                                -1i32
1422                            } else {
1423                                1i32
1424                            };
1425                            np = 0;
1426                            if self.curband == 0 {
1427                                step = self.quant_lo[j];
1428                            }
1429                            let s = step as i32;
1430                            let val = sign * (s + (s >> 1) - (s >> 3));
1431                            self.blocks[block_idx][(i << 4) | j] = val as i16;
1432                        }
1433                        np = np.saturating_sub(1);
1434                    }
1435                }
1436            }
1437        }
1438    }
1439
1440    /// Hot inner loop for refining already-active coefficients.
1441    ///
1442    /// Uses local copies of all ZP state fields so LLVM can keep them in
1443    /// registers for the duration of the double-loop, avoiding struct-pointer
1444    /// round-trips on every `decode_bit` / `decode_passthrough_iw44` call.
1445    #[inline(never)]
1446    fn previously_active_coefficient_decoding_pass(
1447        &mut self,
1448        zp: &mut ZpDecoder<'_>,
1449        block_idx: usize,
1450    ) {
1451        use djvu_zp::tables::{LPS_NEXT, MPS_NEXT, PROB, THRESHOLD};
1452
1453        // Extract ZP state to true stack-locals — LLVM keeps these in registers.
1454        let mut a = zp.a;
1455        let mut c = zp.c;
1456        let mut fence = zp.fence;
1457        let mut bit_buf = zp.bit_buf;
1458        let mut bit_count = zp.bit_count;
1459        let data = zp.data;
1460        let mut pos = zp.pos;
1461
1462        macro_rules! read_byte {
1463            () => {{
1464                let b = if pos < data.len() { data[pos] } else { 0xff };
1465                pos = pos.wrapping_add(1);
1466                b as u32
1467            }};
1468        }
1469        macro_rules! refill {
1470            () => {
1471                while bit_count <= 24 {
1472                    bit_buf = (bit_buf << 8) | read_byte!();
1473                    bit_count += 8;
1474                }
1475            };
1476        }
1477        macro_rules! renorm {
1478            () => {{
1479                let shift = (a as u16).leading_ones();
1480                bit_count -= shift as i32;
1481                a = (a << shift) & 0xffff;
1482                let mask = (1u32 << (shift & 31)).wrapping_sub(1);
1483                c = ((c << shift) | (bit_buf >> (bit_count as u32 & 31)) & mask) & 0xffff;
1484                if bit_count < 16 {
1485                    refill!();
1486                }
1487                fence = c.min(0x7fff);
1488            }};
1489        }
1490        // Decode one bit using an adaptive context byte.
1491        macro_rules! decode_bit_ctx {
1492            ($ctx:expr) => {{
1493                let state = ($ctx) as usize;
1494                let mps_bit = state & 1;
1495                let z = a + PROB[state] as u32;
1496                if z <= fence {
1497                    a = z;
1498                    mps_bit != 0
1499                } else {
1500                    let boundary = 0x6000u32 + ((a + z) >> 2);
1501                    let z_clamped = z.min(boundary);
1502                    if z_clamped > c {
1503                        let complement = 0x10000u32 - z_clamped;
1504                        a = (a + complement) & 0xffff;
1505                        c = (c + complement) & 0xffff;
1506                        $ctx = LPS_NEXT[state];
1507                        renorm!();
1508                        (1 - mps_bit) != 0
1509                    } else {
1510                        if a >= THRESHOLD[state] as u32 {
1511                            $ctx = MPS_NEXT[state];
1512                        }
1513                        bit_count -= 1;
1514                        a = (z_clamped << 1) & 0xffff;
1515                        c = ((c << 1) | (bit_buf >> (bit_count as u32 & 31)) & 1) & 0xffff;
1516                        if bit_count < 16 {
1517                            refill!();
1518                        }
1519                        fence = c.min(0x7fff);
1520                        mps_bit != 0
1521                    }
1522                }
1523            }};
1524        }
1525        // Decode one bit in IW44 passthrough mode (threshold = 0x8000 + 3a/8).
1526        macro_rules! decode_passthrough_iw44 {
1527            () => {{
1528                let z = (0x8000u32 + (3u32 * a) / 8) as u16;
1529                if z as u32 > c {
1530                    let complement = 0x10000u32 - z as u32;
1531                    a = (a + complement) & 0xffff;
1532                    c = (c + complement) & 0xffff;
1533                    renorm!();
1534                    true
1535                } else {
1536                    bit_count -= 1;
1537                    a = (z as u32 * 2) & 0xffff;
1538                    c = (c << 1 | (bit_buf >> (bit_count as u32 & 31)) & 1) & 0xffff;
1539                    if bit_count < 16 {
1540                        refill!();
1541                    }
1542                    fence = c.min(0x7fff);
1543                    false
1544                }
1545            }};
1546        }
1547
1548        let (from, to) = BAND_BUCKETS[self.curband];
1549        let mut step = self.quant_hi[self.curband];
1550        for (boff, i) in (from..=to).enumerate() {
1551            for j in 0..16 {
1552                if (self.coeffstate[boff][j] & ACTIVE) != 0 {
1553                    if self.curband == 0 {
1554                        step = self.quant_lo[j];
1555                    }
1556                    let coef = self.blocks[block_idx][(i << 4) | j];
1557                    let mut abs_coef = coef.unsigned_abs() as i32;
1558                    let s = step as i32;
1559                    let des = if abs_coef <= 3 * s {
1560                        let d = decode_bit_ctx!(self.ctx_increase_coef[0]);
1561                        abs_coef += s >> 2;
1562                        d
1563                    } else {
1564                        decode_passthrough_iw44!()
1565                    };
1566                    if des {
1567                        abs_coef += s >> 1;
1568                    } else {
1569                        abs_coef += -s + (s >> 1);
1570                    }
1571                    self.blocks[block_idx][(i << 4) | j] = if coef < 0 {
1572                        -abs_coef as i16
1573                    } else {
1574                        abs_coef as i16
1575                    };
1576                }
1577            }
1578        }
1579
1580        // Write back ZP state so subsequent calls see the updated arithmetic.
1581        zp.a = a;
1582        zp.c = c;
1583        zp.fence = fence;
1584        zp.bit_buf = bit_buf;
1585        zp.bit_count = bit_count;
1586        zp.pos = pos;
1587    }
1588
1589    /// Advance quantization step and band counter after one slice.
1590    fn finish_slice(&mut self) {
1591        self.quant_hi[self.curband] >>= 1;
1592        if self.curband == 0 {
1593            for i in 0..16 {
1594                self.quant_lo[i] >>= 1;
1595            }
1596        }
1597        self.curband += 1;
1598        if self.curband == 10 {
1599            self.curband = 0;
1600        }
1601    }
1602
1603    /// Apply the inverse wavelet transform and return a flat `i16` array.
1604    ///
1605    /// The returned vector is row-major, with stride = `width.div_ceil(32)*32`.
1606    /// `subsample` ≥ 1 controls the resolution (1 = full, 2 = half, etc.).
1607    fn reconstruct(&self, subsample: usize) -> FlatPlane {
1608        // ── Fast path for sub≥2: compact plane ────────────────────────────────
1609        //
1610        // For subsample=2 the wavelet only ever reads/writes (even_row, even_col)
1611        // positions — those with zigzag index i < 256 (see zigzag_row/col: both
1612        // are even iff bits 8 and 9 of i are 0).  We can therefore:
1613        //   1. Allocate a 4× smaller plane  (ceil(w/2) × ceil(h/2))
1614        //   2. Scatter only the sub_block² low-frequency coefficients per block
1615        //      (zigzag indices 0..sub_block² map to even multiples of sub)
1616        //   3. Run the full wavelet (sub=1) on the compact plane, which now
1617        //      includes the SIMD s=1 pass.
1618        //
1619        // This is equivalent to running the wavelet at sub=2 on the full plane
1620        // and sampling every other position: each compact[k][c] equals the value
1621        // that full[k·sub][c·sub] would hold after the sub=2 wavelet.
1622        //
1623        // The same logic holds for sub=4 (8×8 sub-block) and sub=8 (4×4 sub-block).
1624        if (2..=8).contains(&subsample) && subsample.is_power_of_two() {
1625            let sub = subsample;
1626
1627            // Block structure: the compact plane inherits the same block grid but
1628            // each 32×32 block contributes a (32/sub)×(32/sub) sub-block.
1629            let block_rows = self.height.div_ceil(32);
1630            let sub_block = 32 / sub; // 16 for sub=2, 8 for sub=4, 4 for sub=8
1631
1632            // Compact plane dimensions, aligned to the sub-block width.
1633            let compact_stride = self.block_cols * sub_block;
1634            let compact_rows = block_rows * sub_block;
1635            // Logical image dimensions at the target resolution.
1636            let compact_w = self.width.div_ceil(sub);
1637            let compact_h = self.height.div_ceil(sub);
1638
1639            // Safety: zigzag_row(i)/sub × zigzag_col(i)/sub for i in 0..sub_block²
1640            // is a bijection over [0..sub_block) × [0..sub_block) (bits 8/9 of i are
1641            // 0 → both zigzag values are even; dividing by sub tiles all sub_block²
1642            // positions per block → every element is written before the wavelet reads).
1643            #[allow(unsafe_code)]
1644            let mut plane = FlatPlane {
1645                data: unsafe { uninit_i16_vec(compact_stride * compact_rows) },
1646                stride: compact_stride,
1647            };
1648
1649            // Row-major scatter via compact inverse zigzag tables: write
1650            // sub_block consecutive i16 per row before advancing, maximising
1651            // write-combine efficiency (one cache line per row for sub=2).
1652            // Safety invariants for get_unchecked below:
1653            //   inv: inv_base+col = row*sub_block+col, row,col ∈ 0..sub_block → < sub_block²
1654            //        = compact_inv.len(); block[i]: compact_inv values < sub_block² ≤ 256
1655            //        < 1024 = block.len(); plane[dst_base+col]: sequential within
1656            //        (base_row+row)*compact_stride+base_col+[0,sub_block) — all in bounds.
1657            let compact_inv: &[u8] = match sub {
1658                2 => &ZIGZAG_INV_SUB2,
1659                4 => &ZIGZAG_INV_SUB4,
1660                _ => &ZIGZAG_INV_SUB8, // sub=8
1661            };
1662            #[allow(unsafe_code)]
1663            for r in 0..block_rows {
1664                for c in 0..self.block_cols {
1665                    let block = &self.blocks[r * self.block_cols + c];
1666                    let base_row = r * sub_block;
1667                    let base_col = c * sub_block;
1668                    for row in 0..sub_block {
1669                        let dst_base = (base_row + row) * compact_stride + base_col;
1670                        let inv_base = row * sub_block;
1671                        for col in 0..sub_block {
1672                            // Safety: see invariants above.
1673                            let i = unsafe { *compact_inv.get_unchecked(inv_base + col) } as usize;
1674                            unsafe {
1675                                *plane.data.get_unchecked_mut(dst_base + col) =
1676                                    *block.get_unchecked(i);
1677                            }
1678                        }
1679                    }
1680                }
1681            }
1682
1683            // Run the wavelet on the compact plane starting at scale 16/sub.
1684            // compact s=k ↔ full s=k·sub, so the coarsest valid pass is
1685            // s = 16/sub (e.g. s=8 for sub=2).  Starting at s=16 would add a
1686            // spurious pass with no coefficients and introduce rounding noise.
1687            let start_scale = 16 / sub;
1688            inverse_wavelet_transform_from(&mut plane, compact_w, compact_h, 1, start_scale);
1689            return plane;
1690        }
1691
1692        // ── Default path (sub=1, or non-power-of-two sub) ─────────────────────
1693        let full_width = self.width.div_ceil(32) * 32;
1694        let full_height = self.height.div_ceil(32) * 32;
1695        let block_rows = self.height.div_ceil(32);
1696        // Safety: ZIGZAG_ROW/COL for i in 0..1024 is a bijection over [0..32)×[0..32)
1697        // (odd-indexed bits → row, even-indexed bits → col, non-overlapping). The
1698        // scatter below writes every element before the wavelet reads any of them.
1699        #[allow(unsafe_code)]
1700        let mut plane = FlatPlane {
1701            data: unsafe { uninit_i16_vec(full_width * full_height) },
1702            stride: full_width,
1703        };
1704
1705        // Row-major scatter via ZIGZAG_INV: write 32 consecutive i16 per row
1706        // (= 1 cache line) before advancing, maximising write-combine efficiency.
1707        // block[ZIGZAG_INV[row*32+col]] is a gathered read from a 2 KB array
1708        // that fits in L1, so the scatter cost is minimal.
1709        for r in 0..block_rows {
1710            for c in 0..self.block_cols {
1711                let block = &self.blocks[r * self.block_cols + c];
1712                let row_base = r << 5;
1713                let col_base = c << 5;
1714                for row in 0..32usize {
1715                    let dst_base = (row_base + row) * full_width + col_base;
1716                    let inv_base = row * 32;
1717                    for col in 0..32usize {
1718                        let i = ZIGZAG_INV[inv_base + col] as usize;
1719                        plane.data[dst_base + col] = block[i];
1720                    }
1721                }
1722            }
1723        }
1724
1725        inverse_wavelet_transform(&mut plane, self.width, self.height, subsample);
1726        plane
1727    }
1728}
1729
1730// ---- Flat plane helper -------------------------------------------------------
1731
1732/// Allocate `n` uninitialized `i16` elements.
1733///
1734/// Uses `Vec<MaybeUninit<i16>>` (the clippy-blessed pattern) and reinterprets
1735/// as `Vec<i16>`.
1736///
1737/// # Safety
1738/// Caller must write every element before reading it.
1739#[allow(unsafe_code)]
1740unsafe fn uninit_i16_vec(n: usize) -> Vec<i16> {
1741    use core::mem::MaybeUninit;
1742    let mut v: Vec<MaybeUninit<i16>> = Vec::with_capacity(n);
1743    // Safety: MaybeUninit<i16> requires no initialization; len will equal capacity.
1744    unsafe { v.set_len(n) };
1745    let mut md = core::mem::ManuallyDrop::new(v);
1746    // Safety: MaybeUninit<i16> and i16 have identical layout; capacity unchanged.
1747    unsafe { Vec::from_raw_parts(md.as_mut_ptr().cast::<i16>(), md.len(), md.capacity()) }
1748}
1749
1750struct FlatPlane {
1751    data: Vec<i16>,
1752    stride: usize,
1753}
1754
1755// ---- Inverse Dubuc-Deslauriers-Lemire (4,4) wavelet transform ---------------
1756//
1757// Two passes per resolution level:
1758//   1. Column pass (lifting + prediction along rows of subsampled columns)
1759//   2. Row pass (lifting + prediction along columns of subsampled rows)
1760//
1761// The column pass is transposed for cache efficiency.
1762//
1763// When `s == 1` (the final, highest-resolution level) the column indices are
1764// contiguous, so we can process 8 columns per iteration using `wide::i32x8`.
1765
1766use wide::i32x8;
1767
1768/// Load 8 `i16` values at stride `s` starting at `slice[phys_off]`.
1769///
1770/// Reads `slice[phys_off + j*s]` for j = 0..7. For s=1 this is identical to
1771/// [`load8`]. For s=2 and s=4 the AArch64 path uses `ld2`/`ld4` to deinterleave
1772/// in a single instruction; other targets use scalar loads that LLVM may
1773/// auto-vectorize.
1774#[inline(always)]
1775fn load8s(slice: &[i16], phys_off: usize, s: usize) -> i32x8 {
1776    // s=1 fast path: single contiguous load + sign-extend.  Checked FIRST so that
1777    // the s=1 branch is a single cmp+b (not taken on s≠1) rather than a 5-branch
1778    // dispatch chain inside load8s_neon.
1779    if s == 1 {
1780        // x86_64 + AVX2 enabled at compile time: `vpmovsxwd ymm, [mem]` is one
1781        // instruction (movdqu + vpmovsxwd, fused on most µarchs). Compile-time
1782        // gating keeps the hot loop branch-free; runtime detection in this loop
1783        // would dominate the kernel.
1784        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
1785        {
1786            #[allow(unsafe_code)]
1787            return unsafe { load8s_s1_avx2(slice, phys_off) };
1788        }
1789        // WASM simd128 compile-time path: `i32x4.extend_low/high_i16x8_s` sign-extends
1790        // 8×i16 → 8×i32 in two 128-bit ops, avoiding 8 scalar cast+store pairs.
1791        // On WASM, `wide::i32x8` is `{a: i32x4, b: i32x4}` where each `i32x4` is
1792        // `repr(transparent)` over `v128`, so [lo, hi]: [v128; 2] transmutes cleanly.
1793        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1794        {
1795            #[allow(unsafe_code)]
1796            return unsafe { load8s_s1_simd128(slice, phys_off) };
1797        }
1798        #[allow(unsafe_code, unreachable_code)]
1799        return unsafe {
1800            // SAFETY: caller ensures phys_off+7 < slice.len().
1801            let arr: [i16; 8] = core::ptr::read(slice.as_ptr().add(phys_off) as *const [i16; 8]);
1802            i32x8::from([
1803                arr[0] as i32,
1804                arr[1] as i32,
1805                arr[2] as i32,
1806                arr[3] as i32,
1807                arr[4] as i32,
1808                arr[5] as i32,
1809                arr[6] as i32,
1810                arr[7] as i32,
1811            ])
1812        };
1813    }
1814    #[cfg(target_arch = "aarch64")]
1815    if s == 2 || s == 4 {
1816        #[allow(unsafe_code)]
1817        return unsafe { load8s_neon(slice, phys_off, s) };
1818    }
1819    i32x8::from([
1820        slice[phys_off] as i32,
1821        slice[phys_off + s] as i32,
1822        slice[phys_off + 2 * s] as i32,
1823        slice[phys_off + 3 * s] as i32,
1824        slice[phys_off + 4 * s] as i32,
1825        slice[phys_off + 5 * s] as i32,
1826        slice[phys_off + 6 * s] as i32,
1827        slice[phys_off + 7 * s] as i32,
1828    ])
1829}
1830
1831/// Store 8 `i32x8` values (truncated to `i16`) at stride `s` starting at `slice[phys_off]`.
1832///
1833/// Writes `slice[phys_off + j*s] = v[j] as i16` for j = 0..7. Interleaved positions
1834/// (those not at multiples of `s`) are left unchanged.
1835#[inline(always)]
1836fn store8s(slice: &mut [i16], phys_off: usize, s: usize, v: i32x8) {
1837    // s=1 fast path: narrow and store contiguously.  Same reasoning as load8s.
1838    if s == 1 {
1839        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
1840        {
1841            #[allow(unsafe_code)]
1842            return unsafe { store8s_s1_avx2(slice, phys_off, v) };
1843        }
1844        // WASM simd128: byte-shuffle to pack the low halfword of each i32 lane into
1845        // a contiguous i16x8.  Indices 0,1,4,5,8,9,12,13 pick bytes 0-1 of each 4-byte
1846        // i32 from the low half (lo), and indices 16,17,20,21,24,25,28,29 do the same
1847        // for the high half (hi).  This matches the truncating `as i16` semantics
1848        // (not saturating narrow) and mirrors the AVX2 byte-shuffle approach.
1849        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1850        {
1851            #[allow(unsafe_code)]
1852            return unsafe { store8s_s1_simd128(slice, phys_off, v) };
1853        }
1854        #[allow(unsafe_code, unreachable_code)]
1855        return unsafe {
1856            // SAFETY: caller ensures phys_off+7 < slice.len().
1857            let a = v.to_array();
1858            let narrow: [i16; 8] = [
1859                a[0] as i16,
1860                a[1] as i16,
1861                a[2] as i16,
1862                a[3] as i16,
1863                a[4] as i16,
1864                a[5] as i16,
1865                a[6] as i16,
1866                a[7] as i16,
1867            ];
1868            core::ptr::write(slice.as_mut_ptr().add(phys_off) as *mut [i16; 8], narrow);
1869        };
1870    }
1871    #[cfg(target_arch = "aarch64")]
1872    if s == 2 || s == 4 {
1873        #[allow(unsafe_code)]
1874        return unsafe { store8s_neon(slice, phys_off, s, v) };
1875    }
1876    let a = v.to_array();
1877    for j in 0..8 {
1878        slice[phys_off + j * s] = a[j] as i16;
1879    }
1880}
1881
1882// ---- AArch64 NEON stride load/store -----------------------------------------
1883//
1884// ld2 deinterleaves 16 consecutive i16s into two vectors (even, odd).
1885// ld4 deinterleaves 32 consecutive i16s into four vectors.
1886// After widening the target lane to i32, `lifting_even` / `predict_inner`
1887// run on i32x8 exactly as for s=1.
1888// On store, we re-interleave the updated even lane with the unchanged odd lanes.
1889
1890#[cfg(target_arch = "aarch64")]
1891// s=1 is now handled directly in load8s/store8s (single ldr/str q without dispatch).
1892// This function only needs to handle s=2 and s=4.
1893#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1894#[target_feature(enable = "neon")]
1895unsafe fn load8s_neon(slice: &[i16], phys_off: usize, s: usize) -> i32x8 {
1896    use core::arch::aarch64::*;
1897    let ptr = slice.as_ptr().add(phys_off);
1898    let target: int16x8_t = if s == 2 {
1899        vld2q_s16(ptr).0
1900    } else {
1901        // s == 4
1902        vld4q_s16(ptr).0
1903    };
1904    // Widen i16x8 → two i32x4, then reinterpret as [i32;8] → i32x8
1905    let lo = vmovl_s16(vget_low_s16(target));
1906    let hi = vmovl_high_s16(target);
1907    let arr = core::mem::transmute::<[int32x4_t; 2], [i32; 8]>([lo, hi]);
1908    i32x8::from(arr)
1909}
1910
1911// ---- x86_64 AVX2 stride-1 load/store ---------------------------------------
1912//
1913// `vpmovsxwd ymm, [mem]` sign-extends 8×i16 → 8×i32 in one fused load+convert.
1914// Truncating narrow i32x8 → i16x8 has no native AVX2 instruction (the only
1915// pack ops saturate); we emulate it with a per-lane byte shuffle that gathers
1916// the low halfword of each i32 lane, then a 64-bit lane permute to combine
1917// the two 128-bit halves.
1918//
1919// `i32x8` ↔ `__m256i` are layout-compatible on x86_64 with AVX2 enabled
1920// (`wide` uses `__m256i` internally), and the existing `C16: i32x8 = transmute([16i32; 8])`
1921// pattern at line ~1639 already relies on this. Both are 32 bytes.
1922
1923#[cfg(all(target_arch = "x86_64", feature = "std"))]
1924#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
1925#[target_feature(enable = "avx2")]
1926#[inline]
1927unsafe fn load8s_s1_avx2(slice: &[i16], phys_off: usize) -> i32x8 {
1928    use core::arch::x86_64::*;
1929    let ptr = slice.as_ptr().add(phys_off) as *const __m128i;
1930    let v16 = _mm_loadu_si128(ptr);
1931    let v32 = _mm256_cvtepi16_epi32(v16);
1932    let arr: [i32; 8] = core::mem::transmute(v32);
1933    i32x8::from(arr)
1934}
1935
1936#[cfg(all(target_arch = "x86_64", feature = "std"))]
1937#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
1938#[target_feature(enable = "avx2")]
1939#[inline]
1940unsafe fn store8s_s1_avx2(slice: &mut [i16], phys_off: usize, v: i32x8) {
1941    use core::arch::x86_64::*;
1942    let arr: [i32; 8] = v.to_array();
1943    let v32: __m256i = core::mem::transmute(arr);
1944    // Per-lane byte shuffle: pack low halfwords of each i32 into the low 64 bits
1945    // of each 128-bit lane. _mm256_shuffle_epi8 is per-128-bit-lane, so the same
1946    // 16-byte mask applies to both halves.
1947    let shuf = _mm256_setr_epi8(
1948        0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 4, 5, 8, 9, 12, 13, -1, -1,
1949        -1, -1, -1, -1, -1, -1,
1950    );
1951    let shuffled = _mm256_shuffle_epi8(v32, shuf);
1952    // 64-bit lanes after shuffle: [lo_packed | zeros | hi_packed | zeros].
1953    // Permute to bring [lo_packed | hi_packed] into the low 128 bits.
1954    // Imm 0b00_00_10_00 = lane 0 → 0 (lo_packed), lane 1 → 2 (hi_packed).
1955    let permuted = _mm256_permute4x64_epi64::<0b00_00_10_00>(shuffled);
1956    let result = _mm256_castsi256_si128(permuted);
1957    let ptr = slice.as_mut_ptr().add(phys_off) as *mut __m128i;
1958    _mm_storeu_si128(ptr, result);
1959}
1960
1961#[cfg(target_arch = "aarch64")]
1962#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
1963#[target_feature(enable = "neon")]
1964unsafe fn store8s_neon(slice: &mut [i16], phys_off: usize, s: usize, v: i32x8) {
1965    use core::arch::aarch64::*;
1966    let ptr = slice.as_mut_ptr().add(phys_off);
1967    // Narrow v (i32x8) back to i16x8 via vmovn (truncate low 16 bits)
1968    let v_arr = core::mem::transmute::<[i32; 8], [int32x4_t; 2]>(v.to_array());
1969    let new_vals = vcombine_s16(vmovn_s32(v_arr[0]), vmovn_s32(v_arr[1]));
1970    // For s=2,4: scatter-store 8 i16s to stride-s positions.
1971    // Using 8 individual str h avoids the extra vld2/vld4 that would be needed
1972    // to preserve interleaved odd lanes before a vst2/vst4.
1973    // Each str h targets the same ~16-byte cache region (already hot from load8s).
1974    let a: [i16; 8] = core::mem::transmute(new_vals);
1975    for (j, &val) in a.iter().enumerate() {
1976        *ptr.add(j * s) = val;
1977    }
1978}
1979
1980// ---- WASM simd128 stride-1 load/store ----------------------------------------
1981//
1982// On WASM simd128, `wide::i32x8` compiles to `{a: i32x4, b: i32x4}` where each
1983// `i32x4` is `repr(transparent)` over `v128`.  The struct is `repr(C, align(32))`
1984// so it is memory-compatible with `[v128; 2]` (two consecutive 128-bit values).
1985//
1986// Load: `i32x4.extend_low_i16x8_s` / `i32x4.extend_high_i16x8_s` each produce one
1987// `v128` of 4×i32 from the low/high 4 lanes of an i16x8, sign-extending in a single
1988// WASM instruction (equivalent to `_mm256_cvtepi16_epi32` on AVX2 but in two 128-bit
1989// ops).
1990//
1991// Store: `i8x16_shuffle` with constant mask picks bytes 0,1,4,5,8,9,12,13 from the
1992// low half and 0,1,4,5,8,9,12,13 from the high half (as indices 16..31 into the
1993// second operand), packing the low 2 bytes of each 4-byte i32 lane into a contiguous
1994// 16-byte i16x8.  This is the truncating `as i16` cast (not saturating), matching
1995// the scalar fallback and the AVX2 byte-shuffle approach.
1996
1997#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1998#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
1999#[target_feature(enable = "simd128")]
2000#[inline]
2001unsafe fn load8s_s1_simd128(slice: &[i16], phys_off: usize) -> i32x8 {
2002    use core::arch::wasm32::*;
2003    // Load 8 consecutive i16 (16 bytes) as a v128.
2004    let v16 = v128_load(slice.as_ptr().add(phys_off) as *const v128);
2005    // Sign-extend lower 4 i16 → i32x4 and upper 4 i16 → i32x4.
2006    let lo = i32x4_extend_low_i16x8(v16);
2007    let hi = i32x4_extend_high_i16x8(v16);
2008    // Transmute [v128; 2] → i32x8.  On WASM simd128, i32x8 is {a: i32x4(v128), b: i32x4(v128)}
2009    // (repr(C, align(32))), layout-compatible with two consecutive v128 values.
2010    core::mem::transmute::<[v128; 2], i32x8>([lo, hi])
2011}
2012
2013#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2014#[allow(unsafe_code, unsafe_op_in_unsafe_fn, dead_code)]
2015#[target_feature(enable = "simd128")]
2016#[inline]
2017unsafe fn store8s_s1_simd128(slice: &mut [i16], phys_off: usize, v: i32x8) {
2018    use core::arch::wasm32::*;
2019    // Transmute i32x8 → [v128; 2] (lo = lower 4 lanes, hi = upper 4 lanes).
2020    let [lo, hi]: [v128; 2] = core::mem::transmute(v);
2021    // Pack low halfwords of each i32 lane via constant byte-shuffle.
2022    // Indices 0,1,4,5,8,9,12,13 select bytes 0-1 of lanes 0-3 from `lo` (first operand).
2023    // Indices 16,17,20,21,24,25,28,29 select bytes 0-1 of lanes 0-3 from `hi` (second operand).
2024    // Result is 8 consecutive i16 values, truncating i32→i16 (low 16 bits only).
2025    let out = i8x16_shuffle::<0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29>(lo, hi);
2026    v128_store(slice.as_mut_ptr().add(phys_off) as *mut v128, out);
2027}
2028
2029/// Load 8 contiguous `i32` values from `slice[off..]` into an `i32x8`.
2030///
2031/// # Safety
2032/// Caller must ensure `off + 7 < slice.len()`.
2033#[inline(always)]
2034#[allow(unsafe_code)]
2035fn load8_i32(slice: &[i32], off: usize) -> i32x8 {
2036    // SAFETY: caller guarantees off+7 is in bounds.
2037    unsafe {
2038        i32x8::from([
2039            *slice.get_unchecked(off),
2040            *slice.get_unchecked(off + 1),
2041            *slice.get_unchecked(off + 2),
2042            *slice.get_unchecked(off + 3),
2043            *slice.get_unchecked(off + 4),
2044            *slice.get_unchecked(off + 5),
2045            *slice.get_unchecked(off + 6),
2046            *slice.get_unchecked(off + 7),
2047        ])
2048    }
2049}
2050
2051/// Store 8 values from an `i32x8` into contiguous `i32` slots at `slice[off..]`.
2052///
2053/// # Safety
2054/// Caller must ensure `off + 7 < slice.len()`.
2055#[inline(always)]
2056#[allow(unsafe_code)]
2057fn store8_i32(slice: &mut [i32], off: usize, v: i32x8) {
2058    let a = v.to_array();
2059    // SAFETY: caller guarantees off+7 is in bounds.
2060    unsafe {
2061        *slice.get_unchecked_mut(off) = a[0];
2062        *slice.get_unchecked_mut(off + 1) = a[1];
2063        *slice.get_unchecked_mut(off + 2) = a[2];
2064        *slice.get_unchecked_mut(off + 3) = a[3];
2065        *slice.get_unchecked_mut(off + 4) = a[4];
2066        *slice.get_unchecked_mut(off + 5) = a[5];
2067        *slice.get_unchecked_mut(off + 6) = a[6];
2068        *slice.get_unchecked_mut(off + 7) = a[7];
2069    }
2070}
2071
2072/// Gather one `i16` value from each of 8 consecutive rows at column index `k`.
2073///
2074/// `offs[i]` is the start offset `row_i * stride` for row `i`.
2075///
2076/// # Safety
2077/// Caller must ensure `offs[i] + k < data.len()` for all `i in 0..8`.
2078#[inline(always)]
2079#[allow(unsafe_code)]
2080fn load_rows8(data: &[i16], offs: &[usize; 8], k: usize) -> i32x8 {
2081    // SAFETY: caller guarantees offs[i]+k is in bounds for all i.
2082    unsafe {
2083        i32x8::from([
2084            *data.get_unchecked(offs[0] + k) as i32,
2085            *data.get_unchecked(offs[1] + k) as i32,
2086            *data.get_unchecked(offs[2] + k) as i32,
2087            *data.get_unchecked(offs[3] + k) as i32,
2088            *data.get_unchecked(offs[4] + k) as i32,
2089            *data.get_unchecked(offs[5] + k) as i32,
2090            *data.get_unchecked(offs[6] + k) as i32,
2091            *data.get_unchecked(offs[7] + k) as i32,
2092        ])
2093    }
2094}
2095
2096/// Scatter one value from `v` to each of 8 consecutive rows at column index `k`.
2097///
2098/// # Safety
2099/// Caller must ensure `offs[i] + k < data.len()` for all `i in 0..8`.
2100#[inline(always)]
2101#[allow(unsafe_code)]
2102fn store_rows8(data: &mut [i16], offs: &[usize; 8], k: usize, v: i32x8) {
2103    let a = v.to_array();
2104    // SAFETY: caller guarantees offs[i]+k is in bounds for all i.
2105    unsafe {
2106        *data.get_unchecked_mut(offs[0] + k) = a[0] as i16;
2107        *data.get_unchecked_mut(offs[1] + k) = a[1] as i16;
2108        *data.get_unchecked_mut(offs[2] + k) = a[2] as i16;
2109        *data.get_unchecked_mut(offs[3] + k) = a[3] as i16;
2110        *data.get_unchecked_mut(offs[4] + k) = a[4] as i16;
2111        *data.get_unchecked_mut(offs[5] + k) = a[5] as i16;
2112        *data.get_unchecked_mut(offs[6] + k) = a[6] as i16;
2113        *data.get_unchecked_mut(offs[7] + k) = a[7] as i16;
2114    }
2115}
2116
2117// Compile-time rounding constants — avoids the `memcpy` call that
2118// `i32x8::splat(N)` generates on AArch64 (LLVM doesn't hoist splat to movi.4s).
2119// SAFETY: [i32; 8] and i32x8 have identical representations (8 × 4-byte i32,
2120// 32-byte size); the transmute is value-preserving.
2121#[allow(unsafe_code)]
2122const C16: i32x8 = unsafe { core::mem::transmute([16i32; 8]) };
2123#[allow(unsafe_code)]
2124const C8: i32x8 = unsafe { core::mem::transmute([8i32; 8]) };
2125#[allow(unsafe_code)]
2126const C1: i32x8 = unsafe { core::mem::transmute([1i32; 8]) };
2127
2128/// Lifting filter: `data[idx] -= ((9*(p1+n1) - (p3+n3) + 16) >> 5)`
2129#[inline(always)]
2130fn lifting_even(cur: i32x8, p1: i32x8, n1: i32x8, p3: i32x8, n3: i32x8) -> i32x8 {
2131    let a = p1 + n1;
2132    let c = p3 + n3;
2133    cur - (((a << 3) + a - c + C16) >> 5)
2134}
2135
2136/// Prediction filter (inner): `data[idx] += ((9*(p1+n1) - (p3+n3) + 8) >> 4)`
2137#[inline(always)]
2138fn predict_inner(cur: i32x8, p1: i32x8, n1: i32x8, p3: i32x8, n3: i32x8) -> i32x8 {
2139    let a = p1 + n1;
2140    cur + (((a << 3) + a - (p3 + n3) + C8) >> 4)
2141}
2142
2143/// Prediction filter (boundary): `data[idx] += ((p + n + 1) >> 1)`
2144#[inline(always)]
2145fn predict_avg(cur: i32x8, p: i32x8, n: i32x8) -> i32x8 {
2146    cur + ((p + n + C1) >> 1)
2147}
2148
2149/// AArch64 NEON horizontal row pass for s=1.
2150///
2151/// Processes each row independently using `vld2q_s16` to deinterleave even/odd
2152/// positions and `vextq_s16` for the 5-tap sliding-window neighbors, eliminating
2153/// the scatter loads (`8×ldrh`) used by the vertical 8-rows-at-a-time path.
2154///
2155/// # Even pass (lifting)
2156/// For each chunk of 8 even positions (`chunk*16 .. chunk*16+15`):
2157/// ```text
2158///   vld2q_s16(chunk*16)     → curr_even[0..8], curr_odd[0..8]
2159///   vld2q_s16((chunk+1)*16) → next_even (for n3)
2160///   p1 = vextq_s16(prev_odd, curr_odd, 7)
2161///   n1 = curr_odd
2162///   p3 = vextq_s16(prev_odd, curr_odd, 6)
2163///   n3 = vextq_s16(curr_odd, next_odd, 1)
2164/// ```
2165///
2166/// # Odd pass (prediction)
2167/// For each chunk of 8 inner odd positions at `3+chunk*16, 5+..., 17+chunk*16`:
2168/// ```text
2169///   pair1 = vld2q_s16(chunk*16)     → p3=.0, odds_lo=.1
2170///   pair2 = vld2q_s16((chunk+1)*16) → next_even=.0, odds_hi=.1
2171///   curr_odds = vextq_s16(odds_lo, odds_hi, 1)
2172///   p1 = vextq_s16(p3, next_even, 1)
2173///   n1 = vextq_s16(p3, next_even, 2)
2174///   n3 = vextq_s16(p3, next_even, 3)
2175/// ```
2176///
2177/// # Safety
2178/// `data[row_off .. row_off+width]` must be valid. `width >= 1`.
2179#[cfg(target_arch = "aarch64")]
2180#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
2181#[target_feature(enable = "neon")]
2182unsafe fn row_pass_neon_s1_row(data: &mut [i16], row_off: usize, width: usize) {
2183    use core::arch::aarch64::*;
2184
2185    let kmax = width - 1;
2186    let border = kmax.saturating_sub(3);
2187    let ptr = data.as_mut_ptr().add(row_off);
2188
2189    // Number of NEON even chunks: need next chunk fully in bounds for n3.
2190    // Condition: (chunk+1)*16+15 < width  →  chunk < (width-31)/16.
2191    let even_chunks = if width >= 32 { (width - 31) / 16 } else { 0 };
2192
2193    // ── Even pass (lifting) ────────────────────────────────────────────────────
2194
2195    let mut prev_odd = vdupq_n_s16(0i16);
2196
2197    for chunk in 0..even_chunks {
2198        let curr_pair = vld2q_s16(ptr.add(chunk * 16) as *const i16);
2199        let next_pair = vld2q_s16(ptr.add((chunk + 1) * 16) as *const i16);
2200        let curr_even = curr_pair.0;
2201        let curr_odd = curr_pair.1;
2202        let next_odd = next_pair.1;
2203
2204        let p1 = vextq_s16::<7>(prev_odd, curr_odd);
2205        let n1 = curr_odd;
2206        let p3 = vextq_s16::<6>(prev_odd, curr_odd);
2207        let n3 = vextq_s16::<1>(curr_odd, next_odd);
2208
2209        // cur -= ((9*(p1+n1) - (p3+n3) + 16) >> 5)
2210        macro_rules! lift {
2211            ($ce:expr, $p1:expr, $n1:expr, $p3:expr, $n3:expr) => {{
2212                let a = vaddq_s32($p1, $n1);
2213                let c = vaddq_s32($p3, $n3);
2214                let nine_a = vaddq_s32(vshlq_n_s32::<3>(a), a);
2215                let delta = vshrq_n_s32::<5>(vsubq_s32(vaddq_s32(nine_a, vdupq_n_s32(16i32)), c));
2216                vsubq_s32($ce, delta)
2217            }};
2218        }
2219
2220        let new_lo = lift!(
2221            vmovl_s16(vget_low_s16(curr_even)),
2222            vmovl_s16(vget_low_s16(p1)),
2223            vmovl_s16(vget_low_s16(n1)),
2224            vmovl_s16(vget_low_s16(p3)),
2225            vmovl_s16(vget_low_s16(n3))
2226        );
2227        let new_hi = lift!(
2228            vmovl_high_s16(curr_even),
2229            vmovl_high_s16(p1),
2230            vmovl_high_s16(n1),
2231            vmovl_high_s16(p3),
2232            vmovl_high_s16(n3)
2233        );
2234        let new_evens = vcombine_s16(vmovn_s32(new_lo), vmovn_s32(new_hi));
2235
2236        vst2q_s16(ptr.add(chunk * 16), int16x8x2_t(new_evens, curr_odd));
2237
2238        prev_odd = curr_odd;
2239    }
2240
2241    // Scalar even tail: k = even_chunks*16, +2, ... <= kmax.
2242    // State just before the first advance: prev1=prev_odd[6], next1=prev_odd[7], next3=data[k+1].
2243    {
2244        let k_start = even_chunks * 16;
2245        let mut prev1 = if even_chunks > 0 {
2246            vgetq_lane_s16::<6>(prev_odd) as i32
2247        } else {
2248            0
2249        };
2250        let mut next1 = if even_chunks > 0 {
2251            vgetq_lane_s16::<7>(prev_odd) as i32
2252        } else {
2253            0
2254        };
2255        let mut next3 = if k_start < kmax {
2256            *data.get_unchecked(row_off + k_start + 1) as i32
2257        } else {
2258            0
2259        };
2260        let mut k = k_start;
2261        while k <= kmax {
2262            let prev3 = prev1;
2263            prev1 = next1;
2264            next1 = next3;
2265            next3 = if k + 3 <= kmax {
2266                *data.get_unchecked(row_off + k + 3) as i32
2267            } else {
2268                0
2269            };
2270            let a = prev1 + next1;
2271            let c = prev3 + next3;
2272            let idx = row_off + k;
2273            *data.get_unchecked_mut(idx) =
2274                (*data.get_unchecked(idx) as i32 - (((a << 3) + a - c + 16) >> 5)) as i16;
2275            k += 2;
2276        }
2277    }
2278
2279    // ── Odd pass (prediction) ──────────────────────────────────────────────────
2280
2281    if kmax < 1 {
2282        return;
2283    }
2284
2285    // k=1: always predict_avg (or +=prev if k==kmax)
2286    {
2287        let p1 = *data.get_unchecked(row_off) as i32;
2288        let idx1 = row_off + 1;
2289        if 1 < kmax {
2290            let n1 = *data.get_unchecked(row_off + 2) as i32;
2291            *data.get_unchecked_mut(idx1) =
2292                (*data.get_unchecked(idx1) as i32 + ((p1 + n1 + 1) >> 1)) as i16;
2293        } else {
2294            *data.get_unchecked_mut(idx1) = (*data.get_unchecked(idx1) as i32 + p1) as i16;
2295        }
2296    }
2297
2298    // NEON inner odd chunks: predict_inner for k=3,5,...,17+chunk*16.
2299    // Safety: need (chunk+1)*16+15 < width AND 17+chunk*16 <= border (= kmax-3).
2300    // Combined: chunk < (width-31)/16 (same as even_chunks).
2301    // Inner check: 17+chunk*16 <= kmax-3  →  chunk <= (kmax-20)/16.
2302    let odd_chunks = if kmax >= 20 {
2303        even_chunks.min((kmax - 20) / 16 + 1)
2304    } else {
2305        0
2306    };
2307
2308    for chunk in 0..odd_chunks {
2309        // pair1: evens[chunk*8..+7] in .0, odds[chunk*8..+7] in .1
2310        let pair1 = vld2q_s16(ptr.add(chunk * 16) as *const i16);
2311        // pair2: evens[(chunk+1)*8..+7] in .0, odds[(chunk+1)*8..+7] in .1
2312        let pair2 = vld2q_s16(ptr.add((chunk + 1) * 16) as *const i16);
2313
2314        // 8 inner odds at physical positions 3+chunk*16, 5+..., 17+chunk*16
2315        let curr_odds = vextq_s16::<1>(pair1.1, pair2.1);
2316
2317        // Even neighbors for predict_inner:
2318        // p3[i] = even at k_odd-3 = chunk*16+2i → pair1.0[i]
2319        // p1[i] = even at k_odd-1 = chunk*16+2i+2 → vextq(pair1.0, pair2.0, 1)[i]
2320        // n1[i] = even at k_odd+1 = chunk*16+2i+4 → vextq(pair1.0, pair2.0, 2)[i]
2321        // n3[i] = even at k_odd+3 = chunk*16+2i+6 → vextq(pair1.0, pair2.0, 3)[i]
2322        let p3_e = pair1.0;
2323        let p1_e = vextq_s16::<1>(pair1.0, pair2.0); // also = unchanged evens for store
2324        let n1_e = vextq_s16::<2>(pair1.0, pair2.0);
2325        let n3_e = vextq_s16::<3>(pair1.0, pair2.0);
2326
2327        // cur += ((9*(p1+n1) - (p3+n3) + 8) >> 4)
2328        macro_rules! predict {
2329            ($co:expr, $p1:expr, $n1:expr, $p3:expr, $n3:expr) => {{
2330                let a = vaddq_s32($p1, $n1);
2331                let c = vaddq_s32($p3, $n3);
2332                let nine_a = vaddq_s32(vshlq_n_s32::<3>(a), a);
2333                let delta = vshrq_n_s32::<4>(vsubq_s32(vaddq_s32(nine_a, vdupq_n_s32(8i32)), c));
2334                vaddq_s32($co, delta)
2335            }};
2336        }
2337
2338        let new_lo = predict!(
2339            vmovl_s16(vget_low_s16(curr_odds)),
2340            vmovl_s16(vget_low_s16(p1_e)),
2341            vmovl_s16(vget_low_s16(n1_e)),
2342            vmovl_s16(vget_low_s16(p3_e)),
2343            vmovl_s16(vget_low_s16(n3_e))
2344        );
2345        let new_hi = predict!(
2346            vmovl_high_s16(curr_odds),
2347            vmovl_high_s16(p1_e),
2348            vmovl_high_s16(n1_e),
2349            vmovl_high_s16(p3_e),
2350            vmovl_high_s16(n3_e)
2351        );
2352        let new_odds = vcombine_s16(vmovn_s32(new_lo), vmovn_s32(new_hi));
2353
2354        // Store: evens at chunk*16+2,+4,...,+16 unchanged (= p1_e), odds updated.
2355        vst2q_s16(ptr.add(chunk * 16 + 2), int16x8x2_t(p1_e, new_odds));
2356    }
2357
2358    // Scalar odd tail: k = 3+odd_chunks*16, ..., kmax (inner then boundary).
2359    // State before the advance at k_scalar: prev1=data[k-3], next1=data[k-1], next3=data[k+1].
2360    if kmax >= 3 {
2361        let k_scalar = 3 + odd_chunks * 16;
2362        let mut prev1 = *data.get_unchecked(row_off + k_scalar - 3) as i32;
2363        let mut next1 = *data.get_unchecked(row_off + k_scalar - 1) as i32;
2364        let mut next3 = if k_scalar < kmax {
2365            *data.get_unchecked(row_off + k_scalar + 1) as i32
2366        } else {
2367            0
2368        };
2369        let mut k = k_scalar;
2370        while k <= kmax {
2371            let prev3 = prev1;
2372            prev1 = next1;
2373            next1 = next3;
2374            next3 = if k + 3 <= kmax {
2375                *data.get_unchecked(row_off + k + 3) as i32
2376            } else {
2377                0
2378            };
2379            let idx = row_off + k;
2380            if k <= border {
2381                let a = prev1 + next1;
2382                let c = prev3 + next3;
2383                *data.get_unchecked_mut(idx) =
2384                    (*data.get_unchecked(idx) as i32 + (((a << 3) + a - c + 8) >> 4)) as i16;
2385            } else if k < kmax {
2386                *data.get_unchecked_mut(idx) =
2387                    (*data.get_unchecked(idx) as i32 + ((prev1 + next1 + 1) >> 1)) as i16;
2388            } else {
2389                *data.get_unchecked_mut(idx) = (*data.get_unchecked(idx) as i32 + prev1) as i16;
2390            }
2391            k += 2;
2392        }
2393    }
2394}
2395
2396/// Apply the row-direction wavelet pass for one resolution level.
2397///
2398/// When `use_simd` is `true` and `s == 1` (`sd == 0`), on AArch64 the
2399/// horizontal NEON path (`row_pass_neon_s1_row`) is used for each row,
2400/// processing 8 even/odd positions at a time with `vld2q_s16` instead of
2401/// scatter loads. For `s > 1` and non-AArch64, the vertical 8-rows-at-a-time
2402/// `i32x8` path is used. The remaining rows (and all rows when `use_simd` is
2403/// false) use the scalar path.
2404///
2405/// `s` — step between active samples (power of two); `sd = log2(s)`.
2406pub(crate) fn row_pass_inner(
2407    data: &mut [i16],
2408    width: usize,
2409    height: usize,
2410    stride: usize,
2411    s: usize,
2412    sd: usize,
2413    use_simd: bool,
2414) {
2415    // AArch64 horizontal NEON path: at s=1, process each row using vld2q_s16
2416    // (sequential deinterleave) instead of scatter loads across 8 rows.
2417    #[cfg(target_arch = "aarch64")]
2418    if use_simd && s == 1 {
2419        for row in (0..height).step_by(s) {
2420            #[allow(unsafe_code)]
2421            unsafe {
2422                row_pass_neon_s1_row(data, row * stride, width);
2423            }
2424        }
2425        return;
2426    }
2427
2428    let kmax = (width - 1) >> sd;
2429    let border = kmax.saturating_sub(3);
2430
2431    // ── SIMD path: 8 active rows at a time ───────────────────────────────────
2432    //
2433    // At s=1 the 8 rows are consecutive (o[i] = (row_base + i) * stride).
2434    // At s=2 they are spaced by 2  (o[i] = (row_base + i*2) * stride), etc.
2435    // Column accesses use `k << sd` so the logical k loop is unchanged.
2436    let simd_active = if use_simd { height / s / 8 * 8 } else { 0 };
2437    let simd_rows = simd_active * s;
2438
2439    for group in 0..simd_active / 8 {
2440        let row_base = group * 8 * s;
2441        let o: [usize; 8] = core::array::from_fn(|i| (row_base + i * s) * stride);
2442
2443        // — Lifting (even k) ——————————————————————————————————————————————————
2444        let mut prev1v = i32x8::splat(0);
2445        let mut next1v = i32x8::splat(0);
2446        let mut next3v = if kmax >= 1 {
2447            load_rows8(data, &o, 1 << sd)
2448        } else {
2449            i32x8::splat(0)
2450        };
2451        let mut prev3v: i32x8;
2452        let mut k = 0usize;
2453        while k <= kmax {
2454            prev3v = prev1v;
2455            prev1v = next1v;
2456            next1v = next3v;
2457            next3v = if k + 3 <= kmax {
2458                load_rows8(data, &o, (k + 3) << sd)
2459            } else {
2460                i32x8::splat(0)
2461            };
2462            let cur = load_rows8(data, &o, k << sd);
2463            store_rows8(
2464                data,
2465                &o,
2466                k << sd,
2467                lifting_even(cur, prev1v, next1v, prev3v, next3v),
2468            );
2469            k += 2;
2470        }
2471
2472        // — Prediction (odd k) ————————————————————————————————————————————————
2473        if kmax >= 1 {
2474            let mut k = 1usize;
2475            prev1v = load_rows8(data, &o, (k - 1) << sd);
2476            if k < kmax {
2477                next1v = load_rows8(data, &o, (k + 1) << sd);
2478                let cur = load_rows8(data, &o, k << sd);
2479                store_rows8(data, &o, k << sd, predict_avg(cur, prev1v, next1v));
2480            } else {
2481                // k == kmax: boundary — only one odd sample, += prev
2482                let cur = load_rows8(data, &o, k << sd);
2483                store_rows8(data, &o, k << sd, cur + prev1v);
2484                next1v = i32x8::splat(0);
2485            }
2486
2487            next3v = if border >= 3 {
2488                load_rows8(data, &o, (k + 3) << sd)
2489            } else {
2490                i32x8::splat(0)
2491            };
2492
2493            k = 3;
2494            while k <= border {
2495                prev3v = prev1v;
2496                prev1v = next1v;
2497                next1v = next3v;
2498                next3v = load_rows8(data, &o, (k + 3) << sd);
2499                let cur = load_rows8(data, &o, k << sd);
2500                store_rows8(
2501                    data,
2502                    &o,
2503                    k << sd,
2504                    predict_inner(cur, prev1v, next1v, prev3v, next3v),
2505                );
2506                k += 2;
2507            }
2508
2509            while k <= kmax {
2510                prev1v = next1v;
2511                next1v = next3v;
2512                next3v = i32x8::splat(0);
2513                let cur = load_rows8(data, &o, k << sd);
2514                if k < kmax {
2515                    store_rows8(data, &o, k << sd, predict_avg(cur, prev1v, next1v));
2516                } else {
2517                    store_rows8(data, &o, k << sd, cur + prev1v);
2518                }
2519                k += 2;
2520            }
2521        }
2522    }
2523
2524    // ── Scalar path: remaining rows ───────────────────────────────────────────
2525    let scalar_start = simd_rows;
2526    for row in (scalar_start..height).step_by(s) {
2527        let off = row * stride;
2528
2529        // Lifting (even samples)
2530        let mut prev1: i32 = 0;
2531        let mut next1: i32 = 0;
2532        let mut next3: i32 = if kmax >= 1 {
2533            data[off + (1 << sd)] as i32
2534        } else {
2535            0
2536        };
2537        let mut prev3: i32;
2538        let mut k = 0usize;
2539        while k <= kmax {
2540            prev3 = prev1;
2541            prev1 = next1;
2542            next1 = next3;
2543            next3 = if k + 3 <= kmax {
2544                data[off + ((k + 3) << sd)] as i32
2545            } else {
2546                0
2547            };
2548            let a = prev1 + next1;
2549            let c = prev3 + next3;
2550            let idx = off + (k << sd);
2551            data[idx] = (data[idx] as i32 - (((a << 3) + a - c + 16) >> 5)) as i16;
2552            k += 2;
2553        }
2554
2555        // Prediction (odd samples)
2556        if kmax >= 1 {
2557            let mut k = 1usize;
2558            prev1 = data[off + ((k - 1) << sd)] as i32;
2559            if k < kmax {
2560                next1 = data[off + ((k + 1) << sd)] as i32;
2561                let idx = off + (k << sd);
2562                data[idx] = (data[idx] as i32 + ((prev1 + next1 + 1) >> 1)) as i16;
2563            } else {
2564                let idx = off + (k << sd);
2565                data[idx] = (data[idx] as i32 + prev1) as i16;
2566            }
2567
2568            next3 = if border >= 3 {
2569                data[off + ((k + 3) << sd)] as i32
2570            } else {
2571                0
2572            };
2573
2574            k = 3;
2575            while k <= border {
2576                prev3 = prev1;
2577                prev1 = next1;
2578                next1 = next3;
2579                next3 = data[off + ((k + 3) << sd)] as i32;
2580                let a = prev1 + next1;
2581                let idx = off + (k << sd);
2582                data[idx] = (data[idx] as i32 + (((a << 3) + a - (prev3 + next3) + 8) >> 4)) as i16;
2583                k += 2;
2584            }
2585
2586            while k <= kmax {
2587                prev1 = next1;
2588                next1 = next3;
2589                next3 = 0;
2590                let idx = off + (k << sd);
2591                if k < kmax {
2592                    data[idx] = (data[idx] as i32 + ((prev1 + next1 + 1) >> 1)) as i16;
2593                } else {
2594                    data[idx] = (data[idx] as i32 + prev1) as i16;
2595                }
2596                k += 2;
2597            }
2598        }
2599    }
2600}
2601
2602fn inverse_wavelet_transform(plane: &mut FlatPlane, width: usize, height: usize, subsample: usize) {
2603    inverse_wavelet_transform_from(plane, width, height, subsample, 16);
2604}
2605
2606/// Like `inverse_wavelet_transform` but begins at `start_scale` instead of 16.
2607///
2608/// Use `start_scale = 16 / sub` when operating on a compact plane produced by
2609/// subsampling the coefficient scatter by factor `sub`.  For example, the sub=2
2610/// compact plane only contains coefficients up to scale 8, so the s=16 pass
2611/// would be purely spurious.
2612fn inverse_wavelet_transform_from(
2613    plane: &mut FlatPlane,
2614    width: usize,
2615    height: usize,
2616    subsample: usize,
2617    start_scale: usize,
2618) {
2619    let stride = plane.stride;
2620    let data = plane.data.as_mut_slice();
2621    let mut s = start_scale;
2622    let mut s_degree: u32 = start_scale.trailing_zeros();
2623
2624    let mut st0 = vec![0i32; width];
2625    let mut st1 = vec![0i32; width];
2626    let mut st2 = vec![0i32; width];
2627
2628    while s >= subsample {
2629        let sd = s_degree as usize;
2630
2631        // Column pass SIMD: enabled for s=1,2,4 using stride-aware load8s/store8s.
2632        // For s=2 the load uses vld2q_s16 (deinterleave even/odd), for s=4 vld4q_s16.
2633        // The scalar else-branches below are now only reached for s>4 (s=8, s=16).
2634        let use_simd = s <= 4;
2635
2636        // ── Column pass (transposed) ──────────────────────────────────────────
2637        {
2638            let kmax = (height - 1) >> sd;
2639            let border = kmax.saturating_sub(3);
2640            let num_cols = width.div_ceil(s);
2641            let simd_cols = if use_simd { num_cols / 8 * 8 } else { 0 };
2642
2643            // Lifting (even samples)
2644            for v in &mut st0[..num_cols] {
2645                *v = 0;
2646            }
2647            for v in &mut st1[..num_cols] {
2648                *v = 0;
2649            }
2650            if kmax >= 1 {
2651                let off = (1 << sd) * stride;
2652                if use_simd {
2653                    for ci in (0..simd_cols).step_by(8) {
2654                        store8_i32(&mut st2, ci, load8s(data, off + ci * s, s));
2655                    }
2656                    for ci in simd_cols..num_cols {
2657                        st2[ci] = data[off + ci * s] as i32;
2658                    }
2659                } else {
2660                    for (ci, col) in (0..width).step_by(s).enumerate() {
2661                        st2[ci] = data[off + col] as i32;
2662                    }
2663                }
2664            } else {
2665                for v in &mut st2[..num_cols] {
2666                    *v = 0;
2667                }
2668            }
2669
2670            // Split even pass into: main (k+3 <= kmax → n3 always in-bounds) and
2671            // tail (k+3 > kmax → n3 = 0), mirroring the odd pass structure.
2672            // This hoists the `has_n3` branch out of the ci inner loop so that
2673            // the hot path (≥97% of k-iterations) has no runtime conditional.
2674            let mut k = 0usize;
2675            // Main: n3 always available
2676            while k + 3 <= kmax {
2677                let k_off = (k << sd) * stride;
2678                let n3_off = ((k + 3) << sd) * stride;
2679                if use_simd {
2680                    let mut ci = 0usize;
2681                    while ci < simd_cols {
2682                        let vp3 = load8_i32(&st0, ci);
2683                        let vp1 = load8_i32(&st1, ci);
2684                        let vn1 = load8_i32(&st2, ci);
2685                        let vn3 = load8s(data, n3_off + ci * s, s);
2686                        let cur = load8s(data, k_off + ci * s, s);
2687                        store8s(
2688                            data,
2689                            k_off + ci * s,
2690                            s,
2691                            lifting_even(cur, vp1, vn1, vp3, vn3),
2692                        );
2693                        store8_i32(&mut st0, ci, vp1);
2694                        store8_i32(&mut st1, ci, vn1);
2695                        store8_i32(&mut st2, ci, vn3);
2696                        ci += 8;
2697                    }
2698                    while ci < num_cols {
2699                        let p3 = st0[ci];
2700                        let p1 = st1[ci];
2701                        let n1 = st2[ci];
2702                        let n3 = data[n3_off + ci * s] as i32;
2703                        let a = p1 + n1;
2704                        let idx = k_off + ci * s;
2705                        data[idx] =
2706                            (data[idx] as i32 - (((a << 3) + a - (p3 + n3) + 16) >> 5)) as i16;
2707                        st0[ci] = p1;
2708                        st1[ci] = n1;
2709                        st2[ci] = n3;
2710                        ci += 1;
2711                    }
2712                } else {
2713                    for (ci, col) in (0..width).step_by(s).enumerate() {
2714                        let p3 = st0[ci];
2715                        let p1 = st1[ci];
2716                        let n1 = st2[ci];
2717                        let n3 = data[n3_off + col] as i32;
2718                        let a = p1 + n1;
2719                        let c = p3 + n3;
2720                        let idx = k_off + col;
2721                        data[idx] = (data[idx] as i32 - (((a << 3) + a - c + 16) >> 5)) as i16;
2722                        st0[ci] = p1;
2723                        st1[ci] = n1;
2724                        st2[ci] = n3;
2725                    }
2726                }
2727                k += 2;
2728            }
2729            // Tail: k+3 > kmax → n3 = 0
2730            while k <= kmax {
2731                let k_off = (k << sd) * stride;
2732                if use_simd {
2733                    let zero8 = i32x8::splat(0);
2734                    let mut ci = 0usize;
2735                    while ci < simd_cols {
2736                        let vp3 = load8_i32(&st0, ci);
2737                        let vp1 = load8_i32(&st1, ci);
2738                        let vn1 = load8_i32(&st2, ci);
2739                        let cur = load8s(data, k_off + ci * s, s);
2740                        store8s(
2741                            data,
2742                            k_off + ci * s,
2743                            s,
2744                            lifting_even(cur, vp1, vn1, vp3, zero8),
2745                        );
2746                        store8_i32(&mut st0, ci, vp1);
2747                        store8_i32(&mut st1, ci, vn1);
2748                        store8_i32(&mut st2, ci, zero8);
2749                        ci += 8;
2750                    }
2751                    while ci < num_cols {
2752                        let p3 = st0[ci];
2753                        let p1 = st1[ci];
2754                        let n1 = st2[ci];
2755                        let a = p1 + n1;
2756                        let idx = k_off + ci * s;
2757                        data[idx] = (data[idx] as i32 - (((a << 3) + a - p3 + 16) >> 5)) as i16;
2758                        st0[ci] = p1;
2759                        st1[ci] = n1;
2760                        st2[ci] = 0;
2761                        ci += 1;
2762                    }
2763                } else {
2764                    for (ci, col) in (0..width).step_by(s).enumerate() {
2765                        let p3 = st0[ci];
2766                        let p1 = st1[ci];
2767                        let n1 = st2[ci];
2768                        let a = p1 + n1;
2769                        let idx = k_off + col;
2770                        data[idx] = (data[idx] as i32 - (((a << 3) + a - p3 + 16) >> 5)) as i16;
2771                        st0[ci] = p1;
2772                        st1[ci] = n1;
2773                        st2[ci] = 0;
2774                    }
2775                }
2776                k += 2;
2777            }
2778
2779            // Prediction (odd samples)
2780            if kmax >= 1 {
2781                // k = 1
2782                let km1_off = 0;
2783                let k_off = (1 << sd) * stride;
2784
2785                if 2 <= kmax {
2786                    let kp1_off = (2 << sd) * stride;
2787                    if use_simd {
2788                        let mut ci = 0usize;
2789                        while ci < simd_cols {
2790                            let vp = load8s(data, km1_off + ci * s, s);
2791                            let vn = load8s(data, kp1_off + ci * s, s);
2792                            let cur = load8s(data, k_off + ci * s, s);
2793                            store8s(data, k_off + ci * s, s, predict_avg(cur, vp, vn));
2794                            store8_i32(&mut st0, ci, vp);
2795                            store8_i32(&mut st1, ci, vn);
2796                            ci += 8;
2797                        }
2798                        while ci < num_cols {
2799                            let p = data[km1_off + ci * s] as i32;
2800                            let n = data[kp1_off + ci * s] as i32;
2801                            let idx = k_off + ci * s;
2802                            data[idx] = (data[idx] as i32 + ((p + n + 1) >> 1)) as i16;
2803                            st0[ci] = p;
2804                            st1[ci] = n;
2805                            ci += 1;
2806                        }
2807                    } else {
2808                        for (ci, col) in (0..width).step_by(s).enumerate() {
2809                            let p = data[km1_off + col] as i32;
2810                            let n = data[kp1_off + col] as i32;
2811                            let idx = k_off + col;
2812                            data[idx] = (data[idx] as i32 + ((p + n + 1) >> 1)) as i16;
2813                            st0[ci] = p;
2814                            st1[ci] = n;
2815                        }
2816                    }
2817                } else if use_simd {
2818                    let mut ci = 0usize;
2819                    while ci < simd_cols {
2820                        let vp = load8s(data, km1_off + ci * s, s);
2821                        let cur = load8s(data, k_off + ci * s, s);
2822                        store8s(data, k_off + ci * s, s, cur + vp);
2823                        store8_i32(&mut st0, ci, vp);
2824                        ci += 8;
2825                    }
2826                    for v in &mut st1[..num_cols] {
2827                        *v = 0;
2828                    }
2829                    while ci < num_cols {
2830                        let p = data[km1_off + ci * s] as i32;
2831                        let idx = k_off + ci * s;
2832                        data[idx] = (data[idx] as i32 + p) as i16;
2833                        st0[ci] = p;
2834                        st1[ci] = 0;
2835                        ci += 1;
2836                    }
2837                } else {
2838                    for (ci, col) in (0..width).step_by(s).enumerate() {
2839                        let p = data[km1_off + col] as i32;
2840                        let idx = k_off + col;
2841                        data[idx] = (data[idx] as i32 + p) as i16;
2842                        st0[ci] = p;
2843                        st1[ci] = 0;
2844                    }
2845                }
2846
2847                if border >= 3 {
2848                    let off = (4 << sd) * stride;
2849                    if use_simd {
2850                        let mut ci = 0usize;
2851                        while ci < simd_cols {
2852                            store8_i32(&mut st2, ci, load8s(data, off + ci * s, s));
2853                            ci += 8;
2854                        }
2855                        while ci < num_cols {
2856                            st2[ci] = data[off + ci * s] as i32;
2857                            ci += 1;
2858                        }
2859                    } else {
2860                        for (ci, col) in (0..width).step_by(s).enumerate() {
2861                            st2[ci] = data[off + col] as i32;
2862                        }
2863                    }
2864                }
2865
2866                // k = 3, 5, ..., border
2867                let mut k = 3usize;
2868                while k <= border {
2869                    let k_off = (k << sd) * stride;
2870                    let n3_off = ((k + 3) << sd) * stride;
2871
2872                    if use_simd {
2873                        let mut ci = 0usize;
2874                        while ci < simd_cols {
2875                            let vp3 = load8_i32(&st0, ci);
2876                            let vp1 = load8_i32(&st1, ci);
2877                            let vn1 = load8_i32(&st2, ci);
2878                            let vn3 = load8s(data, n3_off + ci * s, s);
2879                            let cur = load8s(data, k_off + ci * s, s);
2880                            store8s(
2881                                data,
2882                                k_off + ci * s,
2883                                s,
2884                                predict_inner(cur, vp1, vn1, vp3, vn3),
2885                            );
2886                            store8_i32(&mut st0, ci, vp1);
2887                            store8_i32(&mut st1, ci, vn1);
2888                            store8_i32(&mut st2, ci, vn3);
2889                            ci += 8;
2890                        }
2891                        while ci < num_cols {
2892                            let p3 = st0[ci];
2893                            let p1 = st1[ci];
2894                            let n1 = st2[ci];
2895                            let n3 = data[n3_off + ci * s] as i32;
2896                            let a = p1 + n1;
2897                            let idx = k_off + ci * s;
2898                            data[idx] =
2899                                (data[idx] as i32 + (((a << 3) + a - (p3 + n3) + 8) >> 4)) as i16;
2900                            st0[ci] = p1;
2901                            st1[ci] = n1;
2902                            st2[ci] = n3;
2903                            ci += 1;
2904                        }
2905                    } else {
2906                        for (ci, col) in (0..width).step_by(s).enumerate() {
2907                            let p3 = st0[ci];
2908                            let p1 = st1[ci];
2909                            let n1 = st2[ci];
2910                            let n3 = data[n3_off + col] as i32;
2911
2912                            let a = p1 + n1;
2913                            let idx = k_off + col;
2914                            data[idx] =
2915                                (data[idx] as i32 + (((a << 3) + a - (p3 + n3) + 8) >> 4)) as i16;
2916
2917                            st0[ci] = p1;
2918                            st1[ci] = n1;
2919                            st2[ci] = n3;
2920                        }
2921                    }
2922                    k += 2;
2923                }
2924
2925                // tail
2926                while k <= kmax {
2927                    let k_off = (k << sd) * stride;
2928
2929                    if k < kmax {
2930                        if use_simd {
2931                            let mut ci = 0usize;
2932                            while ci < simd_cols {
2933                                let vp = load8_i32(&st1, ci);
2934                                let vn = load8_i32(&st2, ci);
2935                                let cur = load8s(data, k_off + ci * s, s);
2936                                store8s(data, k_off + ci * s, s, predict_avg(cur, vp, vn));
2937                                store8_i32(&mut st1, ci, vn);
2938                                store8_i32(&mut st2, ci, i32x8::splat(0));
2939                                ci += 8;
2940                            }
2941                            while ci < num_cols {
2942                                let p = st1[ci];
2943                                let n = st2[ci];
2944                                let idx = k_off + ci * s;
2945                                data[idx] = (data[idx] as i32 + ((p + n + 1) >> 1)) as i16;
2946                                st1[ci] = n;
2947                                st2[ci] = 0;
2948                                ci += 1;
2949                            }
2950                        } else {
2951                            for (ci, col) in (0..width).step_by(s).enumerate() {
2952                                let p = st1[ci];
2953                                let n = st2[ci];
2954                                let idx = k_off + col;
2955                                data[idx] = (data[idx] as i32 + ((p + n + 1) >> 1)) as i16;
2956                                st1[ci] = n;
2957                                st2[ci] = 0;
2958                            }
2959                        }
2960                    } else if use_simd {
2961                        let mut ci = 0usize;
2962                        while ci < simd_cols {
2963                            let vp = load8_i32(&st1, ci);
2964                            let cur = load8s(data, k_off + ci * s, s);
2965                            store8s(data, k_off + ci * s, s, cur + vp);
2966                            store8_i32(&mut st1, ci, load8_i32(&st2, ci));
2967                            store8_i32(&mut st2, ci, i32x8::splat(0));
2968                            ci += 8;
2969                        }
2970                        while ci < num_cols {
2971                            let p = st1[ci];
2972                            let idx = k_off + ci * s;
2973                            data[idx] = (data[idx] as i32 + p) as i16;
2974                            st1[ci] = st2[ci];
2975                            st2[ci] = 0;
2976                            ci += 1;
2977                        }
2978                    } else {
2979                        for (ci, col) in (0..width).step_by(s).enumerate() {
2980                            let p = st1[ci];
2981                            let idx = k_off + col;
2982                            data[idx] = (data[idx] as i32 + p) as i16;
2983                            st1[ci] = st2[ci];
2984                            st2[ci] = 0;
2985                        }
2986                    }
2987                    k += 2;
2988                }
2989            }
2990        }
2991
2992        // ── Row pass ─────────────────────────────────────────────────────────
2993        // Row pass SIMD works for any s — always enable it.
2994        row_pass_inner(data, width, height, stride, s, sd, true);
2995
2996        s >>= 1;
2997        s_degree = s_degree.saturating_sub(1);
2998    }
2999}
3000
3001// ---- Public API -------------------------------------------------------------
3002
3003/// Progressive IW44 wavelet image decoder.
3004///
3005/// Holds three independent planar decoders (Y, Cb, Cr) whose ZP context tables
3006/// persist across chunks, enabling progressive refinement.
3007///
3008/// ## Usage
3009///
3010/// ```no_run
3011/// use djvu_iw44::Iw44Image;
3012///
3013/// let chunk_data: &[u8] = &[]; // BG44 chunk bytes from the DjVu file
3014/// let mut img = Iw44Image::new();
3015/// // Feed each BG44 chunk in document order:
3016/// img.decode_chunk(chunk_data)?;
3017/// // Convert to an RGB pixmap once all desired chunks are decoded:
3018/// let pixmap = img.to_rgb()?;
3019/// # Ok::<(), djvu_iw44::Iw44Error>(())
3020/// ```
3021#[derive(Clone, Debug)]
3022pub struct Iw44Image {
3023    /// Luma plane dimensions (pixels, before subsampling).
3024    pub width: u32,
3025    /// Luma plane dimensions (pixels, before subsampling).
3026    pub height: u32,
3027    /// `true` for color (YCbCr) images, `false` for grayscale.
3028    is_color: bool,
3029    /// Number of Y slices decoded before chroma decoding starts.
3030    delay: u8,
3031    /// `true` if chroma planes are stored at half resolution.
3032    chroma_half: bool,
3033    /// Luma plane decoder.
3034    y: Option<PlaneDecoder>,
3035    /// Blue-difference chroma plane decoder (color images only).
3036    cb: Option<PlaneDecoder>,
3037    /// Red-difference chroma plane decoder (color images only).
3038    cr: Option<PlaneDecoder>,
3039    /// Total slices decoded so far (used to implement the color-delay counter).
3040    cslice: usize,
3041}
3042
3043impl Default for Iw44Image {
3044    fn default() -> Self {
3045        Self::new()
3046    }
3047}
3048
3049impl Iw44Image {
3050    /// Create a new, empty decoder.
3051    pub fn new() -> Self {
3052        Iw44Image {
3053            width: 0,
3054            height: 0,
3055            is_color: false,
3056            delay: 0,
3057            chroma_half: false,
3058            y: None,
3059            cb: None,
3060            cr: None,
3061            cslice: 0,
3062        }
3063    }
3064
3065    /// Returns the (width, height) of the Cb chroma plane as allocated.
3066    ///
3067    /// When `chroma_half=true` this should be `(ceil(w/2), ceil(h/2))`.
3068    /// Returns `None` if no color chunks have been decoded yet.
3069    #[cfg(test)]
3070    pub fn chroma_plane_dims(&self) -> Option<(usize, usize)> {
3071        self.cb.as_ref().map(|p| (p.width, p.height))
3072    }
3073
3074    /// Returns `true` if the image is a color (YCbCr) image.
3075    #[cfg(test)]
3076    pub fn is_color(&self) -> bool {
3077        self.is_color
3078    }
3079
3080    /// Returns `true` if chroma planes are stored at half resolution.
3081    #[cfg(test)]
3082    pub fn chroma_half(&self) -> bool {
3083        self.chroma_half
3084    }
3085
3086    /// Decode one BG44/FG44/TH44 chunk.
3087    ///
3088    /// Call this once for each chunk in document order.  The ZP coder state
3089    /// is maintained internally so progressive refinement works automatically.
3090    ///
3091    /// ## Chunk format
3092    ///
3093    /// - First chunk (`serial == 0`): 9-byte header then ZP-coded payload.
3094    /// - Subsequent chunks: 2-byte header (`serial`, `slices`) then ZP payload.
3095    pub fn decode_chunk(&mut self, data: &[u8]) -> Result<(), Iw44Error> {
3096        if data.len() < 2 {
3097            return Err(Iw44Error::ChunkTooShort);
3098        }
3099        let serial = data[0];
3100        let slices = data[1];
3101        let payload_start;
3102
3103        if serial == 0 {
3104            // First chunk — parse the 9-byte image header.
3105            if data.len() < 9 {
3106                return Err(Iw44Error::HeaderTooShort);
3107            }
3108            let majver = data[2];
3109            let minor = data[3];
3110            let is_grayscale = (majver >> 7) != 0;
3111            let w = u16::from_be_bytes([data[4], data[5]]);
3112            let h = u16::from_be_bytes([data[6], data[7]]);
3113            let delay_byte = data[8];
3114            let delay = if minor >= 2 { delay_byte & 127 } else { 0 };
3115            let chroma_half = minor >= 2 && (delay_byte & 0x80) == 0;
3116
3117            if w == 0 || h == 0 {
3118                return Err(Iw44Error::ZeroDimension);
3119            }
3120            // Prevent OOM / slow decode on malformed input.
3121            // 64 MP allows real scanned documents (e.g. 6780×9148 ≈ 62 MP at 600 dpi)
3122            // while bounding worst-case fuzz decode to ~3 s (vs 12 s at 256 MP).
3123            let pixels = w as u64 * h as u64;
3124            if pixels > 64 * 1024 * 1024 {
3125                return Err(Iw44Error::ImageTooLarge);
3126            }
3127
3128            self.width = w as u32;
3129            self.height = h as u32;
3130            self.is_color = !is_grayscale;
3131            self.delay = delay;
3132            self.chroma_half = self.is_color && chroma_half;
3133            self.cslice = 0;
3134            self.y = Some(PlaneDecoder::new(w as usize, h as usize));
3135            if self.is_color {
3136                let (cw, ch) = if self.chroma_half {
3137                    ((w as usize).div_ceil(2), (h as usize).div_ceil(2))
3138                } else {
3139                    (w as usize, h as usize)
3140                };
3141                self.cb = Some(PlaneDecoder::new(cw, ch));
3142                self.cr = Some(PlaneDecoder::new(cw, ch));
3143            }
3144            payload_start = 9;
3145        } else {
3146            if self.y.is_none() {
3147                return Err(Iw44Error::MissingFirstChunk);
3148            }
3149            payload_start = 2;
3150        }
3151
3152        let zp_data = &data[payload_start..];
3153        let mut zp = ZpDecoder::new(zp_data).map_err(|_| Iw44Error::ZpTooShort)?;
3154
3155        for _ in 0..slices {
3156            self.cslice += 1;
3157            if let Some(ref mut y) = self.y {
3158                y.decode_slice(&mut zp);
3159            }
3160            if self.is_color && self.cslice > self.delay as usize {
3161                if let Some(ref mut cb) = self.cb {
3162                    cb.decode_slice(&mut zp);
3163                }
3164                if let Some(ref mut cr) = self.cr {
3165                    cr.decode_slice(&mut zp);
3166                }
3167            }
3168            // Once all real input bytes are consumed the ZP coder returns
3169            // 0xFF indefinitely, producing deterministic but meaningless
3170            // bits. Remaining slices carry no new information, so stop early
3171            // to bound decode time on crafted inputs.
3172            if zp.is_exhausted() {
3173                break;
3174            }
3175        }
3176
3177        Ok(())
3178    }
3179
3180    /// Convert the decoded image to an RGB [`Pixmap`].
3181    ///
3182    /// This is the **only** place where the separate Y, Cb, Cr planes are
3183    /// interleaved into RGB pixels.  DjVu images are stored bottom-to-top;
3184    /// this method flips the output to top-to-bottom.
3185    ///
3186    /// Equivalent to `to_rgb_subsample(1)`.
3187    pub fn to_rgb(&self) -> Result<Pixmap, Iw44Error> {
3188        self.to_rgb_subsample(1)
3189    }
3190
3191    /// Convert to an RGB [`Pixmap`] at reduced resolution.
3192    ///
3193    /// `subsample` must be ≥ 1.  A value of 1 gives full resolution; 2 gives
3194    /// half resolution in each dimension, etc.
3195    pub fn to_rgb_subsample(&self, subsample: u32) -> Result<Pixmap, Iw44Error> {
3196        if subsample == 0 {
3197            return Err(Iw44Error::InvalidSubsample);
3198        }
3199        let y_dec = self.y.as_ref().ok_or(Iw44Error::MissingCodec)?;
3200        let sub = subsample as usize;
3201        let w = (self.width as usize).div_ceil(sub) as u32;
3202        let h = (self.height as usize).div_ceil(sub) as u32;
3203
3204        if self.is_color {
3205            // When chroma_half=true the chroma planes are stored at half luma
3206            // resolution.  Divide the subsample factor by 2 (minimum 1) so that
3207            // reconstruct() operates at the correct scale relative to the smaller
3208            // plane.
3209            let chroma_sub = if self.chroma_half {
3210                sub.div_ceil(2)
3211            } else {
3212                sub
3213            };
3214            let cb_dec = self.cb.as_ref().ok_or(Iw44Error::MissingCodec)?;
3215            let cr_dec = self.cr.as_ref().ok_or(Iw44Error::MissingCodec)?;
3216
3217            // Reconstruct Y, Cb and Cr planes.  With the `parallel` feature the
3218            // three independent inverse-wavelet-transforms run concurrently on
3219            // separate rayon threads, cutting the reconstruction wall-time from
3220            // Y+Cb+Cr sequential to max(Y, Cb, Cr) — roughly 1.5–2× faster on
3221            // large pages where Y dominates.
3222            #[cfg(feature = "parallel")]
3223            let (y_plane, cb_plane, cr_plane) = {
3224                let (y, (cb, cr)) = rayon::join(
3225                    || y_dec.reconstruct(sub),
3226                    || {
3227                        rayon::join(
3228                            || cb_dec.reconstruct(chroma_sub),
3229                            || cr_dec.reconstruct(chroma_sub),
3230                        )
3231                    },
3232                );
3233                (y, cb, cr)
3234            };
3235            #[cfg(not(feature = "parallel"))]
3236            let (y_plane, cb_plane, cr_plane) = (
3237                y_dec.reconstruct(sub),
3238                cb_dec.reconstruct(chroma_sub),
3239                cr_dec.reconstruct(chroma_sub),
3240            );
3241
3242            let pw = w as usize;
3243            let ph = h as usize;
3244            let mut pm = Pixmap::new(w, h, 0, 0, 0, 255);
3245
3246            // Fast path: sub=1 (most common — full-resolution render).
3247            // Pre-normalize Y/Cb/Cr into flat row buffers and apply the
3248            // YCbCr→RGBA formula 8 pixels at a time with SIMD.
3249            if sub == 1 {
3250                #[cfg(feature = "parallel")]
3251                {
3252                    use rayon::prelude::*;
3253                    let chroma_half = self.chroma_half;
3254                    pm.data
3255                        .par_chunks_mut(pw * 4)
3256                        .enumerate()
3257                        .for_each(|(out_row, row_data)| {
3258                            let row = ph - 1 - out_row; // DjVu rows are bottom-to-top
3259                            let y_off = row * y_plane.stride;
3260                            if chroma_half {
3261                                let c_row = row / 2;
3262                                let cb_off = c_row * cb_plane.stride;
3263                                let cr_off = c_row * cr_plane.stride;
3264                                ycbcr_row_from_i16_half(
3265                                    &y_plane.data[y_off..y_off + pw],
3266                                    &cb_plane.data[cb_off..],
3267                                    &cr_plane.data[cr_off..],
3268                                    row_data,
3269                                    pw,
3270                                );
3271                            } else {
3272                                let c_off = row * cb_plane.stride;
3273                                ycbcr_row_from_i16(
3274                                    &y_plane.data[y_off..y_off + pw],
3275                                    &cb_plane.data[c_off..c_off + pw],
3276                                    &cr_plane.data[c_off..c_off + pw],
3277                                    row_data,
3278                                );
3279                            }
3280                        });
3281                }
3282                #[cfg(not(feature = "parallel"))]
3283                {
3284                    for row in 0..ph {
3285                        let out_row = ph - 1 - row; // DjVu rows are bottom-to-top
3286                        let y_off = row * y_plane.stride;
3287                        let row_start = out_row * pw * 4;
3288
3289                        if self.chroma_half {
3290                            let c_row = row / 2;
3291                            let cb_off = c_row * cb_plane.stride;
3292                            let cr_off = c_row * cr_plane.stride;
3293                            ycbcr_row_from_i16_half(
3294                                &y_plane.data[y_off..y_off + pw],
3295                                &cb_plane.data[cb_off..],
3296                                &cr_plane.data[cr_off..],
3297                                &mut pm.data[row_start..row_start + pw * 4],
3298                                pw,
3299                            );
3300                        } else {
3301                            let c_off = row * cb_plane.stride;
3302                            ycbcr_row_from_i16(
3303                                &y_plane.data[y_off..y_off + pw],
3304                                &cb_plane.data[c_off..c_off + pw],
3305                                &cr_plane.data[c_off..c_off + pw],
3306                                &mut pm.data[row_start..row_start + pw * 4],
3307                            );
3308                        }
3309                    }
3310                }
3311                return Ok(pm);
3312            }
3313
3314            // Compact path: sub ≥ 2 with power-of-two subsample.
3315            //
3316            // `reconstruct(sub)` now returns a plane that is already at the
3317            // target resolution (ceil(w/sub) × ceil(h/sub)), so we access it
3318            // with sub=1 indexing.  Chroma planes are at the same output size
3319            // (the chroma_half factor is absorbed into chroma_sub), so no
3320            // chroma_half division is needed here.
3321            //
3322            // Uses SIMD via `ycbcr_row_to_rgba` (same as the sub=1 fast path).
3323            if (2..=8).contains(&sub) && sub.is_power_of_two() {
3324                for row in 0..ph {
3325                    let out_row = ph - 1 - row; // DjVu rows are bottom-to-top
3326                    let y_off = row * y_plane.stride;
3327                    let c_off = row * cb_plane.stride;
3328                    let row_start = out_row * pw * 4;
3329                    ycbcr_row_from_i16(
3330                        &y_plane.data[y_off..y_off + pw],
3331                        &cb_plane.data[c_off..c_off + pw],
3332                        &cr_plane.data[c_off..c_off + pw],
3333                        &mut pm.data[row_start..row_start + pw * 4],
3334                    );
3335                }
3336                return Ok(pm);
3337            }
3338
3339            // Fallback scalar path for non-power-of-two or large sub values.
3340            for row in 0..h {
3341                let out_row = h - 1 - row;
3342                for col in 0..w {
3343                    let src_row = row as usize * sub;
3344                    let src_col = col as usize * sub;
3345                    let y_idx = src_row * y_plane.stride + src_col;
3346                    let chroma_row = if self.chroma_half {
3347                        src_row / 2
3348                    } else {
3349                        src_row
3350                    };
3351                    let chroma_col = if self.chroma_half {
3352                        src_col / 2
3353                    } else {
3354                        src_col
3355                    };
3356                    let c_idx = chroma_row * cb_plane.stride + chroma_col;
3357
3358                    let y = normalize(y_plane.data[y_idx]);
3359                    let b = normalize(cb_plane.data[c_idx]);
3360                    let r = normalize(cr_plane.data[c_idx]);
3361
3362                    let t2 = r + (r >> 1);
3363                    let t3 = y + 128 - (b >> 2);
3364
3365                    let red = (y + 128 + t2).clamp(0, 255) as u8;
3366                    let green = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
3367                    let blue = (t3 + (b << 1)).clamp(0, 255) as u8;
3368                    pm.set_rgb(col, out_row, red, green, blue);
3369                }
3370            }
3371            Ok(pm)
3372        } else {
3373            // Grayscale: only the Y plane is needed.
3374            // For sub≥2 the plane is compact (at output resolution); for sub=1 it
3375            // is full-resolution.  Use compact-aware indexing.
3376            let y_plane = y_dec.reconstruct(sub);
3377            let is_compact = (2..=8).contains(&sub) && sub.is_power_of_two();
3378            let mut pm = Pixmap::new(w, h, 0, 0, 0, 255);
3379            for row in 0..h {
3380                let out_row = h - 1 - row;
3381                for col in 0..w {
3382                    let (src_row, src_col) = if is_compact {
3383                        (row as usize, col as usize)
3384                    } else {
3385                        (row as usize * sub, col as usize * sub)
3386                    };
3387                    let idx = src_row * y_plane.stride + src_col;
3388                    let val = normalize(y_plane.data[idx]);
3389                    // Grayscale: DjVu luma 0 maps to black, −128 → white
3390                    let gray = (127 - val) as u8;
3391                    pm.set_rgb(col, out_row, gray, gray, gray);
3392                }
3393            }
3394            Ok(pm)
3395        }
3396    }
3397}
3398
3399// ---- Tests ------------------------------------------------------------------
3400
3401#[cfg(test)]
3402mod tests {
3403    use super::*;
3404
3405    fn assets_path() -> std::path::PathBuf {
3406        std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
3407            .join("../../references/djvujs/library/assets")
3408    }
3409
3410    fn golden_path() -> std::path::PathBuf {
3411        std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../tests/golden/iw44")
3412    }
3413
3414    /// Extract all BG44 chunk payloads from the first DJVU form in the file.
3415    fn extract_bg44_chunks(file: &djvu_iff::DjvuFile) -> Vec<&[u8]> {
3416        fn collect(chunk: &djvu_iff::Chunk) -> Option<Vec<&[u8]>> {
3417            match chunk {
3418                djvu_iff::Chunk::Form {
3419                    secondary_id,
3420                    children,
3421                    ..
3422                } => {
3423                    if secondary_id == b"DJVU" {
3424                        let v = children
3425                            .iter()
3426                            .filter_map(|c| match c {
3427                                djvu_iff::Chunk::Leaf {
3428                                    id: [b'B', b'G', b'4', b'4'],
3429                                    data,
3430                                } => Some(data.as_slice()),
3431                                _ => None,
3432                            })
3433                            .collect::<Vec<_>>();
3434                        return Some(v);
3435                    }
3436                    for c in children {
3437                        if let Some(v) = collect(c) {
3438                            return Some(v);
3439                        }
3440                    }
3441                    None
3442                }
3443                _ => None,
3444            }
3445        }
3446        collect(&file.root).unwrap_or_default()
3447    }
3448
3449    fn find_ppm_data_start(ppm: &[u8]) -> usize {
3450        let mut newlines = 0;
3451        for (i, &b) in ppm.iter().enumerate() {
3452            if b == b'\n' {
3453                newlines += 1;
3454                if newlines == 3 {
3455                    return i + 1;
3456                }
3457            }
3458        }
3459        0
3460    }
3461
3462    /// Compare `actual_ppm` against a golden file, creating it on first run.
3463    ///
3464    /// If the file doesn't exist it is written (first-time generation).
3465    /// On subsequent runs an exact byte-for-byte comparison is enforced so that
3466    /// any accidental change to the pixel output is caught immediately.
3467    fn assert_or_create_golden(actual_ppm: &[u8], golden_file: &str) {
3468        let path = golden_path().join(golden_file);
3469        if !path.exists() {
3470            std::fs::write(&path, actual_ppm)
3471                .unwrap_or_else(|e| panic!("failed to write golden {golden_file}: {e}"));
3472            return; // golden created — test passes on first run
3473        }
3474        assert_ppm_match(actual_ppm, golden_file);
3475    }
3476
3477    fn assert_ppm_match(actual_ppm: &[u8], golden_file: &str) {
3478        let expected_ppm = std::fs::read(golden_path().join(golden_file))
3479            .unwrap_or_else(|_| panic!("golden file not found: {}", golden_file));
3480        assert_eq!(
3481            actual_ppm.len(),
3482            expected_ppm.len(),
3483            "PPM size mismatch for {}: got {} expected {}",
3484            golden_file,
3485            actual_ppm.len(),
3486            expected_ppm.len()
3487        );
3488        if actual_ppm != expected_ppm {
3489            let header_end = find_ppm_data_start(actual_ppm);
3490            let actual_pixels = &actual_ppm[header_end..];
3491            let expected_pixels = &expected_ppm[header_end..];
3492            let total_pixels = actual_pixels.len() / 3;
3493            let diff_pixels = actual_pixels
3494                .chunks(3)
3495                .zip(expected_pixels.chunks(3))
3496                .filter(|(a, b)| a != b)
3497                .count();
3498            panic!(
3499                "{} pixel mismatch: {}/{} pixels differ ({:.1}%)",
3500                golden_file,
3501                diff_pixels,
3502                total_pixels,
3503                diff_pixels as f64 / total_pixels as f64 * 100.0
3504            );
3505        }
3506    }
3507
3508    // ---- TDD: failing tests first -------------------------------------------
3509
3510    /// Decode must fail gracefully on empty input.
3511    #[test]
3512    fn iw44_new_rejects_empty_chunk() {
3513        let mut img = Iw44Image::new();
3514        assert!(matches!(
3515            img.decode_chunk(&[]),
3516            Err(Iw44Error::ChunkTooShort)
3517        ));
3518    }
3519
3520    /// Decode must fail gracefully on a truncated first-chunk header.
3521    #[test]
3522    fn iw44_new_rejects_truncated_header() {
3523        let mut img = Iw44Image::new();
3524        // serial=0 but only 5 bytes (need ≥ 9)
3525        assert!(matches!(
3526            img.decode_chunk(&[0x00, 0x01, 0x00, 0x02, 0x00]),
3527            Err(Iw44Error::HeaderTooShort)
3528        ));
3529    }
3530
3531    /// Zero-dimension image must be rejected.
3532    #[test]
3533    fn iw44_new_rejects_zero_dimension() {
3534        let mut img = Iw44Image::new();
3535        // serial=0, slices=1, majver=0, minor=2, w=0, h=100, delay=0
3536        let header = [0x00u8, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x64, 0x00];
3537        assert!(matches!(
3538            img.decode_chunk(&header),
3539            Err(Iw44Error::ZeroDimension)
3540        ));
3541    }
3542
3543    /// Subsequent chunk before first chunk must be rejected.
3544    #[test]
3545    fn iw44_new_rejects_subsequent_before_first() {
3546        let mut img = Iw44Image::new();
3547        // serial != 0
3548        assert!(matches!(
3549            img.decode_chunk(&[0x01, 0x01]),
3550            Err(Iw44Error::MissingFirstChunk)
3551        ));
3552    }
3553
3554    /// `to_rgb()` on an uninitialised decoder must return an error.
3555    #[test]
3556    fn iw44_new_to_rgb_without_data_returns_error() {
3557        let img = Iw44Image::new();
3558        assert!(matches!(img.to_rgb(), Err(Iw44Error::MissingCodec)));
3559    }
3560
3561    /// `to_rgb_subsample(0)` must be rejected.
3562    #[test]
3563    fn iw44_new_subsample_zero_rejected() {
3564        let img = Iw44Image::new();
3565        assert!(matches!(
3566            img.to_rgb_subsample(0),
3567            Err(Iw44Error::InvalidSubsample)
3568        ));
3569    }
3570
3571    // ---- Pixel-exact golden tests -------------------------------------------
3572
3573    #[test]
3574    fn iw44_new_decode_boy_bg() {
3575        let data = std::fs::read(assets_path().join("boy.djvu")).expect("boy.djvu not found");
3576        let file = djvu_iff::parse(&data).expect("failed to parse boy.djvu");
3577        let chunks = extract_bg44_chunks(&file);
3578        assert_eq!(chunks.len(), 1, "expected 1 BG44 chunk in boy.djvu");
3579
3580        let mut img = Iw44Image::new();
3581        for c in &chunks {
3582            img.decode_chunk(c).expect("decode_chunk failed");
3583        }
3584        assert_eq!(img.width, 192);
3585        assert_eq!(img.height, 256);
3586
3587        let pm = img.to_rgb().expect("to_rgb failed");
3588        assert_ppm_match(&pm.to_ppm(), "boy_bg.ppm");
3589    }
3590
3591    #[test]
3592    fn iw44_new_decode_chicken_bg() {
3593        let data =
3594            std::fs::read(assets_path().join("chicken.djvu")).expect("chicken.djvu not found");
3595        let file = djvu_iff::parse(&data).expect("failed to parse chicken.djvu");
3596        let chunks = extract_bg44_chunks(&file);
3597        assert_eq!(chunks.len(), 3, "expected 3 BG44 chunks in chicken.djvu");
3598
3599        let mut img = Iw44Image::new();
3600        for c in &chunks {
3601            img.decode_chunk(c).expect("decode_chunk failed");
3602        }
3603        assert_eq!(img.width, 181);
3604        assert_eq!(img.height, 240);
3605
3606        let pm = img.to_rgb().expect("to_rgb failed");
3607        assert_ppm_match(&pm.to_ppm(), "chicken_bg.ppm");
3608    }
3609
3610    /// `to_rgb_subsample(2)` on boy.djvu must produce a pixel-exact result.
3611    ///
3612    /// This golden test guards against any regression in the compact-plane sub=2
3613    /// optimization path.  On first run the golden file is created from the
3614    /// current (correct) output; subsequent runs compare against it.
3615    #[test]
3616    fn iw44_new_decode_boy_sub2() {
3617        let data = std::fs::read(assets_path().join("boy.djvu")).expect("boy.djvu not found");
3618        let file = djvu_iff::parse(&data).expect("failed to parse boy.djvu");
3619        let chunks = extract_bg44_chunks(&file);
3620
3621        let mut img = Iw44Image::new();
3622        for c in &chunks {
3623            img.decode_chunk(c).expect("decode_chunk failed");
3624        }
3625        assert_eq!(img.width, 192);
3626        assert_eq!(img.height, 256);
3627
3628        let pm = img.to_rgb_subsample(2).expect("to_rgb_subsample(2) failed");
3629        assert_eq!(pm.width, 96, "sub=2 width must be ceil(192/2)");
3630        assert_eq!(pm.height, 128, "sub=2 height must be ceil(256/2)");
3631
3632        assert_or_create_golden(&pm.to_ppm(), "boy_bg_sub2.ppm");
3633    }
3634
3635    /// `to_rgb_subsample(2)` on big-scanned-page.djvu (color IW44).
3636    ///
3637    /// Exercises the compact-plane path on a large color document.
3638    #[test]
3639    fn iw44_new_decode_big_scanned_sub2() {
3640        let data = std::fs::read(assets_path().join("big-scanned-page.djvu"))
3641            .expect("big-scanned-page.djvu not found");
3642        let file = djvu_iff::parse(&data).expect("failed to parse big-scanned-page.djvu");
3643        let chunks = extract_bg44_chunks(&file);
3644
3645        let mut img = Iw44Image::new();
3646        for c in &chunks {
3647            img.decode_chunk(c).expect("decode_chunk failed");
3648        }
3649        assert_eq!(img.width, 6780);
3650        assert_eq!(img.height, 9148);
3651
3652        let pm = img.to_rgb_subsample(2).expect("to_rgb_subsample(2) failed");
3653        assert_eq!(pm.width, 3390, "sub=2 width must be ceil(6780/2)");
3654        assert_eq!(pm.height, 4574, "sub=2 height must be ceil(9148/2)");
3655
3656        assert_or_create_golden(&pm.to_ppm(), "big_scanned_sub2.ppm");
3657    }
3658
3659    #[test]
3660    fn iw44_new_decode_big_scanned_sub4() {
3661        let data = std::fs::read(assets_path().join("big-scanned-page.djvu"))
3662            .expect("big-scanned-page.djvu not found");
3663        let file = djvu_iff::parse(&data).expect("failed to parse big-scanned-page.djvu");
3664        let chunks = extract_bg44_chunks(&file);
3665        assert_eq!(chunks.len(), 4, "expected 4 BG44 chunks");
3666
3667        let mut img = Iw44Image::new();
3668        for c in &chunks {
3669            img.decode_chunk(c).expect("decode_chunk failed");
3670        }
3671        assert_eq!(img.width, 6780);
3672        assert_eq!(img.height, 9148);
3673
3674        let pm = img.to_rgb_subsample(4).expect("to_rgb_subsample failed");
3675        assert_ppm_match(&pm.to_ppm(), "big_scanned_sub4.ppm");
3676    }
3677
3678    /// Progressive decode: feeding all chunks at once and feeding them one-by-one
3679    /// must produce identical results.
3680    #[test]
3681    fn iw44_new_progressive_matches_full_decode_chicken() {
3682        let data =
3683            std::fs::read(assets_path().join("chicken.djvu")).expect("chicken.djvu not found");
3684        let file = djvu_iff::parse(&data).expect("failed to parse");
3685        let chunks = extract_bg44_chunks(&file);
3686        assert!(
3687            chunks.len() > 1,
3688            "need multiple chunks for progressive test"
3689        );
3690
3691        // Full decode (all chunks at once via repeated decode_chunk calls)
3692        let mut full = Iw44Image::new();
3693        for c in &chunks {
3694            full.decode_chunk(c).expect("full decode failed");
3695        }
3696        let full_pm = full.to_rgb().expect("full to_rgb failed");
3697
3698        // Progressive decode — same result since ZP state persists
3699        let mut prog = Iw44Image::new();
3700        for c in chunks.iter().take(1) {
3701            prog.decode_chunk(c).expect("progressive decode failed");
3702        }
3703        for c in chunks.iter().skip(1) {
3704            prog.decode_chunk(c).expect("progressive decode failed");
3705        }
3706        let prog_pm = prog.to_rgb().expect("progressive to_rgb failed");
3707
3708        assert_eq!(
3709            full_pm.data, prog_pm.data,
3710            "progressive and full decode must produce identical pixels"
3711        );
3712    }
3713
3714    // ── chroma_half allocation test ──────────────────────────────────────────
3715
3716    /// When `chroma_half=true`, chroma planes must be allocated at half
3717    /// resolution (ceil(w/2) × ceil(h/2)), not at full luma resolution.
3718    ///
3719    /// carte.djvu is a color image with chroma_half=true (w=1400, h=852).
3720    #[test]
3721    fn chroma_half_allocates_half_size_plane() {
3722        let data = std::fs::read(assets_path().join("carte.djvu")).expect("carte.djvu not found");
3723        let file = djvu_iff::parse(&data).expect("iff parse");
3724        let chunks = extract_bg44_chunks(&file);
3725        assert!(!chunks.is_empty(), "carte.djvu must have BG44 chunks");
3726
3727        let mut img = Iw44Image::new();
3728        img.decode_chunk(chunks[0]).expect("decode_chunk");
3729
3730        assert!(img.is_color(), "carte.djvu must be a color image");
3731        assert!(img.chroma_half(), "carte.djvu must have chroma_half=true");
3732        let (cw, ch) = img
3733            .chroma_plane_dims()
3734            .expect("chroma plane must be allocated after first color chunk");
3735        let lw = img.width as usize;
3736        let lh = img.height as usize;
3737        let expected_w = lw.div_ceil(2);
3738        let expected_h = lh.div_ceil(2);
3739        assert_eq!(
3740            cw, expected_w,
3741            "chroma plane width must be ceil(luma_w/2)={expected_w}, got {cw}"
3742        );
3743        assert_eq!(
3744            ch, expected_h,
3745            "chroma plane height must be ceil(luma_h/2)={expected_h}, got {ch}"
3746        );
3747    }
3748
3749    /// Decode carte.djvu (chroma_half=true color image) fully and compare
3750    /// pixel output to the golden reference, ensuring the half-plane allocation
3751    /// does not corrupt the decoded image.
3752    #[test]
3753    fn iw44_new_decode_carte_bg_chroma_half() {
3754        let data = std::fs::read(assets_path().join("carte.djvu")).expect("carte.djvu not found");
3755        let file = djvu_iff::parse(&data).expect("iff parse");
3756        let chunks = extract_bg44_chunks(&file);
3757
3758        let mut img = Iw44Image::new();
3759        for c in &chunks {
3760            img.decode_chunk(c).expect("decode_chunk failed");
3761        }
3762        assert_eq!(img.width, 1400);
3763        assert_eq!(img.height, 852);
3764
3765        let pm = img.to_rgb().expect("to_rgb failed");
3766        assert_ppm_match(&pm.to_ppm(), "carte_bg.ppm");
3767    }
3768
3769    // ── Error path tests ────────────────────────────────────────────────────
3770
3771    #[test]
3772    fn test_decode_empty_chunk() {
3773        let mut img = Iw44Image::new();
3774        let result = img.decode_chunk(&[]);
3775        assert!(result.is_err());
3776    }
3777
3778    #[test]
3779    fn test_decode_truncated_header() {
3780        let mut img = Iw44Image::new();
3781        // Only 2 bytes — not enough for a header
3782        let result = img.decode_chunk(&[0x00, 0x01]);
3783        assert!(result.is_err());
3784    }
3785
3786    #[test]
3787    fn test_to_rgb_before_decode() {
3788        let img = Iw44Image::new();
3789        // No chunks decoded yet — should fail
3790        let result = img.to_rgb();
3791        assert!(result.is_err());
3792    }
3793
3794    #[test]
3795    fn test_to_rgb_subsample_zero() {
3796        let img = Iw44Image::new();
3797        let result = img.to_rgb_subsample(0);
3798        assert!(result.is_err());
3799    }
3800
3801    // ---- SIMD YCbCr→RGBA tests -----------------------------------------------
3802
3803    /// `ycbcr_row_to_rgba` matches the scalar formula on synthetic data.
3804    #[test]
3805    fn simd_ycbcr_row_matches_scalar() {
3806        // Cover all 8-wide SIMD chunks plus a tail (n=20).
3807        let n = 20usize;
3808        let ys: Vec<i32> = (0..n).map(|i| (i as i32 * 7) % 200 - 100).collect();
3809        let bs: Vec<i32> = (0..n).map(|i| (i as i32 * 13) % 200 - 100).collect();
3810        let rs: Vec<i32> = (0..n).map(|i| (i as i32 * 17) % 200 - 100).collect();
3811
3812        // Scalar reference
3813        let mut expected = vec![0u8; n * 4];
3814        for col in 0..n {
3815            let y = ys[col];
3816            let b = bs[col];
3817            let r = rs[col];
3818            let t2 = r + (r >> 1);
3819            let t3 = y + 128 - (b >> 2);
3820            expected[col * 4] = (y + 128 + t2).clamp(0, 255) as u8;
3821            expected[col * 4 + 1] = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
3822            expected[col * 4 + 2] = (t3 + (b << 1)).clamp(0, 255) as u8;
3823            expected[col * 4 + 3] = 255;
3824        }
3825
3826        // SIMD result
3827        let mut actual = vec![0u8; n * 4];
3828        super::ycbcr_row_to_rgba(&ys, &bs, &rs, &mut actual);
3829
3830        assert_eq!(
3831            expected, actual,
3832            "SIMD must produce identical output to scalar"
3833        );
3834    }
3835
3836    /// `ycbcr_row_to_rgba` handles extreme values (clamping at 0 and 255).
3837    #[test]
3838    fn simd_ycbcr_row_clamps_correctly() {
3839        let n = 8usize;
3840        // Use values that will clamp to 0 and 255 in each channel.
3841        let ys: Vec<i32> = vec![127, -128, 127, -128, 0, 0, 0, 0];
3842        let bs: Vec<i32> = vec![-128, 127, -128, 127, 0, 0, 0, 0];
3843        let rs: Vec<i32> = vec![127, -128, -128, 127, 0, 0, 0, 0];
3844
3845        let mut simd_out = vec![0u8; n * 4];
3846        super::ycbcr_row_to_rgba(&ys, &bs, &rs, &mut simd_out);
3847
3848        // All RGBA values must be in [0, 255] and alpha == 255.
3849        for chunk in simd_out.chunks_exact(4) {
3850            assert_eq!(chunk[3], 255, "alpha must always be 255");
3851        }
3852    }
3853
3854    /// SIMD render of boy.djvu produces identical output to the scalar path.
3855    ///
3856    /// This verifies that the fast path (sub=1) and general path (sub=2, which
3857    /// uses the old scalar code) produce consistent results on a real file.
3858    #[test]
3859    fn simd_render_matches_subsampled_render_dimensions() {
3860        let data = std::fs::read(assets_path().join("boy.djvu")).expect("boy.djvu not found");
3861        let file = djvu_iff::parse(&data).expect("parse failed");
3862        let chunks = extract_bg44_chunks(&file);
3863
3864        let mut img = Iw44Image::new();
3865        for c in &chunks {
3866            img.decode_chunk(c).expect("decode_chunk failed");
3867        }
3868
3869        // Full-resolution render uses SIMD path (sub=1).
3870        let full = img.to_rgb().expect("to_rgb failed");
3871        // sub=2 uses the scalar general path — just check dims match half.
3872        let half = img.to_rgb_subsample(2).expect("subsample(2) failed");
3873
3874        assert_eq!(full.width, img.width);
3875        assert_eq!(full.height, img.height);
3876        assert_eq!(half.width, img.width.div_ceil(2));
3877        assert_eq!(half.height, img.height.div_ceil(2));
3878        // SIMD path must still pass the existing golden test (done in iw44_new_decode_boy_bg).
3879    }
3880
3881    /// SIMD row pass (8 rows at a time) produces identical results to the scalar
3882    /// path on a synthetic 32×16 plane with a deterministic non-trivial pattern.
3883    ///
3884    /// Both paths are exercised by calling `row_pass_inner` with `use_simd=false`
3885    /// (all scalar) and `use_simd=true` (SIMD + scalar tail) on identical copies
3886    /// of the same data.
3887    #[test]
3888    fn simd_row_pass_matches_scalar() {
3889        let width = 32usize;
3890        let height = 16usize;
3891        let stride = width;
3892        let n = stride * height;
3893
3894        // Deterministic non-trivial pattern: values in [-255, 255].
3895        let initial: Vec<i16> = (0..n).map(|i| ((i * 7 + 13) % 511) as i16 - 255).collect();
3896
3897        let mut scalar_data = initial.clone();
3898        // s=1, sd=0, use_simd=false → pure scalar
3899        super::row_pass_inner(&mut scalar_data, width, height, stride, 1, 0, false);
3900
3901        let mut simd_data = initial.clone();
3902        // s=1, sd=0, use_simd=true → SIMD for rows 0..15, scalar tail for remainder
3903        super::row_pass_inner(&mut simd_data, width, height, stride, 1, 0, true);
3904
3905        assert_eq!(
3906            scalar_data, simd_data,
3907            "SIMD row pass must produce identical output to scalar"
3908        );
3909    }
3910
3911    /// Same as `simd_row_pass_matches_scalar` but for s=2 (sd=1).
3912    ///
3913    /// Active rows are every other row; active columns are every other column.
3914    /// The generalised SIMD path (8 active rows at a time with stride s) must
3915    /// produce the same result as the pure scalar path.
3916    #[test]
3917    fn simd_row_pass_s2_matches_scalar() {
3918        let width = 64usize;
3919        let height = 32usize;
3920        let stride = width;
3921        let n = stride * height;
3922        let s = 2usize;
3923        let sd = 1usize;
3924
3925        let initial: Vec<i16> = (0..n).map(|i| ((i * 7 + 13) % 511) as i16 - 255).collect();
3926
3927        let mut scalar_data = initial.clone();
3928        super::row_pass_inner(&mut scalar_data, width, height, stride, s, sd, false);
3929
3930        let mut simd_data = initial.clone();
3931        super::row_pass_inner(&mut simd_data, width, height, stride, s, sd, true);
3932
3933        assert_eq!(
3934            scalar_data, simd_data,
3935            "SIMD row pass (s=2) must produce identical output to scalar"
3936        );
3937    }
3938
3939    /// Reference scalar implementation of the fused-normalize YCbCr→RGBA path.
3940    /// Mirrors `ycbcr_neon_raw` byte-for-byte (same formula, same clamps).
3941    #[cfg(all(target_arch = "x86_64", feature = "std"))]
3942    fn ycbcr_raw_scalar(y: &[i16], cb: &[i16], cr: &[i16], out: &mut [u8]) {
3943        let w = y.len();
3944        for col in 0..w {
3945            let yn = super::normalize(y[col]);
3946            let bn = super::normalize(cb[col]);
3947            let rn = super::normalize(cr[col]);
3948            let t2 = rn + (rn >> 1);
3949            let t3 = yn + 128 - (bn >> 2);
3950            out[col * 4] = (yn + 128 + t2).clamp(0, 255) as u8;
3951            out[col * 4 + 1] = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
3952            out[col * 4 + 2] = (t3 + (bn << 1)).clamp(0, 255) as u8;
3953            out[col * 4 + 3] = 255;
3954        }
3955    }
3956
3957    #[cfg(all(target_arch = "x86_64", feature = "std"))]
3958    fn ycbcr_raw_half_scalar(y: &[i16], cb: &[i16], cr: &[i16], out: &mut [u8]) {
3959        let w = y.len();
3960        for col in 0..w {
3961            let yn = super::normalize(y[col]);
3962            let bn = super::normalize(cb[col / 2]);
3963            let rn = super::normalize(cr[col / 2]);
3964            let t2 = rn + (rn >> 1);
3965            let t3 = yn + 128 - (bn >> 2);
3966            out[col * 4] = (yn + 128 + t2).clamp(0, 255) as u8;
3967            out[col * 4 + 1] = (t3 - (t2 >> 1)).clamp(0, 255) as u8;
3968            out[col * 4 + 2] = (t3 + (bn << 1)).clamp(0, 255) as u8;
3969            out[col * 4 + 3] = 255;
3970        }
3971    }
3972
3973    /// AVX2 fused-normalize YCbCr→RGBA must agree byte-for-byte with the scalar
3974    /// reference across the full i16 input range and all width residues mod 16
3975    /// (covers main loop + scalar tail).
3976    #[cfg(all(target_arch = "x86_64", feature = "std"))]
3977    #[test]
3978    fn ycbcr_avx2_raw_matches_scalar() {
3979        if !std::is_x86_feature_detected!("avx2") {
3980            eprintln!("skipping: AVX2 not available on this host");
3981            return;
3982        }
3983        // Range chosen to exercise normalize + clamp + every arithmetic branch.
3984        let raw_vals: [i16; 8] = [-32768, -8192, -64, -1, 0, 63, 8191, 32767];
3985        for &width in &[1usize, 7, 16, 17, 31, 32, 33, 47, 48, 64, 100] {
3986            let n = width;
3987            let make_seq = |seed: usize| -> Vec<i16> {
3988                (0..n)
3989                    .map(|i| raw_vals[(i + seed) % raw_vals.len()])
3990                    .collect()
3991            };
3992            let y = make_seq(0);
3993            let cb = make_seq(3);
3994            let cr = make_seq(5);
3995
3996            let mut got = vec![0u8; n * 4];
3997            #[allow(unsafe_code)]
3998            unsafe {
3999                super::ycbcr_avx2_raw(y.as_ptr(), cb.as_ptr(), cr.as_ptr(), got.as_mut_ptr(), n);
4000            }
4001
4002            let mut want = vec![0u8; n * 4];
4003            ycbcr_raw_scalar(&y, &cb, &cr, &mut want);
4004
4005            assert_eq!(got, want, "AVX2 raw mismatch at width {}", width);
4006        }
4007    }
4008
4009    /// AVX2 stride-1 load/store must round-trip the full i16 range
4010    /// bit-exactly through an i32x8.
4011    #[cfg(all(target_arch = "x86_64", feature = "std"))]
4012    #[test]
4013    fn load8s_s1_avx2_matches_scalar() {
4014        if !std::is_x86_feature_detected!("avx2") {
4015            eprintln!("skipping: AVX2 not available on this host");
4016            return;
4017        }
4018        let raw_vals: [i16; 8] = [-32768, -8192, -64, -1, 0, 63, 8191, 32767];
4019        let n = 64;
4020        let buf: Vec<i16> = (0..n).map(|i| raw_vals[i % raw_vals.len()]).collect();
4021        for phys_off in 0..(n - 8) {
4022            #[allow(unsafe_code)]
4023            let got = unsafe { super::load8s_s1_avx2(&buf, phys_off) };
4024            let want = super::load8s(&buf, phys_off, 1);
4025            assert_eq!(
4026                got.to_array(),
4027                want.to_array(),
4028                "AVX2 load8s_s1 mismatch at phys_off {}",
4029                phys_off
4030            );
4031        }
4032    }
4033
4034    /// AVX2 stride-1 store must truncate i32→i16 (drop upper 16 bits, no
4035    /// saturation) matching the scalar `as i16` cast for every input.
4036    #[cfg(all(target_arch = "x86_64", feature = "std"))]
4037    #[test]
4038    fn store8s_s1_avx2_matches_scalar() {
4039        if !std::is_x86_feature_detected!("avx2") {
4040            eprintln!("skipping: AVX2 not available on this host");
4041            return;
4042        }
4043        // Inputs that exercise truncation: values that don't fit in i16,
4044        // negative values, and boundaries.
4045        let raw_vals: [i32; 8] = [i32::MIN, -100_000, -32768, -1, 0, 32767, 100_000, i32::MAX];
4046        for offset in 0..8usize {
4047            let mut input = [0i32; 8];
4048            for j in 0..8 {
4049                input[j] = raw_vals[(j + offset) % 8];
4050            }
4051            let v = wide::i32x8::from(input);
4052
4053            // AVX2 store with surrounding sentinel bytes to detect over-write.
4054            let mut buf_avx2 = vec![0xABCDu16 as i16; 32];
4055            #[allow(unsafe_code)]
4056            unsafe {
4057                super::store8s_s1_avx2(&mut buf_avx2, 8, v);
4058            }
4059            // Scalar reference using stride-1 store (which on this host is
4060            // also the AVX2 path; route through stride-2 to force scalar).
4061            let mut buf_scalar = vec![0xABCDu16 as i16; 32];
4062            for j in 0..8 {
4063                buf_scalar[8 + j] = input[j] as i16;
4064            }
4065            assert_eq!(buf_avx2, buf_scalar, "AVX2 store8s_s1 mismatch");
4066        }
4067    }
4068
4069    /// AVX2 `prelim_flags_bucket_avx2` must produce identical bucket bytes
4070    /// and bstatetmp to the scalar fallback for any 16-coef input.
4071    #[cfg(all(target_arch = "x86_64", feature = "std"))]
4072    #[test]
4073    fn prelim_flags_bucket_avx2_matches_scalar() {
4074        if !std::is_x86_feature_detected!("avx2") {
4075            eprintln!("skipping: AVX2 not available on this host");
4076            return;
4077        }
4078        // Inputs that exercise the all-zero, all-nonzero, mixed, and edge-value cases.
4079        let test_vectors: &[[i16; 16]] = &[
4080            [0; 16],
4081            [
4082                1, 0, -1, 0, 100, 0, -200, 0, 0, 1234, 0, -1234, 0, 32767, -32768, 0,
4083            ],
4084            [1; 16],
4085            [-1; 16],
4086            [
4087                32767, -32768, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 32767, -32768, 1, -1,
4088            ],
4089        ];
4090        for &coefs in test_vectors {
4091            let mut block = [0i16; 1024];
4092            // Place the 16 coefs at base = 0 *and* at a non-zero base to test the offset.
4093            for &base in &[0usize, 16, 32, 1008] {
4094                block[base..base + 16].copy_from_slice(&coefs);
4095
4096                let mut bucket_avx2 = [0u8; 16];
4097                #[allow(unsafe_code)]
4098                let bstate_avx2 =
4099                    unsafe { super::prelim_flags_bucket_avx2(&block, base, &mut bucket_avx2) };
4100
4101                let mut bucket_scalar = [0u8; 16];
4102                let mut bstate_scalar = 0u8;
4103                for k in 0..16 {
4104                    let f = if block[base + k] == 0 {
4105                        super::UNK
4106                    } else {
4107                        super::ACTIVE
4108                    };
4109                    bucket_scalar[k] = f;
4110                    bstate_scalar |= f;
4111                }
4112
4113                assert_eq!(
4114                    bucket_avx2, bucket_scalar,
4115                    "bucket mismatch at base={base} coefs={coefs:?}"
4116                );
4117                assert_eq!(
4118                    bstate_avx2, bstate_scalar,
4119                    "bstatetmp mismatch at base={base}"
4120                );
4121            }
4122        }
4123    }
4124
4125    /// AVX2 `prelim_flags_band0_avx2` must mirror the scalar band-0 update:
4126    /// only entries with `old_flags[k] != ZERO` are rewritten; other entries
4127    /// stay (so a ZERO-state lane is preserved across the call).
4128    #[cfg(all(target_arch = "x86_64", feature = "std"))]
4129    #[test]
4130    fn prelim_flags_band0_avx2_matches_scalar() {
4131        if !std::is_x86_feature_detected!("avx2") {
4132            eprintln!("skipping: AVX2 not available on this host");
4133            return;
4134        }
4135        // Old-flag patterns covering the three states and mixed.
4136        let old_patterns: &[[u8; 16]] = &[
4137            [super::ZERO; 16],
4138            [super::UNK; 16],
4139            [super::ACTIVE; 16],
4140            [
4141                super::ZERO,
4142                super::UNK,
4143                super::ACTIVE,
4144                super::ZERO,
4145                super::UNK,
4146                super::ACTIVE,
4147                super::ZERO,
4148                super::UNK,
4149                super::ACTIVE,
4150                super::ZERO,
4151                super::UNK,
4152                super::ACTIVE,
4153                super::ZERO,
4154                super::UNK,
4155                super::ACTIVE,
4156                super::ZERO,
4157            ],
4158        ];
4159        let coef_patterns: &[[i16; 16]] = &[
4160            [0; 16],
4161            [1; 16],
4162            [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
4163            [
4164                -32768, 0, 32767, 0, 100, 0, -100, 0, 0, 1, 0, -1, 0, 5, 0, -5,
4165            ],
4166        ];
4167
4168        for &old in old_patterns {
4169            for &coefs in coef_patterns {
4170                let mut block = [0i16; 1024];
4171                block[..16].copy_from_slice(&coefs);
4172
4173                let mut flags_avx2 = old;
4174                #[allow(unsafe_code)]
4175                let bstate_avx2 =
4176                    unsafe { super::prelim_flags_band0_avx2(&block, &mut flags_avx2) };
4177
4178                let mut flags_scalar = old;
4179                let mut bstate_scalar = 0u8;
4180                for k in 0..16 {
4181                    if flags_scalar[k] != super::ZERO {
4182                        flags_scalar[k] = if block[k] == 0 {
4183                            super::UNK
4184                        } else {
4185                            super::ACTIVE
4186                        };
4187                    }
4188                    bstate_scalar |= flags_scalar[k];
4189                }
4190
4191                assert_eq!(
4192                    flags_avx2, flags_scalar,
4193                    "flags mismatch old={old:?} coefs={coefs:?}"
4194                );
4195                assert_eq!(bstate_avx2, bstate_scalar, "bstatetmp mismatch");
4196            }
4197        }
4198    }
4199
4200    #[cfg(all(target_arch = "x86_64", feature = "std"))]
4201    #[test]
4202    fn ycbcr_avx2_raw_half_matches_scalar() {
4203        if !std::is_x86_feature_detected!("avx2") {
4204            eprintln!("skipping: AVX2 not available on this host");
4205            return;
4206        }
4207        let raw_vals: [i16; 8] = [-32768, -8192, -64, -1, 0, 63, 8191, 32767];
4208        for &width in &[2usize, 8, 16, 18, 30, 32, 34, 48, 64, 96] {
4209            let n = width;
4210            let half = n.div_ceil(2);
4211            let make_seq = |seed: usize, len: usize| -> Vec<i16> {
4212                (0..len)
4213                    .map(|i| raw_vals[(i + seed) % raw_vals.len()])
4214                    .collect()
4215            };
4216            let y = make_seq(0, n);
4217            let cb_half = make_seq(3, half);
4218            let cr_half = make_seq(5, half);
4219
4220            let mut got = vec![0u8; n * 4];
4221            #[allow(unsafe_code)]
4222            unsafe {
4223                super::ycbcr_avx2_raw_half(
4224                    y.as_ptr(),
4225                    cb_half.as_ptr(),
4226                    cr_half.as_ptr(),
4227                    got.as_mut_ptr(),
4228                    n,
4229                );
4230            }
4231
4232            let mut want = vec![0u8; n * 4];
4233            ycbcr_raw_half_scalar(&y, &cb_half, &cr_half, &mut want);
4234
4235            assert_eq!(got, want, "AVX2 raw_half mismatch at width {}", width);
4236        }
4237    }
4238
4239    /// WASM simd128 stride-1 load must sign-extend i16→i32 correctly.
4240    ///
4241    /// Mirrors `load8s_s1_avx2_matches_scalar` but for the simd128 path.
4242    /// Runs only when compiled for wasm32 with +simd128; host tests use the
4243    /// AVX2 or scalar path instead.
4244    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
4245    #[test]
4246    fn load8s_s1_simd128_matches_scalar() {
4247        let raw_vals: [i16; 8] = [-32768, -8192, -64, -1, 0, 63, 8191, 32767];
4248        let n = 64;
4249        let buf: alloc::vec::Vec<i16> = (0..n).map(|i| raw_vals[i % raw_vals.len()]).collect();
4250        for phys_off in 0..(n - 8) {
4251            #[allow(unsafe_code)]
4252            let got = unsafe { super::load8s_s1_simd128(&buf, phys_off) };
4253            let want = super::load8s(&buf, phys_off, 1);
4254            assert_eq!(
4255                got.to_array(),
4256                want.to_array(),
4257                "simd128 load8s_s1 mismatch at phys_off {}",
4258                phys_off
4259            );
4260        }
4261    }
4262
4263    /// WASM simd128 stride-1 store must truncate i32→i16 (drop upper 16 bits, no
4264    /// saturation) matching the scalar `as i16` cast for every input.
4265    ///
4266    /// Mirrors `store8s_s1_avx2_matches_scalar`.
4267    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
4268    #[test]
4269    fn store8s_s1_simd128_matches_scalar() {
4270        let raw_vals: [i32; 8] = [i32::MIN, -100_000, -32768, -1, 0, 32767, 100_000, i32::MAX];
4271        for offset in 0..8usize {
4272            let mut input = [0i32; 8];
4273            for j in 0..8 {
4274                input[j] = raw_vals[(j + offset) % 8];
4275            }
4276            let v = wide::i32x8::from(input);
4277
4278            let mut buf_simd128 = alloc::vec![0xABCDu16 as i16; 32];
4279            #[allow(unsafe_code)]
4280            unsafe {
4281                super::store8s_s1_simd128(&mut buf_simd128, 8, v);
4282            }
4283            let mut buf_scalar = alloc::vec![0xABCDu16 as i16; 32];
4284            for j in 0..8 {
4285                buf_scalar[8 + j] = input[j] as i16;
4286            }
4287            assert_eq!(buf_simd128, buf_scalar, "simd128 store8s_s1 mismatch");
4288        }
4289    }
4290}