Skip to main content

j2k_transcode/
htj2k97_codeblock_oracle.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared scalar oracle: float 9/7 bands into prequantized HTJ2K code-blocks.
4//!
5//! This module uses the native encoder's public irreversible 9/7 quantization
6//! step helper plus the native code-block layout rules, so both GPU backends can
7//! compare their fused code-block kernels against one authoritative CPU
8//! reference instead of each re-deriving the math.
9//!
10//! The re-derivation is anchored to native truth by a codestream pin test (see
11//! the module tests): encoding the oracle's prequantized output reproduces the
12//! native precomputed-DWT codestream byte-for-byte.
13
14use crate::accelerator::Htj2k97CodeBlockOptions;
15use crate::dct97_2d::Dwt97TwoDimensional;
16use j2k::adapter::encode_stage::{
17    J2kSubBandType, PrequantizedHtj2k97CodeBlock, PrequantizedHtj2k97Component,
18    PrequantizedHtj2k97Resolution, PrequantizedHtj2k97Subband,
19};
20use j2k_native::irreversible_quantization_step_for_subband;
21
22/// Quantize one level of float 9/7 bands into a prequantized HTJ2K component.
23///
24/// Resolution nesting matches the native encoder for a single decomposition
25/// level: resolution 0 holds `[LL]`, resolution 1 holds `[HL, LH, HH]`.
26#[must_use]
27pub fn prequantized_component_from_dwt97(
28    dwt: &Dwt97TwoDimensional<f64>,
29    options: Htj2k97CodeBlockOptions,
30    x_rsiz: u8,
31    y_rsiz: u8,
32) -> PrequantizedHtj2k97Component {
33    PrequantizedHtj2k97Component {
34        x_rsiz,
35        y_rsiz,
36        resolutions: vec![
37            PrequantizedHtj2k97Resolution {
38                subbands: vec![quantize_codeblock_subband(
39                    &dwt.ll,
40                    dwt.low_width,
41                    dwt.low_height,
42                    J2kSubBandType::LowLow,
43                    options,
44                )],
45            },
46            PrequantizedHtj2k97Resolution {
47                subbands: vec![
48                    quantize_codeblock_subband(
49                        &dwt.hl,
50                        dwt.high_width,
51                        dwt.low_height,
52                        J2kSubBandType::HighLow,
53                        options,
54                    ),
55                    quantize_codeblock_subband(
56                        &dwt.lh,
57                        dwt.low_width,
58                        dwt.high_height,
59                        J2kSubBandType::LowHigh,
60                        options,
61                    ),
62                    quantize_codeblock_subband(
63                        &dwt.hh,
64                        dwt.high_width,
65                        dwt.high_height,
66                        J2kSubBandType::HighHigh,
67                        options,
68                    ),
69                ],
70            },
71        ],
72    }
73}
74
75/// Quantize a single float subband and slice it into code-block-major layout.
76///
77/// Code-blocks are emitted outer `cby`, inner `cbx`; each block's coefficients
78/// are row-major, matching the native encoder's `copy_code_block_coefficients`.
79#[must_use]
80pub fn quantize_codeblock_subband(
81    coefficients: &[f64],
82    width: usize,
83    height: usize,
84    sub_band_type: J2kSubBandType,
85    options: Htj2k97CodeBlockOptions,
86) -> PrequantizedHtj2k97Subband {
87    let quantized = quantize_subband_coefficients(coefficients, sub_band_type, options);
88    let cb_width = htj2k97_code_block_dim(options.code_block_width_exp);
89    let cb_height = htj2k97_code_block_dim(options.code_block_height_exp);
90    let num_cbs_x = width.div_ceil(cb_width);
91    let num_cbs_y = height.div_ceil(cb_height);
92    let mut code_blocks = Vec::with_capacity(num_cbs_x * num_cbs_y);
93
94    for cby in 0..num_cbs_y {
95        for cbx in 0..num_cbs_x {
96            let x0 = cbx * cb_width;
97            let y0 = cby * cb_height;
98            let block_width = (width - x0).min(cb_width);
99            let block_height = (height - y0).min(cb_height);
100            let mut block_coefficients = Vec::with_capacity(block_width * block_height);
101            for y in 0..block_height {
102                let row_start = (y0 + y) * width + x0;
103                block_coefficients
104                    .extend_from_slice(&quantized[row_start..row_start + block_width]);
105            }
106            code_blocks.push(PrequantizedHtj2k97CodeBlock {
107                coefficients: block_coefficients,
108                width: block_width as u32,
109                height: block_height as u32,
110            });
111        }
112    }
113
114    PrequantizedHtj2k97Subband {
115        sub_band_type,
116        num_cbs_x: num_cbs_x as u32,
117        num_cbs_y: num_cbs_y as u32,
118        total_bitplanes: htj2k97_subband_total_bitplanes(options, sub_band_type),
119        code_blocks,
120    }
121}
122
123/// Deadzone quantization step size `Δ` for a subband.
124///
125/// `Δ = 2^(range_bits − exponent) · (1 + mantissa/2048)`, with
126/// `range_bits = bit_depth + {LL:0, HL:1, LH:1, HH:2}` and the shared
127/// `(exponent, mantissa)` derived by this module's quantizer.
128#[must_use]
129pub fn htj2k97_subband_delta(
130    options: Htj2k97CodeBlockOptions,
131    sub_band_type: J2kSubBandType,
132) -> f64 {
133    let log_gain = match sub_band_type {
134        J2kSubBandType::LowLow => 0,
135        J2kSubBandType::HighLow | J2kSubBandType::LowHigh => 1,
136        J2kSubBandType::HighHigh => 2,
137    };
138    let range_bits = i32::from(options.bit_depth) + log_gain;
139    let (exponent, mantissa) = htj2k97_step(options, sub_band_type);
140    pow2i_f64(range_bits - i32::from(exponent)) * (1.0 + f64::from(mantissa) / 2048.0)
141}
142
143/// Total declared bitplanes for every code-block in a subband.
144///
145/// `saturating(guard_bits + exponent - 1)`. The exponent is derived from the
146/// effective global plus per-subband quantization profile, so callers must pass
147/// the actual subband kind.
148#[must_use]
149pub fn htj2k97_subband_total_bitplanes(
150    options: Htj2k97CodeBlockOptions,
151    sub_band_type: J2kSubBandType,
152) -> u8 {
153    let (exponent, _) = htj2k97_step(options, sub_band_type);
154    options
155        .guard_bits
156        .saturating_add(exponent)
157        .saturating_sub(1)
158}
159
160/// Validate 9/7 code-block options against the numeric limits both GPU
161/// backends must agree on, returning the decoded `(cb_width, cb_height)`.
162///
163/// One shared implementation keeps Metal and CUDA from drifting: the same
164/// options must be accepted or rejected identically by every backend. Errors
165/// are backend-neutral static strings for the caller's unsupported-job error.
166///
167/// # Errors
168/// Rejects zero/oversized bit depths and guard bits, non-finite or
169/// non-positive quantization scales, code-block dimensions beyond the HTJ2K
170/// limits (sides ≤ 1024, area ≤ 4096), and subband deltas or total bitplane
171/// counts outside the supported range.
172pub fn validate_htj2k97_codeblock_options(
173    options: Htj2k97CodeBlockOptions,
174) -> Result<(usize, usize), &'static str> {
175    if options.bit_depth == 0
176        || options.bit_depth > 30
177        || options.guard_bits > 30
178        || !options.irreversible_quantization_scale.is_finite()
179        || options.irreversible_quantization_scale <= 0.0
180    {
181        return Err("9/7 code-block options are outside supported numeric range");
182    }
183    let subband_scales = options.irreversible_quantization_subband_scales;
184    if [
185        subband_scales.low_low,
186        subband_scales.high_low,
187        subband_scales.low_high,
188        subband_scales.high_high,
189    ]
190    .iter()
191    .any(|scale| !scale.is_finite() || *scale <= 0.0)
192    {
193        return Err("9/7 code-block quantization options are outside supported range");
194    }
195
196    let cb_width = checked_code_block_dim(options.code_block_width_exp)?;
197    let cb_height = checked_code_block_dim(options.code_block_height_exp)?;
198    if cb_width > 1024
199        || cb_height > 1024
200        || cb_width
201            .checked_mul(cb_height)
202            .is_none_or(|area| area > 4096)
203    {
204        return Err("9/7 code-block dimensions exceed HTJ2K limits");
205    }
206
207    for subband in [
208        J2kSubBandType::LowLow,
209        J2kSubBandType::HighLow,
210        J2kSubBandType::LowHigh,
211        J2kSubBandType::HighHigh,
212    ] {
213        let delta = htj2k97_subband_delta(options, subband);
214        if !delta.is_finite()
215            || delta <= 0.0
216            || htj2k97_subband_total_bitplanes(options, subband) > 30
217        {
218            return Err("9/7 code-block quantization options are outside supported range");
219        }
220    }
221
222    Ok((cb_width, cb_height))
223}
224
225fn checked_code_block_dim(exp_minus_two: u8) -> Result<usize, &'static str> {
226    1usize
227        .checked_shl(u32::from(exp_minus_two) + 2)
228        .ok_or("9/7 code-block dimension exponent is unsupported")
229}
230
231fn quantize_subband_coefficients(
232    coefficients: &[f64],
233    sub_band_type: J2kSubBandType,
234    options: Htj2k97CodeBlockOptions,
235) -> Vec<i32> {
236    let delta = htj2k97_subband_delta(options, sub_band_type);
237    let inv_delta = 1.0 / delta;
238
239    coefficients
240        .iter()
241        .map(|&coefficient| {
242            // Deadzone quantization: q = sign(c) · floor(|c| · (1/Δ)), sign(0) = +1.
243            let sign = if coefficient < 0.0 { -1 } else { 1 };
244            sign * (coefficient.abs() * inv_delta).floor() as i32
245        })
246        .collect()
247}
248
249/// Shared `(exponent, mantissa)` for the irreversible 9/7 quantizer.
250fn htj2k97_step(options: Htj2k97CodeBlockOptions, sub_band_type: J2kSubBandType) -> (u8, u16) {
251    let step = irreversible_quantization_step_for_subband(
252        options.bit_depth,
253        options.guard_bits,
254        options.irreversible_quantization_scale,
255        options.irreversible_quantization_subband_scales,
256        sub_band_type,
257    );
258    (step.exponent, step.mantissa)
259}
260
261fn pow2i_f64(exp: i32) -> f64 {
262    2.0f64.powi(exp)
263}
264
265fn htj2k97_code_block_dim(exp_minus_two: u8) -> usize {
266    1usize
267        .checked_shl(u32::from(exp_minus_two) + 2)
268        .unwrap_or(usize::MAX)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use j2k::adapter::encode_stage::{
275        IrreversibleQuantizationSubbandScales, PrequantizedHtj2k97Image,
276    };
277    use j2k_native::{
278        encode_precomputed_htj2k_97, encode_prequantized_htj2k_97, EncodeOptions,
279        J2kForwardDwt97Level, J2kForwardDwt97Output, PrecomputedHtj2k97Component,
280        PrecomputedHtj2k97Image,
281    };
282
283    // Boundary-free coefficients on a 0.25 grid: exact in both f32 and f64, and
284    // every product with the scale-1.0 inverse deltas (4, 2, 1) lands on an exact
285    // integer/half-integer. So the f64 oracle and native's f32 quantizer agree
286    // bit-for-bit here and the codestream pin is exact, not merely close.
287    fn sample_band(len: usize, offset: f64) -> Vec<f64> {
288        (0..len)
289            .map(|idx| ((idx % 17) as f64 - 8.0) * 0.5 + offset)
290            .collect()
291    }
292
293    #[test]
294    fn oracle_prequantized_component_matches_native_precomputed_codestream() {
295        let width = 17u32;
296        let height = 13u32;
297        let low_width = width.div_ceil(2) as usize;
298        let low_height = height.div_ceil(2) as usize;
299        let high_width = (width / 2) as usize;
300        let high_height = (height / 2) as usize;
301
302        let ll = sample_band(low_width * low_height, 0.25);
303        let hl = sample_band(high_width * low_height, -0.75);
304        let lh = sample_band(low_width * high_height, 1.25);
305        let hh = sample_band(high_width * high_height, -1.5);
306
307        let options = EncodeOptions {
308            num_decomposition_levels: 1,
309            reversible: false,
310            guard_bits: 2,
311            code_block_width_exp: 2,
312            code_block_height_exp: 2,
313            ..EncodeOptions::default()
314        };
315
316        // Native precomputed-DWT path quantizes the f32 bands internally.
317        let precomputed_image = PrecomputedHtj2k97Image {
318            width,
319            height,
320            bit_depth: 8,
321            signed: false,
322            components: vec![PrecomputedHtj2k97Component {
323                x_rsiz: 1,
324                y_rsiz: 1,
325                dwt: J2kForwardDwt97Output {
326                    ll: ll.iter().map(|&v| v as f32).collect(),
327                    ll_width: low_width as u32,
328                    ll_height: low_height as u32,
329                    levels: vec![J2kForwardDwt97Level {
330                        hl: hl.iter().map(|&v| v as f32).collect(),
331                        lh: lh.iter().map(|&v| v as f32).collect(),
332                        hh: hh.iter().map(|&v| v as f32).collect(),
333                        width,
334                        height,
335                        low_width: low_width as u32,
336                        low_height: low_height as u32,
337                        high_width: high_width as u32,
338                        high_height: high_height as u32,
339                    }],
340                },
341            }],
342        };
343
344        // Oracle prequantized path (f64) over the same bands.
345        let dwt = Dwt97TwoDimensional {
346            ll,
347            hl,
348            lh,
349            hh,
350            low_width,
351            low_height,
352            high_width,
353            high_height,
354        };
355        let codeblock_options = Htj2k97CodeBlockOptions {
356            bit_depth: 8,
357            guard_bits: 2,
358            code_block_width_exp: 2,
359            code_block_height_exp: 2,
360            irreversible_quantization_scale: 1.0,
361            irreversible_quantization_subband_scales:
362                IrreversibleQuantizationSubbandScales::default(),
363        };
364        let component = prequantized_component_from_dwt97(&dwt, codeblock_options, 1, 1);
365        let prequantized_image = PrequantizedHtj2k97Image {
366            width,
367            height,
368            bit_depth: 8,
369            signed: false,
370            components: vec![component],
371        };
372
373        let expected = encode_precomputed_htj2k_97(&precomputed_image, &options)
374            .expect("native precomputed 9/7 encode");
375        let native_prequantized_image = native_prequantized_image(prequantized_image);
376        let actual = encode_prequantized_htj2k_97(&native_prequantized_image, &options)
377            .expect("oracle prequantized 9/7 encode");
378
379        assert_eq!(
380            actual, expected,
381            "oracle prequantized component must reproduce the native precomputed-DWT codestream"
382        );
383    }
384
385    #[test]
386    fn shared_validator_accepts_standard_options_and_returns_dims() {
387        let options = Htj2k97CodeBlockOptions {
388            bit_depth: 8,
389            guard_bits: 2,
390            code_block_width_exp: 4,
391            code_block_height_exp: 4,
392            irreversible_quantization_scale: 1.0,
393            irreversible_quantization_subband_scales:
394                IrreversibleQuantizationSubbandScales::default(),
395        };
396        assert_eq!(validate_htj2k97_codeblock_options(options), Ok((64, 64)));
397    }
398
399    #[test]
400    fn shared_validator_rejects_out_of_spec_options_on_every_backend() {
401        let valid = Htj2k97CodeBlockOptions {
402            bit_depth: 8,
403            guard_bits: 2,
404            code_block_width_exp: 4,
405            code_block_height_exp: 4,
406            irreversible_quantization_scale: 1.0,
407            irreversible_quantization_subband_scales:
408                IrreversibleQuantizationSubbandScales::default(),
409        };
410
411        // Each case was accepted by the old Metal-only validator.
412        let oversized_bit_depth = Htj2k97CodeBlockOptions {
413            bit_depth: 31,
414            ..valid
415        };
416        let oversized_guard_bits = Htj2k97CodeBlockOptions {
417            guard_bits: 31,
418            ..valid
419        };
420        // 1024x1024: each side passes the per-side cap, area breaks the
421        // HTJ2K 4096 limit.
422        let oversized_area = Htj2k97CodeBlockOptions {
423            code_block_width_exp: 8,
424            code_block_height_exp: 8,
425            ..valid
426        };
427        for options in [oversized_bit_depth, oversized_guard_bits, oversized_area] {
428            assert!(
429                validate_htj2k97_codeblock_options(options).is_err(),
430                "options must be rejected: {options:?}"
431            );
432        }
433
434        // guard_bits == 0 stays accepted (the old Metal validator rejected it,
435        // CUDA and the native encoder accept it).
436        let zero_guard_bits = Htj2k97CodeBlockOptions {
437            guard_bits: 0,
438            ..valid
439        };
440        assert!(validate_htj2k97_codeblock_options(zero_guard_bits).is_ok());
441    }
442
443    #[test]
444    fn oracle_subband_profile_changes_only_selected_delta_and_bitplanes() {
445        let mut options = Htj2k97CodeBlockOptions {
446            bit_depth: 8,
447            guard_bits: 2,
448            code_block_width_exp: 2,
449            code_block_height_exp: 2,
450            irreversible_quantization_scale: 1.9,
451            irreversible_quantization_subband_scales:
452                IrreversibleQuantizationSubbandScales::default(),
453        };
454        let high_low_delta = htj2k97_subband_delta(options, J2kSubBandType::HighLow);
455        let high_high_delta = htj2k97_subband_delta(options, J2kSubBandType::HighHigh);
456        let default_hh_bitplanes =
457            htj2k97_subband_total_bitplanes(options, J2kSubBandType::HighHigh);
458
459        options.irreversible_quantization_subband_scales.high_high = 1.5;
460
461        assert_eq!(
462            htj2k97_subband_delta(options, J2kSubBandType::HighLow).to_bits(),
463            high_low_delta.to_bits()
464        );
465        assert!(htj2k97_subband_delta(options, J2kSubBandType::HighHigh) > high_high_delta);
466        assert_ne!(
467            htj2k97_subband_total_bitplanes(options, J2kSubBandType::HighHigh),
468            default_hh_bitplanes
469        );
470    }
471
472    fn native_prequantized_image(
473        image: PrequantizedHtj2k97Image,
474    ) -> j2k_native::PrequantizedHtj2k97Image {
475        j2k_native::PrequantizedHtj2k97Image {
476            width: image.width,
477            height: image.height,
478            bit_depth: image.bit_depth,
479            signed: image.signed,
480            components: image
481                .components
482                .into_iter()
483                .map(|component| j2k_native::PrequantizedHtj2k97Component {
484                    x_rsiz: component.x_rsiz,
485                    y_rsiz: component.y_rsiz,
486                    resolutions: component
487                        .resolutions
488                        .into_iter()
489                        .map(|resolution| j2k_native::PrequantizedHtj2k97Resolution {
490                            subbands: resolution
491                                .subbands
492                                .into_iter()
493                                .map(|subband| j2k_native::PrequantizedHtj2k97Subband {
494                                    sub_band_type: subband.sub_band_type,
495                                    num_cbs_x: subband.num_cbs_x,
496                                    num_cbs_y: subband.num_cbs_y,
497                                    total_bitplanes: subband.total_bitplanes,
498                                    code_blocks: subband
499                                        .code_blocks
500                                        .into_iter()
501                                        .map(|block| j2k_native::PrequantizedHtj2k97CodeBlock {
502                                            coefficients: block.coefficients,
503                                            width: block.width,
504                                            height: block.height,
505                                        })
506                                        .collect(),
507                                })
508                                .collect(),
509                        })
510                        .collect(),
511                })
512                .collect(),
513        }
514    }
515}