Skip to main content

j2k_transcode/
accelerator.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Optional acceleration hooks for coefficient-domain transform stages.
4//!
5//! These hooks are intentionally narrow: accelerated backends may replace the
6//! direct DCT-grid to one-level wavelet projection, while the scalar path
7//! remains the default oracle and fallback.
8
9use core::fmt;
10
11use crate::dct53_2d::Dwt53TwoDimensional;
12use crate::dct97_2d::Dwt97TwoDimensional;
13use crate::dct_grid::validate_dct_block_grid;
14use crate::reversible53::{
15    reversible_lift_53_high_at, reversible_lift_53_i32, reversible_lift_53_low_at,
16};
17pub use j2k::adapter::encode_stage::{
18    EncodedHtJ2kCodeBlock, IrreversibleQuantizationSubbandScales, J2kSubBandType,
19    PreencodedHtj2k97CodeBlock, PreencodedHtj2k97CompactCodeBlock,
20    PreencodedHtj2k97CompactComponent, PreencodedHtj2k97CompactImage,
21    PreencodedHtj2k97CompactResolution, PreencodedHtj2k97CompactSubband,
22    PreencodedHtj2k97Component, PreencodedHtj2k97Resolution, PreencodedHtj2k97Subband,
23    PrequantizedHtj2k97CodeBlock, PrequantizedHtj2k97Component, PrequantizedHtj2k97Image,
24    PrequantizedHtj2k97Resolution, PrequantizedHtj2k97Subband,
25};
26use j2k_jpeg::transcode::idct_islow_block;
27use rayon::prelude::*;
28
29const REVERSIBLE_DWT53_UNSUPPORTED_GRID: &str =
30    "reversible DCT 5/3 job has unsupported grid geometry";
31
32/// Direct DCT-grid to one-level reversible integer 5/3 projection job.
33#[derive(Debug, Clone, Copy)]
34pub struct DctGridToReversibleDwt53Job<'a> {
35    /// Natural-order, dequantized 8x8 DCT blocks.
36    pub dequantized_blocks: &'a [[i16; 64]],
37    /// Number of DCT block columns in `dequantized_blocks`.
38    pub block_cols: usize,
39    /// Number of DCT block rows in `dequantized_blocks`.
40    pub block_rows: usize,
41    /// Logical component width in samples.
42    pub width: usize,
43    /// Logical component height in samples.
44    pub height: usize,
45}
46
47/// One separable single-level reversible integer 5/3 transform result.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ReversibleDwt53FirstLevel {
50    /// Low-horizontal, low-vertical band.
51    pub ll: Vec<i32>,
52    /// High-horizontal, low-vertical band.
53    pub hl: Vec<i32>,
54    /// Low-horizontal, high-vertical band.
55    pub lh: Vec<i32>,
56    /// High-horizontal, high-vertical band.
57    pub hh: Vec<i32>,
58    /// Width of horizontally low-pass bands.
59    pub low_width: usize,
60    /// Height of vertically low-pass bands.
61    pub low_height: usize,
62    /// Width of horizontally high-pass bands.
63    pub high_width: usize,
64    /// Height of vertically high-pass bands.
65    pub high_height: usize,
66}
67
68/// Direct DCT-grid to one-level 5/3 projection job.
69#[derive(Debug, Clone, Copy)]
70pub struct DctGridToDwt53Job<'a> {
71    /// Natural-order, dequantized 8x8 DCT blocks.
72    pub blocks: &'a [[[f64; 8]; 8]],
73    /// Number of DCT block columns in `blocks`.
74    pub block_cols: usize,
75    /// Number of DCT block rows in `blocks`.
76    pub block_rows: usize,
77    /// Logical component width in samples.
78    pub width: usize,
79    /// Logical component height in samples.
80    pub height: usize,
81}
82
83/// Direct DCT-grid to one-level 9/7 transform job.
84#[derive(Debug, Clone, Copy)]
85pub struct DctGridToDwt97Job<'a> {
86    /// Natural-order, dequantized 8x8 DCT blocks.
87    pub blocks: &'a [[[f64; 8]; 8]],
88    /// Number of DCT block columns in `blocks`.
89    pub block_cols: usize,
90    /// Number of DCT block rows in `blocks`.
91    pub block_rows: usize,
92    /// Logical component width in samples.
93    pub width: usize,
94    /// Logical component height in samples.
95    pub height: usize,
96}
97
98/// Direct DCT-grid to prequantized one-level 9/7 HTJ2K code-block job.
99#[derive(Debug, Clone, Copy)]
100pub struct DctGridToHtj2k97CodeBlockJob<'a> {
101    /// Natural-order, dequantized 8x8 DCT blocks.
102    pub blocks: &'a [[[f64; 8]; 8]],
103    /// Number of DCT block columns in `blocks`.
104    pub block_cols: usize,
105    /// Number of DCT block rows in `blocks`.
106    pub block_rows: usize,
107    /// Logical component width in samples.
108    pub width: usize,
109    /// Logical component height in samples.
110    pub height: usize,
111    /// Horizontal SIZ sampling factor (`XRsiz`).
112    pub x_rsiz: u8,
113    /// Vertical SIZ sampling factor (`YRsiz`).
114    pub y_rsiz: u8,
115}
116
117/// Direct dequantized i16 DCT-grid to one-level 9/7 HTJ2K code-block job.
118///
119/// This is for accelerators that consume the JPEG coefficient extraction
120/// output directly and do not need the generic f64 block representation.
121#[derive(Debug, Clone, Copy)]
122pub struct DctGridI16ToHtj2k97CodeBlockJob<'a> {
123    /// Natural-order, dequantized 8x8 DCT blocks.
124    pub dequantized_blocks: &'a [[i16; 64]],
125    /// Number of DCT block columns in `dequantized_blocks`.
126    pub block_cols: usize,
127    /// Number of DCT block rows in `dequantized_blocks`.
128    pub block_rows: usize,
129    /// Logical component width in samples.
130    pub width: usize,
131    /// Logical component height in samples.
132    pub height: usize,
133    /// Horizontal SIZ sampling factor (`XRsiz`).
134    pub x_rsiz: u8,
135    /// Vertical SIZ sampling factor (`YRsiz`).
136    pub y_rsiz: u8,
137}
138
139/// One same-geometry i16 DCT-grid HTJ2K preencode batch.
140#[derive(Debug, Clone, Copy)]
141pub struct DctGridI16ToHtj2k97CodeBlockBatch<'a, 'j> {
142    /// Jobs in this same-geometry batch.
143    pub jobs: &'j [DctGridI16ToHtj2k97CodeBlockJob<'a>],
144}
145
146/// Compact preencoded HTJ2K components backed by one payload buffer.
147#[derive(Debug, Clone)]
148pub struct PreencodedHtj2k97CompactBatch {
149    /// Contiguous encoded code-block payload bytes for every component.
150    pub payload: Vec<u8>,
151    /// Compact components in the same order as the submitted jobs.
152    pub components: Vec<PreencodedHtj2k97CompactComponent>,
153}
154
155/// Compact preencoded HTJ2K grouped-batch output backed by one payload buffer.
156#[derive(Debug, Clone)]
157pub struct PreencodedHtj2k97CompactBatchGroups {
158    /// Contiguous encoded code-block payload bytes for every returned group.
159    pub payload: Vec<u8>,
160    /// Compact components grouped in the same order as submitted batches.
161    pub groups: Vec<Vec<PreencodedHtj2k97CompactComponent>>,
162}
163
164/// Encode parameters needed to quantize 9/7 output directly into HTJ2K
165/// code-block coefficient layout.
166#[derive(Debug, Clone, Copy, PartialEq)]
167pub struct Htj2k97CodeBlockOptions {
168    /// Component precision in bits.
169    pub bit_depth: u8,
170    /// JPEG 2000 guard bits used for QCD and code-block bitplane counts.
171    pub guard_bits: u8,
172    /// Code-block width exponent minus two.
173    pub code_block_width_exp: u8,
174    /// Code-block height exponent minus two.
175    pub code_block_height_exp: u8,
176    /// Multiplier applied to irreversible 9/7 scalar quantization step sizes.
177    pub irreversible_quantization_scale: f32,
178    /// Per-subband multipliers applied on top of
179    /// [`irreversible_quantization_scale`](Self::irreversible_quantization_scale).
180    pub irreversible_quantization_subband_scales: IrreversibleQuantizationSubbandScales,
181}
182
183/// Backend-specific timing breakdown for a same-geometry 9/7 batch.
184#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
185pub struct Dwt97BatchStageTimings {
186    /// Host packing, buffer allocation, and upload time in microseconds.
187    pub pack_upload_us: u128,
188    /// Logical host-to-device transfers included in [`Self::pack_upload_us`].
189    pub pack_upload_transfers: usize,
190    /// Host-to-device bytes included in [`Self::pack_upload_us`].
191    pub pack_upload_bytes: u64,
192    /// Resident JPEG DCT-grid descriptors validated for this batch.
193    pub resident_dct_handoff_count: usize,
194    /// Time spent in the IDCT plus horizontal 9/7 row-lift stage.
195    pub idct_row_lift_us: u128,
196    /// Time spent in the vertical 9/7 column-lift stage.
197    pub column_lift_us: u128,
198    /// Resident DWT subband descriptors validated for this batch.
199    pub resident_dwt_handoff_count: usize,
200    /// Time spent quantizing 9/7 bands into HTJ2K code-block layout.
201    pub quantize_codeblock_us: u128,
202    /// Time spent HT-encoding resident code-block coefficients.
203    pub ht_encode_us: u128,
204    /// Resident HT cleanup-pass encode kernel time in microseconds.
205    pub ht_kernel_us: u128,
206    /// Resident HT status-buffer device-to-host readback time in microseconds.
207    pub ht_status_readback_us: u128,
208    /// Logical device-to-host status readbacks included in [`Self::ht_status_readback_us`].
209    pub ht_status_readback_transfers: usize,
210    /// Device-to-host status bytes included in [`Self::ht_status_readback_us`].
211    pub ht_status_readback_bytes: u64,
212    /// Resident HT encoded-byte compaction kernel time in microseconds.
213    pub ht_compact_us: u128,
214    /// Resident HT compacted encoded-byte device-to-host readback time in microseconds.
215    pub ht_output_readback_us: u128,
216    /// Logical device-to-host output readbacks included in [`Self::ht_output_readback_us`].
217    pub ht_output_readback_transfers: usize,
218    /// Device-to-host output bytes included in [`Self::ht_output_readback_us`].
219    pub ht_output_readback_bytes: u64,
220    /// Number of HT code-block encode kernel dispatches in this batch.
221    pub ht_codeblock_dispatches: usize,
222    /// Time spent reading and unpacking Metal band buffers into host outputs.
223    pub readback_us: u128,
224    /// Logical device-to-host transfers included in [`Self::readback_us`].
225    pub readback_transfers: usize,
226    /// Device-to-host bytes included in [`Self::readback_us`].
227    pub readback_bytes: u64,
228}
229
230/// Error returned by accelerated transcode stage backends.
231#[derive(Debug, Clone, PartialEq, Eq)]
232pub enum TranscodeStageError {
233    /// The job shape, options, or environment are outside what this backend
234    /// supports.
235    Unsupported(&'static str),
236    /// The backend failed while executing the stage.
237    Backend(String),
238    /// The device or runtime backing this accelerator is unavailable.
239    DeviceUnavailable,
240}
241
242impl fmt::Display for TranscodeStageError {
243    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244        match self {
245            Self::Unsupported(reason) => f.write_str(reason),
246            Self::Backend(reason) => f.write_str(reason),
247            Self::DeviceUnavailable => f.write_str("accelerator device is unavailable"),
248        }
249    }
250}
251
252impl std::error::Error for TranscodeStageError {}
253
254impl From<&'static str> for TranscodeStageError {
255    fn from(reason: &'static str) -> Self {
256        Self::Unsupported(reason)
257    }
258}
259
260/// Optional backend for SIMD, GPU, or other accelerated transform stages.
261pub trait DctToWaveletStageAccelerator {
262    /// Whether this accelerator wants same-geometry 9/7 batch jobs offered.
263    ///
264    /// The default is false so CPU-only fallback paths do not pay the memory
265    /// cost of materializing batch-owned float DCT blocks before immediately
266    /// falling back.
267    fn supports_dwt97_batch(&self) -> bool {
268        false
269    }
270
271    /// Whether this accelerator wants same-geometry 9/7 batches offered as
272    /// prequantized HTJ2K code-block jobs before the float-band hook.
273    fn supports_htj2k97_codeblock_batch(&self) -> bool {
274        false
275    }
276
277    /// Whether this accelerator wants same-geometry 9/7 preencoded HTJ2K
278    /// batches offered with dequantized i16 DCT blocks before materializing the
279    /// generic f64 block representation.
280    fn supports_htj2k97_i16_preencoded_batch(&self) -> bool {
281        false
282    }
283
284    /// Whether this accelerator wants the compact i16 preencoded HTJ2K batch
285    /// hook offered before the owned preencoded hook.
286    fn supports_htj2k97_compact_preencoded_batch(&self) -> bool {
287        self.supports_htj2k97_i16_preencoded_batch()
288    }
289
290    /// Optionally compute the direct DCT-grid to one-level reversible integer
291    /// 5/3 projection.
292    ///
293    /// Return `Ok(Some(output))` when the backend handled the job bit-exactly
294    /// relative to j2k's scalar integer oracle. Return `Ok(None)` to use
295    /// the scalar fallback.
296    fn dct_grid_to_reversible_dwt53(
297        &mut self,
298        _job: DctGridToReversibleDwt53Job<'_>,
299    ) -> Result<Option<ReversibleDwt53FirstLevel>, TranscodeStageError> {
300        Ok(None)
301    }
302
303    /// Optionally compute a same-geometry batch of direct DCT-grid to
304    /// one-level reversible integer 5/3 projections.
305    ///
306    /// Backends should return outputs in the same order as `jobs`. Return
307    /// `Ok(None)` to use the scalar per-component fallback.
308    fn dct_grid_to_reversible_dwt53_batch(
309        &mut self,
310        _jobs: &[DctGridToReversibleDwt53Job<'_>],
311    ) -> Result<Option<Vec<ReversibleDwt53FirstLevel>>, TranscodeStageError> {
312        Ok(None)
313    }
314
315    /// Optionally compute the direct DCT-grid to one-level 5/3 projection.
316    ///
317    /// Return `Ok(Some(output))` when the backend handled the job. Return
318    /// `Ok(None)` to use the scalar fallback.
319    fn dct_grid_to_dwt53(
320        &mut self,
321        _job: DctGridToDwt53Job<'_>,
322    ) -> Result<Option<Dwt53TwoDimensional<f64>>, TranscodeStageError> {
323        Ok(None)
324    }
325
326    /// Optionally compute the direct DCT-grid to one-level 9/7 transform.
327    ///
328    /// Return `Ok(Some(output))` when the backend handled the job. Return
329    /// `Ok(None)` to use the scalar fallback.
330    fn dct_grid_to_dwt97(
331        &mut self,
332        _job: DctGridToDwt97Job<'_>,
333    ) -> Result<Option<Dwt97TwoDimensional<f64>>, TranscodeStageError> {
334        Ok(None)
335    }
336
337    /// Optionally compute a same-geometry batch of direct DCT-grid to
338    /// one-level 9/7 transforms.
339    ///
340    /// Backends should return outputs in the same order as `jobs`. Return
341    /// `Ok(None)` to use the scalar per-component fallback.
342    fn dct_grid_to_dwt97_batch(
343        &mut self,
344        _jobs: &[DctGridToDwt97Job<'_>],
345    ) -> Result<Option<Vec<Dwt97TwoDimensional<f64>>>, TranscodeStageError> {
346        Ok(None)
347    }
348
349    /// Optionally compute same-geometry DCT-grid 9/7 jobs directly into
350    /// prequantized HTJ2K code-block components.
351    ///
352    /// Backends should return one component per input job in the same order as
353    /// `jobs`. Return `Ok(None)` to use the float-band path.
354    fn dct_grid_to_htj2k97_codeblock_batch(
355        &mut self,
356        _jobs: &[DctGridToHtj2k97CodeBlockJob<'_>],
357        _options: Htj2k97CodeBlockOptions,
358    ) -> Result<Option<Vec<PrequantizedHtj2k97Component>>, TranscodeStageError> {
359        Ok(None)
360    }
361
362    /// Optionally compute same-geometry DCT-grid 9/7 jobs directly into
363    /// preencoded HTJ2K code-block payloads.
364    ///
365    /// Backends should return one component per input job in the same order as
366    /// `jobs`. Return `Ok(None)` to use the prequantized or float-band path.
367    fn dct_grid_to_htj2k97_preencoded_batch(
368        &mut self,
369        _jobs: &[DctGridToHtj2k97CodeBlockJob<'_>],
370        _options: Htj2k97CodeBlockOptions,
371    ) -> Result<Option<Vec<PreencodedHtj2k97Component>>, TranscodeStageError> {
372        Ok(None)
373    }
374
375    /// Optionally compute same-geometry dequantized i16 DCT-grid 9/7 jobs
376    /// directly into preencoded HTJ2K code-block payloads.
377    ///
378    /// Backends should return one component per input job in the same order as
379    /// `jobs`. Return `Ok(None)` to use the generic f64 preencoded path.
380    fn dct_grid_i16_to_htj2k97_preencoded_batch(
381        &mut self,
382        _jobs: &[DctGridI16ToHtj2k97CodeBlockJob<'_>],
383        _options: Htj2k97CodeBlockOptions,
384    ) -> Result<Option<Vec<PreencodedHtj2k97Component>>, TranscodeStageError> {
385        Ok(None)
386    }
387
388    /// Optionally compute same-geometry dequantized i16 DCT-grid 9/7 jobs into
389    /// compact preencoded HTJ2K code-block payloads.
390    ///
391    /// Backends should return one component per input job in the same order as
392    /// `jobs`, with all component ranges pointing into the returned payload.
393    /// Return `Ok(None)` to use the owned preencoded path.
394    fn dct_grid_i16_to_htj2k97_compact_preencoded_batch(
395        &mut self,
396        _jobs: &[DctGridI16ToHtj2k97CodeBlockJob<'_>],
397        _options: Htj2k97CodeBlockOptions,
398    ) -> Result<Option<PreencodedHtj2k97CompactBatch>, TranscodeStageError> {
399        Ok(None)
400    }
401
402    /// Optionally compute multiple same-geometry dequantized i16 DCT-grid
403    /// batches directly into preencoded HTJ2K code-block payloads.
404    ///
405    /// Each input batch is internally same-geometry, but different batches may
406    /// have different component dimensions. Backends should return one output
407    /// vector per input batch, in order. Return `Ok(None)` to use the per-group
408    /// fallback hooks.
409    fn dct_grid_i16_to_htj2k97_preencoded_batch_groups(
410        &mut self,
411        _groups: &[DctGridI16ToHtj2k97CodeBlockBatch<'_, '_>],
412        _options: Htj2k97CodeBlockOptions,
413    ) -> Result<Option<Vec<Vec<PreencodedHtj2k97Component>>>, TranscodeStageError> {
414        Ok(None)
415    }
416
417    /// Optionally compute multiple same-geometry dequantized i16 DCT-grid 9/7
418    /// batches into compact preencoded HTJ2K code-block payloads.
419    ///
420    /// Each returned item corresponds to one input batch and contains one
421    /// component per job in that batch. Return `Ok(None)` to use the owned
422    /// preencoded grouped hook.
423    fn dct_grid_i16_to_htj2k97_compact_preencoded_batch_groups(
424        &mut self,
425        _groups: &[DctGridI16ToHtj2k97CodeBlockBatch<'_, '_>],
426        _options: Htj2k97CodeBlockOptions,
427    ) -> Result<Option<PreencodedHtj2k97CompactBatchGroups>, TranscodeStageError> {
428        Ok(None)
429    }
430
431    /// Return backend stage timings for the most recent 9/7 batch dispatch.
432    fn last_dwt97_batch_stage_timings(&self) -> Option<Dwt97BatchStageTimings> {
433        None
434    }
435}
436
437/// Accelerator that always uses the scalar CPU fallback.
438#[derive(Debug, Default, Clone, Copy)]
439pub struct CpuOnlyDctToWaveletStageAccelerator;
440
441impl DctToWaveletStageAccelerator for CpuOnlyDctToWaveletStageAccelerator {}
442
443/// CPU/Rayon accelerator for the exact reversible integer 5/3 first level.
444///
445/// This backend keeps j2k's scalar ISLOW IDCT semantics as the oracle:
446/// each 8x8 block is decoded with `j2k-jpeg`, level-shifted to signed
447/// component samples, then transformed with reversible integer 5/3 lifting.
448#[derive(Debug, Default, Clone)]
449pub struct RayonReversibleDwt53Accelerator {
450    attempts: usize,
451    dispatches: usize,
452    batch_attempts: usize,
453    batch_dispatches: usize,
454}
455
456impl RayonReversibleDwt53Accelerator {
457    /// Number of reversible 5/3 jobs offered to this accelerator.
458    #[must_use]
459    pub const fn reversible_dwt53_attempts(&self) -> usize {
460        self.attempts
461    }
462
463    /// Number of reversible 5/3 jobs handled by this accelerator.
464    #[must_use]
465    pub const fn reversible_dwt53_dispatches(&self) -> usize {
466        self.dispatches
467    }
468
469    /// Number of reversible 5/3 batches offered to this accelerator.
470    #[must_use]
471    pub const fn reversible_dwt53_batch_attempts(&self) -> usize {
472        self.batch_attempts
473    }
474
475    /// Number of reversible 5/3 batches handled by this accelerator.
476    #[must_use]
477    pub const fn reversible_dwt53_batch_dispatches(&self) -> usize {
478        self.batch_dispatches
479    }
480}
481
482impl DctToWaveletStageAccelerator for RayonReversibleDwt53Accelerator {
483    fn dct_grid_to_reversible_dwt53(
484        &mut self,
485        job: DctGridToReversibleDwt53Job<'_>,
486    ) -> Result<Option<ReversibleDwt53FirstLevel>, TranscodeStageError> {
487        self.attempts = self.attempts.saturating_add(1);
488        let output = reversible_dwt53_first_level_rayon(job)?;
489        self.dispatches = self.dispatches.saturating_add(1);
490        Ok(Some(output))
491    }
492
493    fn dct_grid_to_reversible_dwt53_batch(
494        &mut self,
495        jobs: &[DctGridToReversibleDwt53Job<'_>],
496    ) -> Result<Option<Vec<ReversibleDwt53FirstLevel>>, TranscodeStageError> {
497        self.batch_attempts = self.batch_attempts.saturating_add(1);
498        let mut output = Vec::with_capacity(jobs.len());
499        for job in jobs {
500            output.push(reversible_dwt53_first_level_rayon(*job)?);
501        }
502        self.batch_dispatches = self.batch_dispatches.saturating_add(1);
503        Ok(Some(output))
504    }
505}
506
507/// Decode the job's dequantized DCT blocks into j2k's signed integer
508/// component sample blocks.
509///
510/// This is public so hybrid GPU backends can keep JPEG parsing and exact IDCT
511/// on CPU while offloading the reversible 5/3 projection.
512pub fn idct_blocks_to_signed_samples_rayon(blocks: &[[i16; 64]]) -> Vec<[i32; 64]> {
513    blocks
514        .par_iter()
515        .map(|block| {
516            let decoded = idct_islow_block(block);
517            decoded.map(|sample| i32::from(sample) - 128)
518        })
519        .collect()
520}
521
522/// Compute one exact reversible integer 5/3 level from already decoded
523/// block-local signed samples.
524pub fn reversible_dwt53_first_level_from_block_samples(
525    block_samples: &[[i32; 64]],
526    block_cols: usize,
527    block_rows: usize,
528    width: usize,
529    height: usize,
530) -> Result<ReversibleDwt53FirstLevel, &'static str> {
531    validate_reversible_grid(block_samples.len(), block_cols, block_rows, width, height)?;
532
533    let low_width = width.div_ceil(2);
534    let low_height = height.div_ceil(2);
535    let high_width = width / 2;
536    let high_height = height / 2;
537
538    let low_rows: Vec<(Vec<i32>, Vec<i32>)> = (0..low_height)
539        .into_par_iter()
540        .map(|output_y| {
541            let mut row = Vec::with_capacity(width);
542            for x in 0..width {
543                row.push(vertical_low_53_i32_at(
544                    block_samples,
545                    block_cols,
546                    width,
547                    height,
548                    x,
549                    output_y,
550                ));
551            }
552            reversible_lift_53_i32(&mut row);
553            (
554                row.iter().step_by(2).copied().collect(),
555                row.iter().skip(1).step_by(2).copied().collect(),
556            )
557        })
558        .collect();
559    let high_rows: Vec<(Vec<i32>, Vec<i32>)> = (0..high_height)
560        .into_par_iter()
561        .map(|output_y| {
562            let mut row = Vec::with_capacity(width);
563            for x in 0..width {
564                row.push(vertical_high_53_i32_at(
565                    block_samples,
566                    block_cols,
567                    width,
568                    height,
569                    x,
570                    output_y,
571                ));
572            }
573            reversible_lift_53_i32(&mut row);
574            (
575                row.iter().step_by(2).copied().collect(),
576                row.iter().skip(1).step_by(2).copied().collect(),
577            )
578        })
579        .collect();
580
581    let mut ll = Vec::with_capacity(low_width * low_height);
582    let mut hl = Vec::with_capacity(high_width * low_height);
583    for (low, high) in low_rows {
584        ll.extend(low);
585        hl.extend(high);
586    }
587
588    let mut lh = Vec::with_capacity(low_width * high_height);
589    let mut hh = Vec::with_capacity(high_width * high_height);
590    for (low, high) in high_rows {
591        lh.extend(low);
592        hh.extend(high);
593    }
594
595    Ok(ReversibleDwt53FirstLevel {
596        ll,
597        hl,
598        lh,
599        hh,
600        low_width,
601        low_height,
602        high_width,
603        high_height,
604    })
605}
606
607fn reversible_dwt53_first_level_rayon(
608    job: DctGridToReversibleDwt53Job<'_>,
609) -> Result<ReversibleDwt53FirstLevel, &'static str> {
610    validate_reversible_grid(
611        job.dequantized_blocks.len(),
612        job.block_cols,
613        job.block_rows,
614        job.width,
615        job.height,
616    )?;
617    let block_samples = idct_blocks_to_signed_samples_rayon(job.dequantized_blocks);
618    reversible_dwt53_first_level_from_block_samples(
619        &block_samples,
620        job.block_cols,
621        job.block_rows,
622        job.width,
623        job.height,
624    )
625}
626
627fn validate_reversible_grid(
628    block_count: usize,
629    block_cols: usize,
630    block_rows: usize,
631    width: usize,
632    height: usize,
633) -> Result<(), &'static str> {
634    validate_dct_block_grid(block_count, block_cols, block_rows, width, height)
635        .map_err(|_| REVERSIBLE_DWT53_UNSUPPORTED_GRID)
636}
637
638fn vertical_low_53_i32_at(
639    block_samples: &[[i32; 64]],
640    block_cols: usize,
641    width: usize,
642    height: usize,
643    x: usize,
644    low_idx: usize,
645) -> i32 {
646    reversible_lift_53_low_at(height, low_idx, |y| {
647        component_sample_i32(block_samples, block_cols, width, height, x, y)
648    })
649}
650
651fn vertical_high_53_i32_at(
652    block_samples: &[[i32; 64]],
653    block_cols: usize,
654    width: usize,
655    height: usize,
656    x: usize,
657    high_idx: usize,
658) -> i32 {
659    reversible_lift_53_high_at(height, high_idx, |y| {
660        component_sample_i32(block_samples, block_cols, width, height, x, y)
661    })
662}
663
664fn component_sample_i32(
665    block_samples: &[[i32; 64]],
666    block_cols: usize,
667    width: usize,
668    height: usize,
669    x: usize,
670    y: usize,
671) -> i32 {
672    debug_assert!(x < width);
673    debug_assert!(y < height);
674    let block_x = x / 8;
675    let block_y = y / 8;
676    let block_idx = block_y * block_cols + block_x;
677    let local_idx = (y % 8) * 8 + (x % 8);
678    block_samples[block_idx][local_idx]
679}
680
681#[cfg(test)]
682mod ground_truth_tests {
683    //! Independent ground truth for the reversible integer 5/3.
684    //!
685    //! The CUDA 5/3 kernel is parity-tested against the lifting in this module,
686    //! so a boundary/indexing/band-split bug here would be faithfully copied by
687    //! the kernel and pass parity. Validate the lifting against the canonical
688    //! JPEG2000 reversible 5/3 (ISO/IEC 15444-1 Annex F.3.8.1) evaluated per
689    //! output index from a whole-sample-symmetrically extended signal — a
690    //! structurally different implementation than the in-place two-pass loops.
691
692    use super::{
693        reversible_dwt53_first_level_from_block_samples, reversible_lift_53_i32,
694        ReversibleDwt53FirstLevel,
695    };
696
697    fn floor2(a: i32, b: i32) -> i32 {
698        a.div_euclid(b)
699    }
700
701    /// Whole-sample symmetric reflection (mirror about 0 and `n - 1`, endpoints
702    /// not repeated) — the boundary extension the lifting realizes at the edges.
703    fn ws_reflect(i: isize, n: usize) -> usize {
704        if n == 1 {
705            return 0;
706        }
707        let n = isize::try_from(n).unwrap();
708        let period = 2 * (n - 1);
709        let mut k = i.rem_euclid(period);
710        if k >= n {
711            k = period - k;
712        }
713        usize::try_from(k).unwrap()
714    }
715
716    /// Canonical forward 5/3: `(low, high)` where `low[m]` is the even/approx
717    /// coefficient and `high[m]` the odd/detail coefficient. Every index is read
718    /// through whole-sample symmetric extension of the original signal, so the
719    /// detail-boundary behavior follows automatically (no special cases).
720    fn ref_53_forward(signal: &[i32]) -> (Vec<i32>, Vec<i32>) {
721        let n = signal.len();
722        if n < 2 {
723            return (signal.to_vec(), Vec::new());
724        }
725        let sig = |i: isize| signal[ws_reflect(i, n)];
726        let detail = |m: isize| {
727            let c = 2 * m + 1;
728            sig(c) - floor2(sig(c - 1) + sig(c + 1), 2)
729        };
730        let low: Vec<i32> = (0..n.div_ceil(2))
731            .map(|m| {
732                let mi = isize::try_from(m).unwrap();
733                sig(2 * mi) + floor2(detail(mi - 1) + detail(mi) + 2, 4)
734            })
735            .collect();
736        let high: Vec<i32> = (0..n / 2)
737            .map(|m| detail(isize::try_from(m).unwrap()))
738            .collect();
739        (low, high)
740    }
741
742    /// Separable 2D reference matching the oracle's vertical-then-horizontal
743    /// order (integer floor lifting is NOT order-independent, so order matters).
744    fn ref_53_2d(plane: &[i32], width: usize, height: usize) -> ReversibleDwt53FirstLevel {
745        let low_width = width.div_ceil(2);
746        let high_width = width / 2;
747        let low_height = height.div_ceil(2);
748        let high_height = height / 2;
749
750        let mut v_low = vec![0i32; width * low_height];
751        let mut v_high = vec![0i32; width * high_height];
752        for x in 0..width {
753            let column: Vec<i32> = (0..height).map(|y| plane[y * width + x]).collect();
754            let (lo, hi) = ref_53_forward(&column);
755            for (oy, &value) in lo.iter().enumerate() {
756                v_low[oy * width + x] = value;
757            }
758            for (oy, &value) in hi.iter().enumerate() {
759                v_high[oy * width + x] = value;
760            }
761        }
762
763        let horizontal = |source: &[i32], rows: usize| -> (Vec<i32>, Vec<i32>) {
764            let mut low = vec![0i32; low_width * rows];
765            let mut high = vec![0i32; high_width * rows];
766            for oy in 0..rows {
767                let (lo, hi) = ref_53_forward(&source[oy * width..oy * width + width]);
768                low[oy * low_width..oy * low_width + low_width].copy_from_slice(&lo);
769                high[oy * high_width..oy * high_width + high_width].copy_from_slice(&hi);
770            }
771            (low, high)
772        };
773
774        let (ll, hl) = horizontal(&v_low, low_height);
775        let (lh, hh) = horizontal(&v_high, high_height);
776
777        ReversibleDwt53FirstLevel {
778            ll,
779            hl,
780            lh,
781            hh,
782            low_width,
783            low_height,
784            high_width,
785            high_height,
786        }
787    }
788
789    /// Pack a flat `width x height` sample plane into the block-major
790    /// `[[i32; 64]]` layout `reversible_dwt53_first_level_from_block_samples`
791    /// consumes (local index `(y % 8) * 8 + (x % 8)`).
792    fn pack_plane(plane: &[i32], width: usize, height: usize) -> (Vec<[i32; 64]>, usize, usize) {
793        let block_cols = width.div_ceil(8);
794        let block_rows = height.div_ceil(8);
795        let mut blocks = vec![[0i32; 64]; block_cols * block_rows];
796        for y in 0..height {
797            for x in 0..width {
798                let block = (y / 8) * block_cols + (x / 8);
799                blocks[block][(y % 8) * 8 + (x % 8)] = plane[y * width + x];
800            }
801        }
802        (blocks, block_cols, block_rows)
803    }
804
805    fn next_sample(state: &mut u64) -> i32 {
806        *state = state
807            .wrapping_mul(6_364_136_223_846_793_005)
808            .wrapping_add(1_442_695_040_888_963_407);
809        ((*state >> 40) & 0x1ff) as i32 - 256
810    }
811
812    #[test]
813    fn reversible_lift_53_matches_canonical_formula_1d() {
814        let mut state = 0x0a11_ce5e_ed00_d001u64;
815        for n in [2usize, 3, 4, 5, 8, 9, 12, 15, 16, 23, 32, 33, 64, 65] {
816            let signal: Vec<i32> = (0..n).map(|_| next_sample(&mut state)).collect();
817            let mut lifted = signal.clone();
818            reversible_lift_53_i32(&mut lifted);
819            let lifted_low: Vec<i32> = lifted.iter().step_by(2).copied().collect();
820            let lifted_high: Vec<i32> = lifted.iter().skip(1).step_by(2).copied().collect();
821            let (low, high) = ref_53_forward(&signal);
822            assert_eq!(lifted_low, low, "low band mismatch for n={n}");
823            assert_eq!(lifted_high, high, "high band mismatch for n={n}");
824        }
825    }
826
827    #[test]
828    fn reversible_lift_53_shared_helper_matches_canonical_formula_1d() {
829        let mut state = 0x5a53_5a53_5a53_5a53u64;
830        for n in [2usize, 3, 4, 5, 8, 9, 16, 17, 31, 32, 65] {
831            let signal: Vec<i32> = (0..n).map(|_| next_sample(&mut state)).collect();
832            let mut lifted = signal.clone();
833            crate::reversible53::reversible_lift_53_i32(&mut lifted);
834            let lifted_low: Vec<i32> = lifted.iter().step_by(2).copied().collect();
835            let lifted_high: Vec<i32> = lifted.iter().skip(1).step_by(2).copied().collect();
836            let (low, high) = ref_53_forward(&signal);
837            assert_eq!(lifted_low, low, "low band mismatch for n={n}");
838            assert_eq!(lifted_high, high, "high band mismatch for n={n}");
839        }
840    }
841
842    #[test]
843    fn reversible_dwt53_2d_matches_canonical_separable() {
844        let mut state = 0xfeed_5eed_d00d_face_u64;
845        for (width, height) in [
846            (8usize, 8usize),
847            (16, 16),
848            (24, 16),
849            (15, 13),
850            (16, 23),
851            (9, 7),
852            (32, 32),
853        ] {
854            let plane: Vec<i32> = (0..width * height)
855                .map(|_| next_sample(&mut state))
856                .collect();
857            let (blocks, block_cols, block_rows) = pack_plane(&plane, width, height);
858            let got = reversible_dwt53_first_level_from_block_samples(
859                &blocks, block_cols, block_rows, width, height,
860            )
861            .expect("oracle accepts the packed grid");
862            let want = ref_53_2d(&plane, width, height);
863            assert_eq!(
864                (
865                    got.low_width,
866                    got.low_height,
867                    got.high_width,
868                    got.high_height
869                ),
870                (
871                    want.low_width,
872                    want.low_height,
873                    want.high_width,
874                    want.high_height
875                ),
876                "band dimensions for {width}x{height}"
877            );
878            assert_eq!(got.ll, want.ll, "LL mismatch for {width}x{height}");
879            assert_eq!(got.hl, want.hl, "HL mismatch for {width}x{height}");
880            assert_eq!(got.lh, want.lh, "LH mismatch for {width}x{height}");
881            assert_eq!(got.hh, want.hh, "HH mismatch for {width}x{height}");
882        }
883    }
884
885    #[test]
886    fn reversible_lift_53_kills_dc_and_linear_detail() {
887        // Constant -> low = constant, detail exactly zero.
888        let mut constant = vec![7i32; 32];
889        reversible_lift_53_i32(&mut constant);
890        assert!(
891            constant.iter().skip(1).step_by(2).all(|&v| v == 0),
892            "constant produced nonzero detail"
893        );
894        assert!(
895            constant.iter().step_by(2).all(|&v| v == 7),
896            "constant low band drifted from 7"
897        );
898
899        // Linear ramp -> interior detail exactly zero (two vanishing moments).
900        let ramp: Vec<i32> = (0..40_i32).map(|k| 3 * k - 5).collect();
901        let mut lifted = ramp;
902        reversible_lift_53_i32(&mut lifted);
903        let detail: Vec<i32> = lifted.iter().skip(1).step_by(2).copied().collect();
904        for &value in &detail[1..detail.len() - 1] {
905            assert_eq!(value, 0, "linear ramp produced interior detail {value}");
906        }
907    }
908
909    #[test]
910    fn reversible_dwt53_2d_separates_horizontal_and_vertical_detail() {
911        // Varies only along x -> no vertical detail (LH and HH vanish).
912        let (width, height) = (16usize, 16usize);
913        let varies_in_x: Vec<i32> = (0..width * height)
914            .map(|i| 3 * i32::try_from(i % width).unwrap() - 7)
915            .collect();
916        let (blocks, bc, br) = pack_plane(&varies_in_x, width, height);
917        let t = reversible_dwt53_first_level_from_block_samples(&blocks, bc, br, width, height)
918            .expect("oracle accepts grid");
919        assert!(
920            t.lh.iter().all(|&v| v == 0),
921            "x-only plane produced LH detail"
922        );
923        assert!(
924            t.hh.iter().all(|&v| v == 0),
925            "x-only plane produced HH detail"
926        );
927
928        // Varies only along y -> no horizontal detail (HL and HH vanish).
929        let varies_in_y: Vec<i32> = (0..width * height)
930            .map(|i| 3 * i32::try_from(i / width).unwrap() - 7)
931            .collect();
932        let (blocks, bc, br) = pack_plane(&varies_in_y, width, height);
933        let t = reversible_dwt53_first_level_from_block_samples(&blocks, bc, br, width, height)
934            .expect("oracle accepts grid");
935        assert!(
936            t.hl.iter().all(|&v| v == 0),
937            "y-only plane produced HL detail"
938        );
939        assert!(
940            t.hh.iter().all(|&v| v == 0),
941            "y-only plane produced HH detail"
942        );
943    }
944}