Skip to main content

j2k_native/j2c/
decode.rs

1//! Decoding JPEG2000 code streams.
2//!
3//! This is the "core" module of the crate that orchestrates all
4//! stages in such a way that a given codestream is decoded into its
5//! component channels.
6
7use alloc::boxed::Box;
8use alloc::vec;
9use alloc::vec::Vec;
10
11use super::bitplane::{BitPlaneDecodeBuffers, BitPlaneDecodeContext};
12use super::build::{CodeBlock, Decomposition, Layer, Precinct, Segment, SubBand, SubBandType};
13use super::codestream::{ComponentInfo, Header, QuantizationStyle, WaveletTransform};
14use super::ht_block_decode::{self, HtBlockDecodeContext};
15use super::idwt::IDWTOutput;
16use super::progression::{progression_iterator, ProgressionData};
17use super::roi::RoiPlan;
18use super::tag_tree::TagNode;
19use super::tile::{ComponentTile, ResolutionTile, Tile};
20use super::{bitplane, build, idwt, mct, segment, tile, ComponentData};
21use crate::error::{bail, ColorError, DecodingError, Result, TileError};
22use crate::j2c::segment::MAX_BITPLANE_COUNT;
23use crate::math::SimdBuffer;
24use crate::profile;
25use crate::reader::BitReader;
26use crate::{
27    add_roi_shift_to_bitplanes, apply_roi_maxshift_inverse_i32, apply_roi_maxshift_inverse_i64,
28    checked_decode_byte_len3, checked_decode_sample_count, decode_j2k_code_block_scalar,
29    HtCodeBlockBatchJob, HtCodeBlockDecodeJob, HtCodeBlockDecoder, HtOwnedCodeBlockBatchJob,
30    HtOwnedSubBandPlan, HtSubBandDecodeJob, J2kCodeBlockBatchJob, J2kCodeBlockDecodeJob,
31    J2kCodeBlockSegment, J2kCodeBlockStyle, J2kDirectBandId, J2kDirectColorPlan,
32    J2kDirectGrayscalePlan, J2kDirectGrayscaleStep, J2kDirectIdwtStep, J2kDirectStoreStep,
33    J2kOwnedCodeBlockBatchJob, J2kOwnedSubBandPlan, J2kRect, J2kStoreComponentJob,
34    J2kSubBandDecodeJob, J2kSubBandType, J2kWaveletTransform,
35};
36#[cfg(feature = "parallel")]
37use crate::{decode_ht_code_block_scalar_with_workspace, HtCodeBlockDecodeWorkspace};
38use core::mem::size_of;
39use core::ops::{DerefMut, Range};
40
41pub(crate) fn decode<'a>(
42    data: &'a [u8],
43    header: &Header<'a>,
44    ctx: &mut DecoderContext<'a>,
45    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
46) -> Result<()> {
47    let mut reader = BitReader::new(data);
48    let profile_enabled = profile::profile_stages_enabled();
49    let total_start = profile::profile_now(profile_enabled);
50    let mut profile_timings = DecodeProfileTimings::default();
51    let stage_start = profile::profile_now(profile_enabled);
52    let tiles = tile::parse(&mut reader, header)?;
53    profile_timings.parse_tiles_us += profile::elapsed_us(stage_start);
54
55    if tiles.is_empty() {
56        bail!(TileError::Invalid);
57    }
58
59    ctx.reset(header, &tiles[0])?;
60    let cpu_decode_parallelism = ctx.cpu_decode_parallelism;
61    let (tile_ctx, storage) = (&mut ctx.tile_decode_context, &mut ctx.storage);
62
63    for tile in &tiles {
64        ltrace!(
65            "tile {} rect [{},{} {}x{}]",
66            tile.idx,
67            tile.rect.x0,
68            tile.rect.y0,
69            tile.rect.width(),
70            tile.rect.height(),
71        );
72
73        decode_tile(
74            tile,
75            header,
76            progression_iterator(tile)?,
77            tile_ctx,
78            storage,
79            ht_decoder,
80            cpu_decode_parallelism,
81            profile_enabled,
82            &mut profile_timings,
83        )?;
84    }
85
86    // Note that this assumes that either all tiles have MCT or none of them.
87    // In theory, only some could have it... But hopefully no such cursed
88    // images exist!
89    if tiles[0].mct {
90        let stage_start = profile::profile_now(profile_enabled);
91        mct::apply_inverse(tile_ctx, &tiles[0].component_infos, header, ht_decoder)?;
92        apply_sign_shift(tile_ctx, &header.component_infos);
93        profile_timings.mct_us += profile::elapsed_us(stage_start);
94    }
95
96    if profile_enabled {
97        emit_decode_profile_row(tile_ctx, &profile_timings, total_start);
98    }
99
100    Ok(())
101}
102
103pub(crate) fn build_direct_grayscale_plan<'a>(
104    data: &'a [u8],
105    header: &Header<'a>,
106    ctx: &mut DecoderContext<'a>,
107) -> Result<J2kDirectGrayscalePlan> {
108    let mut reader = BitReader::new(data);
109    let tiles = tile::parse(&mut reader, header)?;
110
111    if tiles.len() != 1 {
112        bail!(DecodingError::UnsupportedFeature(
113            "direct grayscale plan only supports single-tile codestreams"
114        ));
115    }
116
117    let tile = &tiles[0];
118    if tile.component_infos.len() != 1 {
119        bail!(DecodingError::UnsupportedFeature(
120            "direct grayscale plan only supports single-component codestreams"
121        ));
122    }
123    ctx.tile_decode_context.channel_data.clear();
124    ctx.storage.reset();
125
126    build::build(tile, &mut ctx.storage)?;
127    if let Some(output_region) = ctx.tile_decode_context.output_region {
128        ctx.storage.roi_plan = RoiPlan::build(tile, header, &ctx.storage, output_region);
129    }
130
131    segment::parse(tile, progression_iterator(tile)?, header, &mut ctx.storage)?;
132
133    let component_info = &tile.component_infos[0];
134    build_component_plan_from_storage(
135        tile,
136        header,
137        &ctx.storage,
138        0,
139        component_unsigned_level_shift(component_info),
140    )
141}
142
143pub(crate) fn build_direct_color_plan<'a>(
144    data: &'a [u8],
145    header: &Header<'a>,
146    ctx: &mut DecoderContext<'a>,
147) -> Result<J2kDirectColorPlan> {
148    let mut reader = BitReader::new(data);
149    let tiles = tile::parse(&mut reader, header)?;
150
151    if tiles.len() != 1 {
152        bail!(DecodingError::UnsupportedFeature(
153            "direct color plan only supports single-tile codestreams"
154        ));
155    }
156
157    let tile = &tiles[0];
158    if tile.component_infos.len() != 3 {
159        bail!(DecodingError::UnsupportedFeature(
160            "direct color plan only supports three-component RGB codestreams"
161        ));
162    }
163    let transform = tile.component_infos[0].wavelet_transform();
164    if tile.mct
165        && (transform != tile.component_infos[1].wavelet_transform()
166            || transform != tile.component_infos[2].wavelet_transform())
167    {
168        bail!(ColorError::Mct);
169    }
170
171    ctx.tile_decode_context.channel_data.clear();
172    ctx.storage.reset();
173
174    build::build(tile, &mut ctx.storage)?;
175    if let Some(output_region) = ctx.tile_decode_context.output_region {
176        ctx.storage.roi_plan = RoiPlan::build(tile, header, &ctx.storage, output_region);
177    }
178
179    segment::parse(tile, progression_iterator(tile)?, header, &mut ctx.storage)?;
180
181    let mut bit_depths = [0_u8; 3];
182    let mut component_plans = Vec::with_capacity(3);
183    for (component_idx, bit_depth) in bit_depths.iter_mut().enumerate() {
184        let component_info = &tile.component_infos[component_idx];
185        *bit_depth = component_info.size_info.precision;
186        let addend = if tile.mct {
187            0.0
188        } else {
189            component_unsigned_level_shift(component_info)
190        };
191        component_plans.push(build_component_plan_from_storage(
192            tile,
193            header,
194            &ctx.storage,
195            component_idx,
196            addend,
197        )?);
198    }
199
200    Ok(J2kDirectColorPlan {
201        dimensions: (
202            header.size_data.image_width(),
203            header.size_data.image_height(),
204        ),
205        bit_depths,
206        mct: tile.mct,
207        transform: J2kWaveletTransform::from(transform),
208        component_plans,
209    })
210}
211
212fn build_component_plan_from_storage(
213    tile: &Tile<'_>,
214    header: &Header<'_>,
215    storage: &DecompositionStorage<'_>,
216    component_idx: usize,
217    store_addend: f32,
218) -> Result<J2kDirectGrayscalePlan> {
219    let component_info =
220        tile.component_infos
221            .get(component_idx)
222            .ok_or(DecodingError::UnsupportedFeature(
223                "direct component plan index is out of range",
224            ))?;
225    if component_info.size_info.horizontal_resolution != 1
226        || component_info.size_info.vertical_resolution != 1
227    {
228        bail!(DecodingError::UnsupportedFeature(
229            "direct component plan only supports unit-sampled components"
230        ));
231    }
232
233    let tile_decompositions =
234        storage
235            .tile_decompositions
236            .get(component_idx)
237            .ok_or(DecodingError::UnsupportedFeature(
238                "direct component decomposition index is out of range",
239            ))?;
240    let decompositions = &storage.decompositions[tile_decompositions.decompositions.clone()];
241    let active_decomposition_count = decompositions
242        .len()
243        .saturating_sub(header.skipped_resolution_levels as usize);
244    let sub_band_step_count = (0..component_info.num_resolution_levels()
245        - header.skipped_resolution_levels)
246        .map(|resolution| {
247            tile_decompositions
248                .sub_band_iter(resolution, &storage.decompositions)
249                .count()
250        })
251        .sum::<usize>();
252    let mut steps =
253        Vec::with_capacity(sub_band_step_count + active_decomposition_count.saturating_add(1));
254    let mut next_band_id: J2kDirectBandId = 0;
255    let mut sub_band_ids = vec![None; storage.sub_bands.len()];
256
257    for resolution in 0..component_info.num_resolution_levels() - header.skipped_resolution_levels {
258        let sub_band_iter = tile_decompositions.sub_band_iter(resolution, &storage.decompositions);
259        for sub_band_idx in sub_band_iter {
260            if let Some(step) = build_grayscale_sub_band_step(
261                &storage.sub_bands[sub_band_idx],
262                sub_band_idx,
263                next_band_id,
264                resolution,
265                component_info,
266                storage,
267                header,
268            )? {
269                sub_band_ids[sub_band_idx] = Some(next_band_id);
270                next_band_id = next_band_id
271                    .checked_add(1)
272                    .ok_or(DecodingError::CodeBlockDecodeFailure)?;
273                steps.push(step);
274            }
275        }
276    }
277
278    let mut current_ll_rect = storage.sub_bands[tile_decompositions.first_ll_sub_band].rect;
279    let mut current_ll_band_id = sub_band_ids[tile_decompositions.first_ll_sub_band]
280        .ok_or(DecodingError::CodeBlockDecodeFailure)?;
281    let decompositions = &decompositions[..active_decomposition_count];
282    for decomposition in decompositions {
283        let hl = &storage.sub_bands[decomposition.sub_bands[0]];
284        let lh = &storage.sub_bands[decomposition.sub_bands[1]];
285        let hh = &storage.sub_bands[decomposition.sub_bands[2]];
286        let output_band_id = next_band_id;
287        next_band_id = next_band_id
288            .checked_add(1)
289            .ok_or(DecodingError::CodeBlockDecodeFailure)?;
290        steps.push(J2kDirectGrayscaleStep::Idwt(J2kDirectIdwtStep {
291            output_band_id,
292            rect: J2kRect::from(decomposition.rect),
293            transform: J2kWaveletTransform::from(component_info.wavelet_transform()),
294            ll_band_id: current_ll_band_id,
295            ll: J2kRect::from(current_ll_rect),
296            hl_band_id: sub_band_ids[decomposition.sub_bands[0]]
297                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
298            hl: J2kRect::from(hl.rect),
299            lh_band_id: sub_band_ids[decomposition.sub_bands[1]]
300                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
301            lh: J2kRect::from(lh.rect),
302            hh_band_id: sub_band_ids[decomposition.sub_bands[2]]
303                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
304            hh: J2kRect::from(hh.rect),
305        }));
306        current_ll_rect = decomposition.rect;
307        current_ll_band_id = output_band_id;
308    }
309
310    let component_tile = ComponentTile::new(tile, component_info);
311    let resolution_tile = ResolutionTile::new(
312        component_tile,
313        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
314    );
315    let image_x_offset = header.size_data.image_area_x_offset;
316    let image_y_offset = header.size_data.image_area_y_offset;
317    let source_x = image_x_offset.saturating_sub(current_ll_rect.x0);
318    let source_y = image_y_offset.saturating_sub(current_ll_rect.y0);
319    let copy_width = resolution_tile
320        .rect
321        .width()
322        .min(current_ll_rect.width().saturating_sub(source_x));
323    let copy_height = resolution_tile
324        .rect
325        .height()
326        .min(current_ll_rect.height().saturating_sub(source_y));
327    let output_x = resolution_tile.rect.x0.saturating_sub(image_x_offset);
328    let output_y = resolution_tile.rect.y0.saturating_sub(image_y_offset);
329    steps.push(J2kDirectGrayscaleStep::Store(J2kDirectStoreStep {
330        input_band_id: current_ll_band_id,
331        input_rect: J2kRect::from(current_ll_rect),
332        source_x,
333        source_y,
334        copy_width,
335        copy_height,
336        output_width: header.size_data.image_width(),
337        output_height: header.size_data.image_height(),
338        output_x,
339        output_y,
340        addend: store_addend,
341    }));
342
343    Ok(J2kDirectGrayscalePlan {
344        dimensions: (
345            header.size_data.image_width(),
346            header.size_data.image_height(),
347        ),
348        bit_depth: component_info.size_info.precision,
349        steps,
350    })
351}
352
353fn build_grayscale_sub_band_step(
354    sub_band: &SubBand,
355    sub_band_idx: usize,
356    band_id: J2kDirectBandId,
357    resolution: u8,
358    component_info: &ComponentInfo,
359    storage: &DecompositionStorage<'_>,
360    header: &Header<'_>,
361) -> Result<Option<J2kDirectGrayscaleStep>> {
362    let dequantization_step = {
363        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
364        {
365            1.0
366        } else {
367            let (exponent, mantissa) =
368                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
369
370            let r_b = {
371                let log_gain = match sub_band.sub_band_type {
372                    SubBandType::LowLow => 0,
373                    SubBandType::LowHigh => 1,
374                    SubBandType::HighLow => 1,
375                    SubBandType::HighHigh => 2,
376                };
377
378                component_info.size_info.precision as u16 + log_gain
379            };
380
381            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
382        }
383    };
384
385    let num_bitplanes = {
386        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
387        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
388            .checked_add(exponent)
389            .and_then(|x| x.checked_sub(1))
390            .ok_or(DecodingError::InvalidBitplaneCount)?;
391
392        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
393            bail!(DecodingError::TooManyBitplanes);
394        }
395
396        num_bitplanes as u8
397    };
398
399    if component_info
400        .coding_style
401        .parameters
402        .code_block_style
403        .uses_high_throughput_block_coding()
404    {
405        let coded_bitplanes =
406            add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, 31)?;
407        let stripe_causal = component_info
408            .coding_style
409            .parameters
410            .code_block_style
411            .vertically_causal_context;
412        let mut jobs = Vec::with_capacity(direct_sub_band_job_capacity(sub_band, storage));
413        for precinct in sub_band
414            .precincts
415            .clone()
416            .map(|idx| &storage.precincts[idx])
417        {
418            for code_block in precinct
419                .code_blocks
420                .clone()
421                .map(|idx| &storage.code_blocks[idx])
422            {
423                if !code_block_required_by_index(storage, sub_band_idx, code_block) {
424                    continue;
425                }
426                let actual_bitplanes = if header.strict {
427                    coded_bitplanes
428                        .checked_sub(code_block.missing_bit_planes)
429                        .ok_or(DecodingError::InvalidBitplaneCount)?
430                } else {
431                    coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
432                };
433                let max_coding_passes = if actual_bitplanes == 0 {
434                    0
435                } else {
436                    1 + 3 * (actual_bitplanes - 1)
437                };
438                if code_block.number_of_coding_passes > max_coding_passes && header.strict {
439                    bail!(DecodingError::TooManyCodingPasses);
440                }
441                if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
442                    continue;
443                }
444
445                let combined = ht_block_decode::collect_code_block_data(code_block, storage)?;
446                jobs.push(HtOwnedCodeBlockBatchJob {
447                    output_x: code_block.rect.x0 - sub_band.rect.x0,
448                    output_y: code_block.rect.y0 - sub_band.rect.y0,
449                    data: combined.data,
450                    cleanup_length: combined.cleanup_length,
451                    refinement_length: combined.refinement_length,
452                    width: code_block.rect.width(),
453                    height: code_block.rect.height(),
454                    output_stride: sub_band.rect.width() as usize,
455                    missing_bit_planes: code_block.missing_bit_planes,
456                    number_of_coding_passes: code_block.number_of_coding_passes,
457                    num_bitplanes,
458                    roi_shift: component_info.roi_shift,
459                    stripe_causal,
460                    strict: header.strict,
461                    dequantization_step,
462                });
463            }
464        }
465
466        return Ok(Some(J2kDirectGrayscaleStep::HtSubBand(
467            HtOwnedSubBandPlan {
468                band_id,
469                rect: J2kRect::from(sub_band.rect),
470                width: sub_band.rect.width(),
471                height: sub_band.rect.height(),
472                jobs,
473            },
474        )));
475    }
476
477    let classic_job_sub_band_type = match sub_band.sub_band_type {
478        SubBandType::LowLow => J2kSubBandType::LowLow,
479        SubBandType::HighLow => J2kSubBandType::HighLow,
480        SubBandType::LowHigh => J2kSubBandType::LowHigh,
481        SubBandType::HighHigh => J2kSubBandType::HighHigh,
482    };
483    let classic_job_style = J2kCodeBlockStyle {
484        selective_arithmetic_coding_bypass: component_info
485            .coding_style
486            .parameters
487            .code_block_style
488            .selective_arithmetic_coding_bypass,
489        reset_context_probabilities: component_info
490            .coding_style
491            .parameters
492            .code_block_style
493            .reset_context_probabilities,
494        termination_on_each_pass: component_info
495            .coding_style
496            .parameters
497            .code_block_style
498            .termination_on_each_pass,
499        vertically_causal_context: component_info
500            .coding_style
501            .parameters
502            .code_block_style
503            .vertically_causal_context,
504        segmentation_symbols: component_info
505            .coding_style
506            .parameters
507            .code_block_style
508            .segmentation_symbols,
509    };
510
511    let mut jobs = Vec::with_capacity(direct_sub_band_job_capacity(sub_band, storage));
512    for precinct in sub_band
513        .precincts
514        .clone()
515        .map(|idx| &storage.precincts[idx])
516    {
517        for code_block in precinct
518            .code_blocks
519            .clone()
520            .map(|idx| &storage.code_blocks[idx])
521        {
522            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
523                continue;
524            }
525            let (combined_data, segments) = collect_classic_code_block_data(
526                code_block,
527                &component_info.coding_style.parameters.code_block_style,
528                storage,
529            )?;
530            jobs.push(J2kOwnedCodeBlockBatchJob {
531                output_x: code_block.rect.x0 - sub_band.rect.x0,
532                output_y: code_block.rect.y0 - sub_band.rect.y0,
533                data: combined_data,
534                segments,
535                width: code_block.rect.width(),
536                height: code_block.rect.height(),
537                output_stride: sub_band.rect.width() as usize,
538                missing_bit_planes: code_block.missing_bit_planes,
539                number_of_coding_passes: code_block.number_of_coding_passes,
540                total_bitplanes: num_bitplanes,
541                roi_shift: component_info.roi_shift,
542                sub_band_type: classic_job_sub_band_type,
543                style: classic_job_style,
544                strict: header.strict,
545                dequantization_step,
546            });
547        }
548    }
549
550    Ok(Some(J2kDirectGrayscaleStep::ClassicSubBand(
551        J2kOwnedSubBandPlan {
552            band_id,
553            rect: J2kRect::from(sub_band.rect),
554            width: sub_band.rect.width(),
555            height: sub_band.rect.height(),
556            jobs,
557        },
558    )))
559}
560
561fn direct_sub_band_job_capacity(sub_band: &SubBand, storage: &DecompositionStorage<'_>) -> usize {
562    sub_band
563        .precincts
564        .clone()
565        .map(|idx| storage.precincts[idx].code_blocks.len())
566        .sum()
567}
568
569fn collect_classic_code_block_data(
570    code_block: &CodeBlock,
571    style: &super::codestream::CodeBlockStyle,
572    storage: &DecompositionStorage<'_>,
573) -> Result<(Vec<u8>, Vec<J2kCodeBlockSegment>)> {
574    let mut combined_data = Vec::new();
575    let mut collected_segments = Vec::new();
576    let mut last_segment_idx = 0u8;
577    let mut segment_start_offset = 0usize;
578    let mut segment_start_coding_pass = 0u8;
579    let mut coding_passes = 0u8;
580    let is_normal_mode =
581        !style.selective_arithmetic_coding_bypass && !style.termination_on_each_pass;
582
583    for layer in &storage.layers[code_block.layers.start..code_block.layers.end] {
584        let Some(range) = layer.segments.clone() else {
585            continue;
586        };
587
588        for segment in &storage.segments[range] {
589            if segment.idx != last_segment_idx {
590                if segment.idx != last_segment_idx + 1 {
591                    bail!(DecodingError::CodeBlockDecodeFailure);
592                }
593                if coding_passes > segment_start_coding_pass
594                    || combined_data.len() > segment_start_offset
595                {
596                    let data_offset = u32::try_from(segment_start_offset)
597                        .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
598                    let data_length = u32::try_from(combined_data.len() - segment_start_offset)
599                        .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
600                    let use_arithmetic = if style.selective_arithmetic_coding_bypass {
601                        if segment_start_coding_pass <= 9 {
602                            true
603                        } else {
604                            segment_start_coding_pass.is_multiple_of(3)
605                        }
606                    } else {
607                        true
608                    };
609                    collected_segments.push(J2kCodeBlockSegment {
610                        data_offset,
611                        data_length,
612                        start_coding_pass: segment_start_coding_pass,
613                        end_coding_pass: coding_passes,
614                        use_arithmetic,
615                    });
616                }
617                segment_start_offset = combined_data.len();
618                segment_start_coding_pass = coding_passes;
619                last_segment_idx += 1;
620            }
621
622            combined_data.extend_from_slice(segment.data);
623            coding_passes = coding_passes.saturating_add(segment.coding_pases);
624        }
625    }
626
627    if coding_passes > segment_start_coding_pass || combined_data.len() > segment_start_offset {
628        let data_offset = u32::try_from(segment_start_offset)
629            .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
630        let data_length = u32::try_from(combined_data.len().saturating_sub(segment_start_offset))
631            .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
632        let use_arithmetic = if style.selective_arithmetic_coding_bypass {
633            if segment_start_coding_pass <= 9 {
634                true
635            } else {
636                segment_start_coding_pass.is_multiple_of(3)
637            }
638        } else {
639            true
640        };
641        collected_segments.push(J2kCodeBlockSegment {
642            data_offset,
643            data_length,
644            start_coding_pass: segment_start_coding_pass,
645            end_coding_pass: coding_passes,
646            use_arithmetic,
647        });
648    }
649
650    if is_normal_mode {
651        collected_segments.clear();
652        collected_segments.push(J2kCodeBlockSegment {
653            data_offset: 0,
654            data_length: u32::try_from(combined_data.len())
655                .map_err(|_| DecodingError::CodeBlockDecodeFailure)?,
656            start_coding_pass: 0,
657            end_coding_pass: coding_passes,
658            use_arithmetic: true,
659        });
660    }
661
662    if coding_passes != code_block.number_of_coding_passes {
663        bail!(DecodingError::CodeBlockDecodeFailure);
664    }
665
666    Ok((combined_data, collected_segments))
667}
668
669#[derive(Debug, Clone, Copy, PartialEq, Eq)]
670pub(crate) struct OutputRegion {
671    pub(crate) x: u32,
672    pub(crate) y: u32,
673    pub(crate) width: u32,
674    pub(crate) height: u32,
675}
676
677impl OutputRegion {
678    pub(crate) fn from_tuple(region: (u32, u32, u32, u32)) -> Self {
679        let (x, y, width, height) = region;
680        Self {
681            x,
682            y,
683            width,
684            height,
685        }
686    }
687
688    fn dimensions(self) -> (u32, u32) {
689        (self.width, self.height)
690    }
691}
692
693#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
694pub(crate) struct DecodeDebugCounters {
695    pub(crate) decoded_code_blocks: usize,
696    pub(crate) skipped_code_blocks: usize,
697    pub(crate) idwt_output_samples: usize,
698    pub(crate) ht_phase_stats: ht_block_decode::HtBlockDecodeStats,
699}
700
701/// CPU parallelism policy for native JPEG 2000 decode.
702#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
703pub enum CpuDecodeParallelism {
704    /// Allow a single tile decode to use internal code-block parallelism.
705    #[default]
706    Auto,
707    /// Keep code-block decode serial for callers that already parallelize tiles.
708    Serial,
709}
710
711/// A decoder context for decoding JPEG2000 images.
712pub struct DecoderContext<'a> {
713    pub(crate) tile_decode_context: TileDecodeContext,
714    pub(crate) storage: DecompositionStorage<'a>,
715    cpu_decode_parallelism: CpuDecodeParallelism,
716}
717
718impl Default for DecoderContext<'_> {
719    fn default() -> Self {
720        Self {
721            tile_decode_context: TileDecodeContext::default(),
722            storage: DecompositionStorage::default(),
723            cpu_decode_parallelism: CpuDecodeParallelism::Auto,
724        }
725    }
726}
727
728impl DecoderContext<'_> {
729    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) -> Result<()> {
730        self.tile_decode_context.reset(header, initial_tile)?;
731        self.storage.reset();
732        Ok(())
733    }
734
735    pub(crate) fn set_output_region(&mut self, output_region: Option<(u32, u32, u32, u32)>) {
736        self.tile_decode_context.output_region = output_region.map(OutputRegion::from_tuple);
737    }
738
739    /// Return the native CPU decode parallelism policy.
740    pub fn cpu_decode_parallelism(&self) -> CpuDecodeParallelism {
741        self.cpu_decode_parallelism
742    }
743
744    /// Set the native CPU decode parallelism policy.
745    pub fn set_cpu_decode_parallelism(&mut self, parallelism: CpuDecodeParallelism) {
746        self.cpu_decode_parallelism = parallelism;
747    }
748}
749
750fn decode_tile<'a, 'b>(
751    tile: &'b Tile<'a>,
752    header: &Header<'_>,
753    progression_iterator: Box<dyn Iterator<Item = ProgressionData> + '_>,
754    tile_ctx: &mut TileDecodeContext,
755    storage: &mut DecompositionStorage<'a>,
756    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
757    cpu_decode_parallelism: CpuDecodeParallelism,
758    profile_enabled: bool,
759    profile_timings: &mut DecodeProfileTimings,
760) -> Result<()> {
761    storage.reset();
762    storage.exact_integer_decode = tile_requires_exact_integer_decode(tile);
763    if storage.exact_integer_decode {
764        validate_exact_integer_decode_tile(tile)?;
765        if tile_ctx.output_region.is_some() {
766            bail!(DecodingError::UnsupportedFeature(
767                "25-38 bit region decode requires exact integer region IDWT support"
768            ));
769        }
770    }
771
772    // This is the method that orchestrates all steps.
773
774    // First, we build the decompositions, including their sub-bands, precincts
775    // and code blocks.
776    let stage_start = profile::profile_now(profile_enabled);
777    build::build(tile, storage)?;
778    if let Some(output_region) = tile_ctx.output_region {
779        storage.roi_plan = RoiPlan::build(tile, header, storage, output_region);
780        if storage.roi_plan.is_some() {
781            storage.coefficients.fill(0.0);
782            if storage.exact_integer_decode {
783                storage.coefficients_i64.fill(0);
784            }
785        }
786    }
787    profile_timings.build_us += profile::elapsed_us(stage_start);
788    // Next, we parse the layers/segments for each code block.
789    let stage_start = profile::profile_now(profile_enabled);
790    segment::parse(tile, progression_iterator, header, storage)?;
791    profile_timings.segment_us += profile::elapsed_us(stage_start);
792    // We then decode the bitplanes of each code block, yielding the
793    // (possibly dequantized) coefficients of each code block.
794    let stage_start = profile::profile_now(profile_enabled);
795    decode_component_tile_bit_planes(
796        tile,
797        tile_ctx,
798        storage,
799        header,
800        ht_decoder,
801        cpu_decode_parallelism,
802        profile_enabled,
803    )?;
804    profile_timings.codeblock_us += profile::elapsed_us(stage_start);
805
806    // Unlike before, we interleave the apply_idwt and store stages
807    // for each component tile so we can reuse allocations better.
808    for (idx, component_info) in header.component_infos.iter().enumerate() {
809        // Next, we apply the inverse discrete wavelet transform.
810        let stage_start = profile::profile_now(profile_enabled);
811        idwt::apply(
812            storage,
813            tile_ctx,
814            idx,
815            header,
816            component_info.wavelet_transform(),
817            ht_decoder,
818        )?;
819        profile_timings.idwt_us += profile::elapsed_us(stage_start);
820        // Finally, we store the raw samples for the tile area in the correct
821        // location. Note that in case we have MCT, we are not applying it yet.
822        // It will be applied in the very end once all tiles have been processed.
823        // The reason we do this is that applying MCT requires access to the
824        // data from _all_ components. If we didn't defer this until the end
825        // we would have to collect the IDWT outputs of all components before
826        // applying it. By not applying MCT here, we can get away with doing
827        // IDWT and store on a per-component basis. Thus, we only need to
828        // store one IDWT output at a time, allowing for better reuse of
829        // allocations.
830        let stage_start = profile::profile_now(profile_enabled);
831        store(tile, header, tile_ctx, component_info, idx, ht_decoder)?;
832        profile_timings.store_us += profile::elapsed_us(stage_start);
833    }
834
835    Ok(())
836}
837
838fn tile_requires_exact_integer_decode(tile: &Tile<'_>) -> bool {
839    tile.component_infos
840        .iter()
841        .any(ComponentInfo::requires_exact_integer_decode)
842}
843
844fn validate_exact_integer_decode_tile(tile: &Tile<'_>) -> Result<()> {
845    for component in &tile.component_infos {
846        if component.size_info.precision > 38 {
847            bail!(DecodingError::UnsupportedFeature(
848                "JPEG 2000 Part 1 component precision is limited to 38 bits"
849            ));
850        }
851        if component.wavelet_transform() != WaveletTransform::Reversible53 {
852            bail!(DecodingError::UnsupportedFeature(
853                "25-38 bit decode currently requires reversible 5/3 coding"
854            ));
855        }
856        if component.quantization_info.quantization_style != QuantizationStyle::NoQuantization {
857            bail!(DecodingError::UnsupportedFeature(
858                "25-38 bit decode currently requires reversible no-quantization coding"
859            ));
860        }
861    }
862    Ok(())
863}
864
865#[derive(Default)]
866struct DecodeProfileTimings {
867    parse_tiles_us: u128,
868    build_us: u128,
869    segment_us: u128,
870    codeblock_us: u128,
871    idwt_us: u128,
872    store_us: u128,
873    mct_us: u128,
874}
875
876#[cold]
877#[inline(never)]
878fn emit_decode_profile_row(
879    tile_ctx: &TileDecodeContext,
880    profile_timings: &DecodeProfileTimings,
881    total_start: Option<profile::ProfileInstant>,
882) {
883    profile::emit_profile_row(
884        "decode",
885        "cpu",
886        &[
887            ("parse_tiles_us", profile_timings.parse_tiles_us),
888            ("build_us", profile_timings.build_us),
889            ("segment_us", profile_timings.segment_us),
890            ("codeblock_us", profile_timings.codeblock_us),
891            ("ht_blocks", tile_ctx.debug_counters.ht_phase_stats.blocks),
892            (
893                "ht_refinement_blocks",
894                tile_ctx.debug_counters.ht_phase_stats.refinement_blocks,
895            ),
896            (
897                "ht_cleanup_bytes",
898                tile_ctx.debug_counters.ht_phase_stats.cleanup_bytes,
899            ),
900            (
901                "ht_refinement_bytes",
902                tile_ctx.debug_counters.ht_phase_stats.refinement_bytes,
903            ),
904            (
905                "ht_cleanup_us",
906                tile_ctx.debug_counters.ht_phase_stats.ht_cleanup_us,
907            ),
908            (
909                "ht_mag_sgn_us",
910                tile_ctx.debug_counters.ht_phase_stats.ht_mag_sgn_us,
911            ),
912            (
913                "ht_sigma_us",
914                tile_ctx.debug_counters.ht_phase_stats.ht_sigma_us,
915            ),
916            (
917                "ht_sigprop_us",
918                tile_ctx.debug_counters.ht_phase_stats.ht_sigprop_us,
919            ),
920            (
921                "ht_magref_us",
922                tile_ctx.debug_counters.ht_phase_stats.ht_magref_us,
923            ),
924            ("idwt_us", profile_timings.idwt_us),
925            ("store_us", profile_timings.store_us),
926            ("mct_us", profile_timings.mct_us),
927            ("total_us", profile::elapsed_us(total_start)),
928        ],
929    );
930}
931
932/// All decompositions for a single tile.
933#[derive(Clone)]
934pub(crate) struct TileDecompositions {
935    pub(crate) first_ll_sub_band: usize,
936    pub(crate) decompositions: Range<usize>,
937}
938
939impl TileDecompositions {
940    pub(crate) fn sub_band_iter(
941        &self,
942        resolution: u8,
943        decompositions: &[Decomposition],
944    ) -> SubBandIter {
945        let indices = if resolution == 0 {
946            [
947                self.first_ll_sub_band,
948                self.first_ll_sub_band,
949                self.first_ll_sub_band,
950            ]
951        } else {
952            decompositions[self.decompositions.clone()][resolution as usize - 1].sub_bands
953        };
954
955        SubBandIter {
956            next_idx: 0,
957            indices,
958            resolution,
959        }
960    }
961}
962
963#[derive(Clone)]
964pub(crate) struct SubBandIter {
965    resolution: u8,
966    next_idx: usize,
967    indices: [usize; 3],
968}
969
970impl Iterator for SubBandIter {
971    type Item = usize;
972
973    fn next(&mut self) -> Option<Self::Item> {
974        let value = if self.resolution == 0 {
975            if self.next_idx > 0 {
976                None
977            } else {
978                Some(self.indices[0])
979            }
980        } else if self.next_idx >= self.indices.len() {
981            None
982        } else {
983            Some(self.indices[self.next_idx])
984        };
985
986        self.next_idx += 1;
987
988        value
989    }
990}
991
992/// A buffer so that we can reuse allocations for layers/code blocks/etc.
993/// across different tiles.
994#[derive(Default)]
995pub(crate) struct DecompositionStorage<'a> {
996    pub(crate) segments: Vec<Segment<'a>>,
997    pub(crate) layers: Vec<Layer>,
998    pub(crate) code_blocks: Vec<CodeBlock>,
999    pub(crate) precincts: Vec<Precinct>,
1000    pub(crate) tag_tree_nodes: Vec<TagNode>,
1001    pub(crate) coefficients: Vec<f32>,
1002    pub(crate) coefficients_i64: Vec<i64>,
1003    pub(crate) sub_bands: Vec<SubBand>,
1004    pub(crate) decompositions: Vec<Decomposition>,
1005    pub(crate) tile_decompositions: Vec<TileDecompositions>,
1006    pub(crate) roi_plan: Option<RoiPlan>,
1007    pub(crate) exact_integer_decode: bool,
1008}
1009
1010impl DecompositionStorage<'_> {
1011    pub(crate) fn reset(&mut self) {
1012        self.segments.clear();
1013        self.layers.clear();
1014        self.code_blocks.clear();
1015        // No need to clear the coefficients, as they will be resized
1016        // and then overridden.
1017        // self.coefficients.clear();
1018        self.precincts.clear();
1019        self.sub_bands.clear();
1020        self.decompositions.clear();
1021        self.tile_decompositions.clear();
1022        self.tag_tree_nodes.clear();
1023        self.roi_plan = None;
1024        self.exact_integer_decode = false;
1025    }
1026}
1027
1028/// A reusable context used during the decoding of a single tile.
1029///
1030/// Some of the fields are temporary in nature and reset after moving on to the
1031/// next tile, some contain global state.
1032#[derive(Default)]
1033pub(crate) struct TileDecodeContext {
1034    /// A reusable buffer for the IDWT output.
1035    pub(crate) idwt_output: IDWTOutput,
1036    /// A scratch buffer used during IDWT.
1037    pub(crate) idwt_scratch_buffer: Vec<f32>,
1038    /// A scratch buffer used during exact reversible integer IDWT.
1039    pub(crate) idwt_scratch_buffer_i64: Vec<i64>,
1040    /// A reusable context for decoding code blocks.
1041    pub(crate) bit_plane_decode_context: BitPlaneDecodeContext,
1042    /// Reusable buffers for decoding bitplanes.
1043    pub(crate) bit_plane_decode_buffers: BitPlaneDecodeBuffers,
1044    /// A reusable context for decoding HTJ2K code blocks.
1045    pub(crate) ht_block_decode_context: HtBlockDecodeContext,
1046    /// The raw, decoded samples for each channel.
1047    pub(crate) channel_data: Vec<ComponentData>,
1048    /// Optional output window for region-local decode storage.
1049    pub(crate) output_region: Option<OutputRegion>,
1050    /// Debug counters for tests and ROI instrumentation.
1051    pub(crate) debug_counters: DecodeDebugCounters,
1052}
1053
1054impl TileDecodeContext {
1055    /// Reset the context for processing a new image.
1056    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) -> Result<()> {
1057        // Bitplane decode context and buffers will be reset in the
1058        // corresponding methods. IDWT output and scratch buffer will be
1059        // overridden on demand, so those don't need to be reset either.
1060        self.channel_data.clear();
1061        self.debug_counters = DecodeDebugCounters::default();
1062
1063        let (output_width, output_height) =
1064            self.output_region.map(OutputRegion::dimensions).unwrap_or((
1065                header.size_data.image_width(),
1066                header.size_data.image_height(),
1067            ));
1068
1069        let sample_count = checked_decode_sample_count(output_width, output_height)?;
1070        checked_decode_byte_len3(
1071            sample_count,
1072            initial_tile.component_infos.len(),
1073            size_of::<f32>(),
1074        )?;
1075        let exact_integer_decode = initial_tile
1076            .component_infos
1077            .iter()
1078            .any(ComponentInfo::requires_exact_integer_decode);
1079        if exact_integer_decode {
1080            checked_decode_byte_len3(
1081                sample_count,
1082                initial_tile.component_infos.len(),
1083                size_of::<i64>(),
1084            )?;
1085        }
1086
1087        // Allocate per component here; the surrounding context reuses the
1088        // higher-level vectors while `SimdBuffer` owns its initialized storage.
1089        for info in &initial_tile.component_infos {
1090            self.channel_data.push(ComponentData {
1091                container: SimdBuffer::zeros(sample_count),
1092                integer_container: exact_integer_decode.then(|| vec![0; sample_count]),
1093                bit_depth: info.size_info.precision,
1094                signed: info.size_info.signed,
1095            });
1096        }
1097        Ok(())
1098    }
1099}
1100
1101pub(crate) fn decode_component_tile_bit_planes<'a>(
1102    tile: &Tile<'a>,
1103    tile_ctx: &mut TileDecodeContext,
1104    storage: &mut DecompositionStorage<'a>,
1105    header: &Header<'_>,
1106    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1107    cpu_decode_parallelism: CpuDecodeParallelism,
1108    profile_enabled: bool,
1109) -> Result<()> {
1110    for (tile_decompositions_idx, component_info) in tile.component_infos.iter().enumerate() {
1111        // Only decode the resolution levels we actually care about.
1112        for resolution in
1113            0..component_info.num_resolution_levels() - header.skipped_resolution_levels
1114        {
1115            let tile_composition = &storage.tile_decompositions[tile_decompositions_idx];
1116            let sub_band_iter = tile_composition.sub_band_iter(resolution, &storage.decompositions);
1117
1118            for sub_band_idx in sub_band_iter {
1119                decode_sub_band_bitplanes(
1120                    sub_band_idx,
1121                    resolution,
1122                    component_info,
1123                    tile_ctx,
1124                    storage,
1125                    header,
1126                    ht_decoder,
1127                    cpu_decode_parallelism,
1128                    profile_enabled,
1129                )?;
1130            }
1131        }
1132    }
1133
1134    Ok(())
1135}
1136
1137fn decode_sub_band_bitplanes(
1138    sub_band_idx: usize,
1139    resolution: u8,
1140    component_info: &ComponentInfo,
1141    tile_ctx: &mut TileDecodeContext,
1142    storage: &mut DecompositionStorage<'_>,
1143    header: &Header<'_>,
1144    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1145    cpu_decode_parallelism: CpuDecodeParallelism,
1146    profile_enabled: bool,
1147) -> Result<()> {
1148    let sub_band = storage.sub_bands[sub_band_idx].clone();
1149
1150    let dequantization_step = {
1151        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
1152        {
1153            1.0
1154        } else {
1155            let (exponent, mantissa) =
1156                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
1157
1158            let r_b = {
1159                let log_gain = match sub_band.sub_band_type {
1160                    SubBandType::LowLow => 0,
1161                    SubBandType::LowHigh => 1,
1162                    SubBandType::HighLow => 1,
1163                    SubBandType::HighHigh => 2,
1164                };
1165
1166                component_info.size_info.precision as u16 + log_gain
1167            };
1168
1169            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
1170        }
1171    };
1172
1173    let num_bitplanes = {
1174        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
1175        // Equation (E-2)
1176        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
1177            .checked_add(exponent)
1178            .and_then(|x| x.checked_sub(1))
1179            .ok_or(DecodingError::InvalidBitplaneCount)?;
1180
1181        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
1182            bail!(DecodingError::TooManyBitplanes);
1183        }
1184
1185        num_bitplanes as u8
1186    };
1187
1188    if component_info
1189        .coding_style
1190        .parameters
1191        .code_block_style
1192        .uses_high_throughput_block_coding()
1193    {
1194        if storage.exact_integer_decode {
1195            decode_sub_band_ht_blocks_i64(
1196                sub_band_idx,
1197                &sub_band,
1198                component_info,
1199                tile_ctx,
1200                storage,
1201                header,
1202                num_bitplanes,
1203                profile_enabled,
1204            )?;
1205            return Ok(());
1206        }
1207        decode_sub_band_ht_blocks(
1208            sub_band_idx,
1209            &sub_band,
1210            component_info,
1211            tile_ctx,
1212            storage,
1213            header,
1214            ht_decoder,
1215            cpu_decode_parallelism,
1216            num_bitplanes,
1217            dequantization_step,
1218            profile_enabled,
1219        )?;
1220        return Ok(());
1221    }
1222
1223    let coded_bitplanes =
1224        add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, MAX_BITPLANE_COUNT)?;
1225
1226    if storage.exact_integer_decode {
1227        decode_sub_band_classic_blocks_i64(
1228            sub_band_idx,
1229            &sub_band,
1230            component_info,
1231            tile_ctx,
1232            storage,
1233            header,
1234            coded_bitplanes,
1235        )?;
1236        return Ok(());
1237    }
1238
1239    let classic_job_sub_band_type = match sub_band.sub_band_type {
1240        SubBandType::LowLow => J2kSubBandType::LowLow,
1241        SubBandType::HighLow => J2kSubBandType::HighLow,
1242        SubBandType::LowHigh => J2kSubBandType::LowHigh,
1243        SubBandType::HighHigh => J2kSubBandType::HighHigh,
1244    };
1245    let classic_job_style = J2kCodeBlockStyle {
1246        selective_arithmetic_coding_bypass: component_info
1247            .coding_style
1248            .parameters
1249            .code_block_style
1250            .selective_arithmetic_coding_bypass,
1251        reset_context_probabilities: component_info
1252            .coding_style
1253            .parameters
1254            .code_block_style
1255            .reset_context_probabilities,
1256        termination_on_each_pass: component_info
1257            .coding_style
1258            .parameters
1259            .code_block_style
1260            .termination_on_each_pass,
1261        vertically_causal_context: component_info
1262            .coding_style
1263            .parameters
1264            .code_block_style
1265            .vertically_causal_context,
1266        segmentation_symbols: component_info
1267            .coding_style
1268            .parameters
1269            .code_block_style
1270            .segmentation_symbols,
1271    };
1272
1273    if let Some(ht_decoder) = ht_decoder.as_deref_mut() {
1274        let pending_blocks =
1275            collect_pending_classic_blocks(sub_band_idx, &sub_band, component_info, storage)?;
1276
1277        let batch_jobs: Vec<_> = pending_blocks
1278            .iter()
1279            .map(|pending| J2kCodeBlockBatchJob {
1280                output_x: pending.output_x,
1281                output_y: pending.output_y,
1282                code_block: J2kCodeBlockDecodeJob {
1283                    data: &pending.combined_data,
1284                    segments: &pending.segments,
1285                    width: pending.width,
1286                    height: pending.height,
1287                    output_stride: sub_band.rect.width() as usize,
1288                    missing_bit_planes: pending.missing_bit_planes,
1289                    number_of_coding_passes: pending.number_of_coding_passes,
1290                    total_bitplanes: num_bitplanes,
1291                    roi_shift: component_info.roi_shift,
1292                    sub_band_type: classic_job_sub_band_type,
1293                    style: classic_job_style,
1294                    strict: header.strict,
1295                    dequantization_step,
1296                },
1297            })
1298            .collect();
1299
1300        let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1301        if ht_decoder.decode_j2k_sub_band(
1302            J2kSubBandDecodeJob {
1303                width: sub_band.rect.width(),
1304                height: sub_band.rect.height(),
1305                jobs: &batch_jobs,
1306            },
1307            base_store,
1308        )? {
1309            tile_ctx.debug_counters.decoded_code_blocks += batch_jobs.len();
1310            return Ok(());
1311        }
1312
1313        let output_stride = sub_band.rect.width() as usize;
1314        for job in batch_jobs {
1315            tile_ctx.debug_counters.decoded_code_blocks += 1;
1316            let base_idx = (job.output_y * sub_band.rect.width()) as usize + job.output_x as usize;
1317            let output_len = if job.code_block.height == 0 {
1318                0
1319            } else {
1320                output_stride
1321                    .checked_mul(job.code_block.height as usize - 1)
1322                    .and_then(|prefix| prefix.checked_add(job.code_block.width as usize))
1323                    .ok_or(DecodingError::CodeBlockDecodeFailure)?
1324            };
1325            let output_slice = &mut base_store[base_idx..base_idx + output_len];
1326            if ht_decoder.decode_j2k_code_block(job.code_block, output_slice)? {
1327                continue;
1328            }
1329            decode_j2k_code_block_scalar(job.code_block, output_slice)?;
1330        }
1331
1332        return Ok(());
1333    }
1334
1335    let code_block_count = count_classic_code_blocks(sub_band_idx, &sub_band, storage);
1336    if should_decode_classic_sub_band_in_parallel(cpu_decode_parallelism, code_block_count) {
1337        #[cfg(feature = "parallel")]
1338        {
1339            let pending_blocks =
1340                collect_pending_classic_blocks(sub_band_idx, &sub_band, component_info, storage)?;
1341            let decoded_blocks = decode_classic_sub_band_blocks_parallel(
1342                &pending_blocks,
1343                classic_job_sub_band_type,
1344                classic_job_style,
1345                header.strict,
1346                num_bitplanes,
1347                component_info.roi_shift,
1348                dequantization_step,
1349            )?;
1350            tile_ctx.debug_counters.decoded_code_blocks += decoded_blocks.len();
1351            copy_decoded_classic_blocks_to_sub_band(&decoded_blocks, &sub_band, storage)?;
1352            return Ok(());
1353        }
1354    }
1355
1356    for precinct in sub_band
1357        .precincts
1358        .clone()
1359        .map(|idx| &storage.precincts[idx])
1360    {
1361        for code_block in precinct
1362            .code_blocks
1363            .clone()
1364            .map(|idx| &storage.code_blocks[idx])
1365        {
1366            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1367                tile_ctx.debug_counters.skipped_code_blocks += 1;
1368                continue;
1369            }
1370            tile_ctx.debug_counters.decoded_code_blocks += 1;
1371            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
1372            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
1373            let output_stride = sub_band.rect.width() as usize;
1374            let base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
1375
1376            bitplane::decode(
1377                code_block,
1378                sub_band.sub_band_type,
1379                coded_bitplanes,
1380                &component_info.coding_style.parameters.code_block_style,
1381                tile_ctx,
1382                storage,
1383                header.strict,
1384            )?;
1385
1386            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1387            let mut base_idx = base_idx;
1388
1389            for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
1390                let out_row = &mut base_store[base_idx..];
1391
1392                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
1393                    let coefficient = apply_roi_maxshift_inverse_i64(
1394                        coefficient.get_i64(),
1395                        component_info.roi_shift,
1396                    );
1397                    *output = coefficient as f32;
1398                    *output *= dequantization_step;
1399                }
1400
1401                base_idx += output_stride;
1402            }
1403        }
1404    }
1405
1406    Ok(())
1407}
1408
1409fn decode_sub_band_classic_blocks_i64(
1410    sub_band_idx: usize,
1411    sub_band: &SubBand,
1412    component_info: &ComponentInfo,
1413    tile_ctx: &mut TileDecodeContext,
1414    storage: &mut DecompositionStorage<'_>,
1415    header: &Header<'_>,
1416    coded_bitplanes: u8,
1417) -> Result<()> {
1418    for precinct in sub_band
1419        .precincts
1420        .clone()
1421        .map(|idx| &storage.precincts[idx])
1422    {
1423        for code_block in precinct
1424            .code_blocks
1425            .clone()
1426            .map(|idx| &storage.code_blocks[idx])
1427        {
1428            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1429                tile_ctx.debug_counters.skipped_code_blocks += 1;
1430                continue;
1431            }
1432            tile_ctx.debug_counters.decoded_code_blocks += 1;
1433            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
1434            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
1435            let output_stride = sub_band.rect.width() as usize;
1436            let base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
1437
1438            bitplane::decode(
1439                code_block,
1440                sub_band.sub_band_type,
1441                coded_bitplanes,
1442                &component_info.coding_style.parameters.code_block_style,
1443                tile_ctx,
1444                storage,
1445                header.strict,
1446            )?;
1447
1448            let base_store = &mut storage.coefficients_i64[sub_band.coefficients.clone()];
1449            let mut base_idx = base_idx;
1450
1451            for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
1452                let out_row = &mut base_store[base_idx..];
1453
1454                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
1455                    *output = apply_roi_maxshift_inverse_i64(
1456                        coefficient.get_i64(),
1457                        component_info.roi_shift,
1458                    );
1459                }
1460
1461                base_idx += output_stride;
1462            }
1463        }
1464    }
1465
1466    Ok(())
1467}
1468
1469struct PendingHtBlock {
1470    combined: ht_block_decode::CombinedCodeBlockData,
1471    output_x: u32,
1472    output_y: u32,
1473    width: u32,
1474    height: u32,
1475    missing_bit_planes: u8,
1476    number_of_coding_passes: u8,
1477}
1478
1479struct PendingClassicBlock {
1480    combined_data: Vec<u8>,
1481    segments: Vec<J2kCodeBlockSegment>,
1482    output_x: u32,
1483    output_y: u32,
1484    width: u32,
1485    height: u32,
1486    missing_bit_planes: u8,
1487    number_of_coding_passes: u8,
1488}
1489
1490#[cfg(feature = "parallel")]
1491struct DecodedClassicBlock {
1492    output_x: u32,
1493    output_y: u32,
1494    width: u32,
1495    height: u32,
1496    coefficients: Vec<f32>,
1497}
1498
1499#[cfg(feature = "parallel")]
1500struct DecodedHtBlock {
1501    output_x: u32,
1502    output_y: u32,
1503    width: u32,
1504    height: u32,
1505    coefficients: Vec<f32>,
1506}
1507
1508fn count_classic_code_blocks(
1509    sub_band_idx: usize,
1510    sub_band: &SubBand,
1511    storage: &DecompositionStorage<'_>,
1512) -> usize {
1513    sub_band
1514        .precincts
1515        .clone()
1516        .map(|idx| &storage.precincts[idx])
1517        .map(|precinct| {
1518            precinct
1519                .code_blocks
1520                .clone()
1521                .filter(|idx| {
1522                    let code_block = &storage.code_blocks[*idx];
1523                    code_block_required_by_index(storage, sub_band_idx, code_block)
1524                })
1525                .count()
1526        })
1527        .sum()
1528}
1529
1530fn code_block_required_by_index(
1531    storage: &DecompositionStorage<'_>,
1532    sub_band_idx: usize,
1533    code_block: &CodeBlock,
1534) -> bool {
1535    storage
1536        .roi_plan
1537        .as_ref()
1538        .is_none_or(|plan| plan.code_block_required(sub_band_idx, code_block.rect))
1539}
1540
1541fn collect_pending_classic_blocks(
1542    sub_band_idx: usize,
1543    sub_band: &SubBand,
1544    component_info: &ComponentInfo,
1545    storage: &DecompositionStorage<'_>,
1546) -> Result<Vec<PendingClassicBlock>> {
1547    let mut pending_blocks =
1548        Vec::with_capacity(count_classic_code_blocks(sub_band_idx, sub_band, storage));
1549    for precinct in sub_band
1550        .precincts
1551        .clone()
1552        .map(|idx| &storage.precincts[idx])
1553    {
1554        for code_block in precinct
1555            .code_blocks
1556            .clone()
1557            .map(|idx| &storage.code_blocks[idx])
1558        {
1559            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1560                continue;
1561            }
1562            let (combined_data, segments) = collect_classic_code_block_data(
1563                code_block,
1564                &component_info.coding_style.parameters.code_block_style,
1565                storage,
1566            )?;
1567            pending_blocks.push(PendingClassicBlock {
1568                combined_data,
1569                segments,
1570                output_x: code_block.rect.x0 - sub_band.rect.x0,
1571                output_y: code_block.rect.y0 - sub_band.rect.y0,
1572                width: code_block.rect.width(),
1573                height: code_block.rect.height(),
1574                missing_bit_planes: code_block.missing_bit_planes,
1575                number_of_coding_passes: code_block.number_of_coding_passes,
1576            });
1577        }
1578    }
1579    Ok(pending_blocks)
1580}
1581
1582fn count_ht_code_blocks(
1583    sub_band_idx: usize,
1584    sub_band: &SubBand,
1585    storage: &DecompositionStorage<'_>,
1586) -> usize {
1587    sub_band
1588        .precincts
1589        .clone()
1590        .map(|idx| &storage.precincts[idx])
1591        .map(|precinct| {
1592            precinct
1593                .code_blocks
1594                .clone()
1595                .filter(|idx| {
1596                    let code_block = &storage.code_blocks[*idx];
1597                    code_block_required_by_index(storage, sub_band_idx, code_block)
1598                        && code_block.number_of_coding_passes > 0
1599                })
1600                .count()
1601        })
1602        .sum()
1603}
1604
1605#[cfg(feature = "parallel")]
1606fn collect_pending_ht_blocks(
1607    sub_band_idx: usize,
1608    sub_band: &SubBand,
1609    storage: &DecompositionStorage<'_>,
1610    header: &Header<'_>,
1611    num_bitplanes: u8,
1612    roi_shift: u8,
1613) -> Result<Vec<PendingHtBlock>> {
1614    let coded_bitplanes = add_roi_shift_to_bitplanes(num_bitplanes, roi_shift, 31)?;
1615    let mut pending_blocks =
1616        Vec::with_capacity(count_ht_code_blocks(sub_band_idx, sub_band, storage));
1617    for precinct in sub_band
1618        .precincts
1619        .clone()
1620        .map(|idx| &storage.precincts[idx])
1621    {
1622        for code_block in precinct
1623            .code_blocks
1624            .clone()
1625            .map(|idx| &storage.code_blocks[idx])
1626        {
1627            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1628                continue;
1629            }
1630            let actual_bitplanes = if header.strict {
1631                coded_bitplanes
1632                    .checked_sub(code_block.missing_bit_planes)
1633                    .ok_or(DecodingError::InvalidBitplaneCount)?
1634            } else {
1635                coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
1636            };
1637            let max_coding_passes = if actual_bitplanes == 0 {
1638                0
1639            } else {
1640                1 + 3 * (actual_bitplanes - 1)
1641            };
1642            if code_block.number_of_coding_passes > max_coding_passes && header.strict {
1643                bail!(DecodingError::TooManyCodingPasses);
1644            }
1645            if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
1646                continue;
1647            }
1648
1649            pending_blocks.push(PendingHtBlock {
1650                combined: ht_block_decode::collect_code_block_data(code_block, storage)?,
1651                output_x: code_block.rect.x0 - sub_band.rect.x0,
1652                output_y: code_block.rect.y0 - sub_band.rect.y0,
1653                width: code_block.rect.width(),
1654                height: code_block.rect.height(),
1655                missing_bit_planes: code_block.missing_bit_planes,
1656                number_of_coding_passes: code_block.number_of_coding_passes,
1657            });
1658        }
1659    }
1660    Ok(pending_blocks)
1661}
1662
1663pub(crate) fn should_decode_classic_sub_band_in_parallel(
1664    parallelism: CpuDecodeParallelism,
1665    code_block_count: usize,
1666) -> bool {
1667    cfg!(feature = "parallel") && parallelism == CpuDecodeParallelism::Auto && code_block_count >= 4
1668}
1669
1670pub(crate) fn should_decode_ht_sub_band_in_parallel(
1671    parallelism: CpuDecodeParallelism,
1672    code_block_count: usize,
1673) -> bool {
1674    cfg!(feature = "parallel") && parallelism == CpuDecodeParallelism::Auto && code_block_count >= 4
1675}
1676
1677#[cfg(feature = "parallel")]
1678fn decode_classic_sub_band_blocks_parallel(
1679    pending_blocks: &[PendingClassicBlock],
1680    sub_band_type: J2kSubBandType,
1681    style: J2kCodeBlockStyle,
1682    strict: bool,
1683    total_bitplanes: u8,
1684    roi_shift: u8,
1685    dequantization_step: f32,
1686) -> Result<Vec<DecodedClassicBlock>> {
1687    use rayon::prelude::*;
1688
1689    pending_blocks
1690        .par_iter()
1691        .map(|pending| {
1692            let output_stride = pending.width as usize;
1693            let output_len = output_stride
1694                .checked_mul(pending.height as usize)
1695                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1696            let mut coefficients = vec![0.0; output_len];
1697            decode_j2k_code_block_scalar(
1698                J2kCodeBlockDecodeJob {
1699                    data: &pending.combined_data,
1700                    segments: &pending.segments,
1701                    width: pending.width,
1702                    height: pending.height,
1703                    output_stride,
1704                    missing_bit_planes: pending.missing_bit_planes,
1705                    number_of_coding_passes: pending.number_of_coding_passes,
1706                    total_bitplanes,
1707                    roi_shift,
1708                    sub_band_type,
1709                    style,
1710                    strict,
1711                    dequantization_step,
1712                },
1713                &mut coefficients,
1714            )?;
1715            Ok(DecodedClassicBlock {
1716                output_x: pending.output_x,
1717                output_y: pending.output_y,
1718                width: pending.width,
1719                height: pending.height,
1720                coefficients,
1721            })
1722        })
1723        .collect::<Vec<_>>()
1724        .into_iter()
1725        .collect()
1726}
1727
1728#[cfg(feature = "parallel")]
1729fn copy_decoded_classic_blocks_to_sub_band(
1730    decoded_blocks: &[DecodedClassicBlock],
1731    sub_band: &SubBand,
1732    storage: &mut DecompositionStorage<'_>,
1733) -> Result<()> {
1734    let sub_band_width = sub_band.rect.width() as usize;
1735    let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1736    for block in decoded_blocks {
1737        if block
1738            .output_x
1739            .checked_add(block.width)
1740            .is_none_or(|x1| x1 > sub_band.rect.width())
1741            || block
1742                .output_y
1743                .checked_add(block.height)
1744                .is_none_or(|y1| y1 > sub_band.rect.height())
1745        {
1746            bail!(DecodingError::CodeBlockDecodeFailure);
1747        }
1748        let block_width = block.width as usize;
1749        for row in 0..block.height as usize {
1750            let dst_start = (block.output_y as usize + row)
1751                .checked_mul(sub_band_width)
1752                .and_then(|offset| offset.checked_add(block.output_x as usize))
1753                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1754            let dst_end = dst_start
1755                .checked_add(block_width)
1756                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1757            let src_start = row
1758                .checked_mul(block_width)
1759                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1760            let src_end = src_start
1761                .checked_add(block_width)
1762                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1763            base_store[dst_start..dst_end].copy_from_slice(&block.coefficients[src_start..src_end]);
1764        }
1765    }
1766    Ok(())
1767}
1768
1769#[cfg(feature = "parallel")]
1770fn decode_ht_sub_band_blocks_parallel(
1771    pending_blocks: &[PendingHtBlock],
1772    strict: bool,
1773    num_bitplanes: u8,
1774    roi_shift: u8,
1775    stripe_causal: bool,
1776    dequantization_step: f32,
1777) -> Result<Vec<DecodedHtBlock>> {
1778    use rayon::prelude::*;
1779
1780    pending_blocks
1781        .par_iter()
1782        .map(|pending| {
1783            let output_stride = pending.width as usize;
1784            let output_len = output_stride
1785                .checked_mul(pending.height as usize)
1786                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1787            let mut coefficients = vec![0.0; output_len];
1788            let mut workspace = HtCodeBlockDecodeWorkspace::default();
1789            decode_ht_code_block_scalar_with_workspace(
1790                HtCodeBlockDecodeJob {
1791                    data: &pending.combined.data,
1792                    cleanup_length: pending.combined.cleanup_length,
1793                    refinement_length: pending.combined.refinement_length,
1794                    width: pending.width,
1795                    height: pending.height,
1796                    output_stride,
1797                    missing_bit_planes: pending.missing_bit_planes,
1798                    number_of_coding_passes: pending.number_of_coding_passes,
1799                    num_bitplanes,
1800                    roi_shift,
1801                    stripe_causal,
1802                    strict,
1803                    dequantization_step,
1804                },
1805                &mut coefficients,
1806                &mut workspace,
1807            )?;
1808            Ok(DecodedHtBlock {
1809                output_x: pending.output_x,
1810                output_y: pending.output_y,
1811                width: pending.width,
1812                height: pending.height,
1813                coefficients,
1814            })
1815        })
1816        .collect::<Vec<_>>()
1817        .into_iter()
1818        .collect()
1819}
1820
1821#[cfg(feature = "parallel")]
1822fn copy_decoded_ht_blocks_to_sub_band(
1823    decoded_blocks: &[DecodedHtBlock],
1824    sub_band: &SubBand,
1825    storage: &mut DecompositionStorage<'_>,
1826) -> Result<()> {
1827    let sub_band_width = sub_band.rect.width() as usize;
1828    let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1829    for block in decoded_blocks {
1830        if block
1831            .output_x
1832            .checked_add(block.width)
1833            .is_none_or(|x1| x1 > sub_band.rect.width())
1834            || block
1835                .output_y
1836                .checked_add(block.height)
1837                .is_none_or(|y1| y1 > sub_band.rect.height())
1838        {
1839            bail!(DecodingError::CodeBlockDecodeFailure);
1840        }
1841        let block_width = block.width as usize;
1842        for row in 0..block.height as usize {
1843            let dst_start = (block.output_y as usize + row)
1844                .checked_mul(sub_band_width)
1845                .and_then(|offset| offset.checked_add(block.output_x as usize))
1846                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1847            let dst_end = dst_start
1848                .checked_add(block_width)
1849                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1850            let src_start = row
1851                .checked_mul(block_width)
1852                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1853            let src_end = src_start
1854                .checked_add(block_width)
1855                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1856            base_store[dst_start..dst_end].copy_from_slice(&block.coefficients[src_start..src_end]);
1857        }
1858    }
1859    Ok(())
1860}
1861
1862fn decode_sub_band_ht_blocks_i64(
1863    sub_band_idx: usize,
1864    sub_band: &SubBand,
1865    component_info: &ComponentInfo,
1866    tile_ctx: &mut TileDecodeContext,
1867    storage: &mut DecompositionStorage<'_>,
1868    header: &Header<'_>,
1869    num_bitplanes: u8,
1870    profile_enabled: bool,
1871) -> Result<()> {
1872    let coded_bitplanes = add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, 31)?;
1873    let stripe_causal = component_info
1874        .coding_style
1875        .parameters
1876        .code_block_style
1877        .vertically_causal_context;
1878
1879    for precinct in sub_band
1880        .precincts
1881        .clone()
1882        .map(|idx| &storage.precincts[idx])
1883    {
1884        for code_block in precinct
1885            .code_blocks
1886            .clone()
1887            .map(|idx| &storage.code_blocks[idx])
1888        {
1889            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1890                tile_ctx.debug_counters.skipped_code_blocks += 1;
1891                continue;
1892            }
1893
1894            let actual_bitplanes = if header.strict {
1895                coded_bitplanes
1896                    .checked_sub(code_block.missing_bit_planes)
1897                    .ok_or(DecodingError::InvalidBitplaneCount)?
1898            } else {
1899                coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
1900            };
1901            let max_coding_passes = if actual_bitplanes == 0 {
1902                0
1903            } else {
1904                1 + 3 * (actual_bitplanes - 1)
1905            };
1906            if code_block.number_of_coding_passes > max_coding_passes && header.strict {
1907                bail!(DecodingError::TooManyCodingPasses);
1908            }
1909            if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
1910                continue;
1911            }
1912
1913            tile_ctx.debug_counters.decoded_code_blocks += 1;
1914            ht_block_decode::decode_with_stats(
1915                code_block,
1916                coded_bitplanes,
1917                stripe_causal,
1918                &mut tile_ctx.ht_block_decode_context,
1919                storage,
1920                header.strict,
1921                Some(&mut tile_ctx.debug_counters.ht_phase_stats),
1922                profile_enabled,
1923            )?;
1924
1925            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
1926            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
1927            let base_store = &mut storage.coefficients_i64[sub_band.coefficients.clone()];
1928            let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
1929            let output_stride = sub_band.rect.width() as usize;
1930
1931            for coefficients in tile_ctx.ht_block_decode_context.coefficient_rows() {
1932                let out_row = &mut base_store[base_idx..];
1933
1934                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
1935                    let coefficient =
1936                        ht_block_decode::coefficient_to_i32(coefficient, coded_bitplanes) as i64;
1937                    *output = apply_roi_maxshift_inverse_i64(coefficient, component_info.roi_shift);
1938                }
1939
1940                base_idx += output_stride;
1941            }
1942        }
1943    }
1944
1945    Ok(())
1946}
1947
1948fn decode_sub_band_ht_blocks(
1949    sub_band_idx: usize,
1950    sub_band: &SubBand,
1951    component_info: &ComponentInfo,
1952    tile_ctx: &mut TileDecodeContext,
1953    storage: &mut DecompositionStorage<'_>,
1954    header: &Header<'_>,
1955    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1956    cpu_decode_parallelism: CpuDecodeParallelism,
1957    num_bitplanes: u8,
1958    dequantization_step: f32,
1959    profile_enabled: bool,
1960) -> Result<()> {
1961    let coded_bitplanes = add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, 31)?;
1962    let stripe_causal = component_info
1963        .coding_style
1964        .parameters
1965        .code_block_style
1966        .vertically_causal_context;
1967
1968    if let Some(ht_decoder) = ht_decoder.as_deref_mut() {
1969        let mut pending_blocks = Vec::new();
1970        for precinct in sub_band
1971            .precincts
1972            .clone()
1973            .map(|idx| &storage.precincts[idx])
1974        {
1975            for code_block in precinct
1976                .code_blocks
1977                .clone()
1978                .map(|idx| &storage.code_blocks[idx])
1979            {
1980                if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1981                    continue;
1982                }
1983                let actual_bitplanes = if header.strict {
1984                    coded_bitplanes
1985                        .checked_sub(code_block.missing_bit_planes)
1986                        .ok_or(DecodingError::InvalidBitplaneCount)?
1987                } else {
1988                    coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
1989                };
1990                let max_coding_passes = if actual_bitplanes == 0 {
1991                    0
1992                } else {
1993                    1 + 3 * (actual_bitplanes - 1)
1994                };
1995                if code_block.number_of_coding_passes > max_coding_passes && header.strict {
1996                    bail!(DecodingError::TooManyCodingPasses);
1997                }
1998                if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
1999                    continue;
2000                }
2001
2002                pending_blocks.push(PendingHtBlock {
2003                    combined: ht_block_decode::collect_code_block_data(code_block, storage)?,
2004                    output_x: code_block.rect.x0 - sub_band.rect.x0,
2005                    output_y: code_block.rect.y0 - sub_band.rect.y0,
2006                    width: code_block.rect.width(),
2007                    height: code_block.rect.height(),
2008                    missing_bit_planes: code_block.missing_bit_planes,
2009                    number_of_coding_passes: code_block.number_of_coding_passes,
2010                });
2011            }
2012        }
2013
2014        let batch_jobs: Vec<_> = pending_blocks
2015            .iter()
2016            .map(|pending| HtCodeBlockBatchJob {
2017                output_x: pending.output_x,
2018                output_y: pending.output_y,
2019                code_block: HtCodeBlockDecodeJob {
2020                    data: &pending.combined.data,
2021                    cleanup_length: pending.combined.cleanup_length,
2022                    refinement_length: pending.combined.refinement_length,
2023                    width: pending.width,
2024                    height: pending.height,
2025                    output_stride: sub_band.rect.width() as usize,
2026                    missing_bit_planes: pending.missing_bit_planes,
2027                    number_of_coding_passes: pending.number_of_coding_passes,
2028                    num_bitplanes,
2029                    roi_shift: component_info.roi_shift,
2030                    stripe_causal,
2031                    strict: header.strict,
2032                    dequantization_step,
2033                },
2034            })
2035            .collect();
2036
2037        let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
2038        if ht_decoder.decode_sub_band(
2039            HtSubBandDecodeJob {
2040                width: sub_band.rect.width(),
2041                height: sub_band.rect.height(),
2042                jobs: &batch_jobs,
2043            },
2044            base_store,
2045        )? {
2046            tile_ctx.debug_counters.decoded_code_blocks += batch_jobs.len();
2047            return Ok(());
2048        }
2049
2050        let output_stride = sub_band.rect.width() as usize;
2051        for job in batch_jobs {
2052            tile_ctx.debug_counters.decoded_code_blocks += 1;
2053            let base_idx = (job.output_y * sub_band.rect.width()) as usize + job.output_x as usize;
2054            let output_len = if job.code_block.height == 0 {
2055                0
2056            } else {
2057                output_stride * (job.code_block.height as usize - 1) + job.code_block.width as usize
2058            };
2059            ht_decoder.decode_code_block(
2060                job.code_block,
2061                &mut base_store[base_idx..base_idx + output_len],
2062            )?;
2063        }
2064
2065        return Ok(());
2066    }
2067
2068    let code_block_count = count_ht_code_blocks(sub_band_idx, sub_band, storage);
2069    if !profile_enabled
2070        && should_decode_ht_sub_band_in_parallel(cpu_decode_parallelism, code_block_count)
2071    {
2072        #[cfg(feature = "parallel")]
2073        {
2074            let pending_blocks = collect_pending_ht_blocks(
2075                sub_band_idx,
2076                sub_band,
2077                storage,
2078                header,
2079                num_bitplanes,
2080                component_info.roi_shift,
2081            )?;
2082            let decoded_blocks = decode_ht_sub_band_blocks_parallel(
2083                &pending_blocks,
2084                header.strict,
2085                num_bitplanes,
2086                component_info.roi_shift,
2087                stripe_causal,
2088                dequantization_step,
2089            )?;
2090            tile_ctx.debug_counters.decoded_code_blocks += decoded_blocks.len();
2091            copy_decoded_ht_blocks_to_sub_band(&decoded_blocks, sub_band, storage)?;
2092            return Ok(());
2093        }
2094    }
2095
2096    for precinct in sub_band
2097        .precincts
2098        .clone()
2099        .map(|idx| &storage.precincts[idx])
2100    {
2101        for code_block in precinct
2102            .code_blocks
2103            .clone()
2104            .map(|idx| &storage.code_blocks[idx])
2105        {
2106            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
2107                tile_ctx.debug_counters.skipped_code_blocks += 1;
2108                continue;
2109            }
2110            tile_ctx.debug_counters.decoded_code_blocks += 1;
2111            ht_block_decode::decode_with_stats(
2112                code_block,
2113                coded_bitplanes,
2114                stripe_causal,
2115                &mut tile_ctx.ht_block_decode_context,
2116                storage,
2117                header.strict,
2118                Some(&mut tile_ctx.debug_counters.ht_phase_stats),
2119                profile_enabled,
2120            )?;
2121
2122            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
2123            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
2124            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
2125            let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
2126            let output_stride = sub_band.rect.width() as usize;
2127
2128            for coefficients in tile_ctx.ht_block_decode_context.coefficient_rows() {
2129                let out_row = &mut base_store[base_idx..];
2130
2131                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
2132                    let coefficient =
2133                        ht_block_decode::coefficient_to_i32(coefficient, coded_bitplanes);
2134                    let coefficient =
2135                        apply_roi_maxshift_inverse_i32(coefficient, component_info.roi_shift);
2136                    *output = coefficient as f32;
2137                    *output *= dequantization_step;
2138                }
2139
2140                base_idx += output_stride;
2141            }
2142        }
2143    }
2144
2145    Ok(())
2146}
2147
2148fn apply_sign_shift(tile_ctx: &mut TileDecodeContext, component_infos: &[ComponentInfo]) {
2149    for (channel_data, component_info) in
2150        tile_ctx.channel_data.iter_mut().zip(component_infos.iter())
2151    {
2152        if let Some(samples) = channel_data.integer_container.as_mut() {
2153            let addend = component_unsigned_level_shift_i64(component_info);
2154            for sample in samples {
2155                *sample += addend;
2156            }
2157        } else {
2158            let addend = component_unsigned_level_shift(component_info);
2159            for sample in channel_data.container.deref_mut() {
2160                *sample += addend;
2161            }
2162        }
2163    }
2164}
2165
2166fn store<'a>(
2167    tile: &'a Tile<'a>,
2168    header: &Header<'_>,
2169    tile_ctx: &mut TileDecodeContext,
2170    component_info: &ComponentInfo,
2171    component_idx: usize,
2172    backend: &mut Option<&mut dyn HtCodeBlockDecoder>,
2173) -> Result<()> {
2174    if tile_ctx.channel_data[component_idx]
2175        .integer_container
2176        .is_some()
2177    {
2178        return store_i64(tile, header, tile_ctx, component_info, component_idx);
2179    }
2180
2181    let channel_data = &mut tile_ctx.channel_data[component_idx];
2182    let idwt_output = &mut tile_ctx.idwt_output;
2183
2184    let component_tile = ComponentTile::new(tile, component_info);
2185    let resolution_tile = ResolutionTile::new(
2186        component_tile,
2187        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
2188    );
2189
2190    let sign_shift = if tile.mct {
2191        0.0
2192    } else {
2193        component_unsigned_level_shift(component_info)
2194    };
2195
2196    let (scale_x, scale_y) = (
2197        component_info.size_info.horizontal_resolution,
2198        component_info.size_info.vertical_resolution,
2199    );
2200
2201    let (image_x_offset, image_y_offset) = (
2202        header.size_data.image_area_x_offset,
2203        header.size_data.image_area_y_offset,
2204    );
2205
2206    if let Some(output_region) = tile_ctx.output_region {
2207        store_region(
2208            tile,
2209            header,
2210            tile_ctx,
2211            component_info,
2212            component_idx,
2213            output_region,
2214            backend,
2215            sign_shift,
2216        )?;
2217        return Ok(());
2218    }
2219
2220    if scale_x == 1 && scale_y == 1 {
2221        let source_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
2222        let source_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
2223        let copy_width = resolution_tile
2224            .rect
2225            .width()
2226            .min(idwt_output.rect.width().saturating_sub(source_x));
2227        let copy_height = resolution_tile
2228            .rect
2229            .height()
2230            .min(idwt_output.rect.height().saturating_sub(source_y));
2231        let output_x = resolution_tile.rect.x0.saturating_sub(image_x_offset);
2232        let output_y = resolution_tile.rect.y0.saturating_sub(image_y_offset);
2233
2234        let handled = if let Some(backend) = backend.as_deref_mut() {
2235            copy_width > 0
2236                && copy_height > 0
2237                && backend.decode_store_component(J2kStoreComponentJob {
2238                    input: &idwt_output.coefficients,
2239                    input_width: idwt_output.rect.width(),
2240                    source_x,
2241                    source_y,
2242                    copy_width,
2243                    copy_height,
2244                    output: &mut channel_data.container,
2245                    output_width: header.size_data.image_width(),
2246                    output_x,
2247                    output_y,
2248                    addend: sign_shift,
2249                })?
2250        } else {
2251            false
2252        };
2253
2254        if handled {
2255            return Ok(());
2256        }
2257
2258        // If no sub-sampling, use a fast path where we copy rows of coefficients
2259        // at once.
2260
2261        // The rect of the IDWT output corresponds to the rect of the highest
2262        // decomposition level of the tile, which is usually not 1:1 aligned
2263        // with the actual tile rectangle. We also need to account for the
2264        // offset of the reference grid.
2265
2266        let skip_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
2267        let skip_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
2268
2269        if sign_shift != 0.0 {
2270            for sample in idwt_output.coefficients.iter_mut() {
2271                *sample += sign_shift;
2272            }
2273        }
2274
2275        let input_row_iter = idwt_output
2276            .coefficients
2277            .chunks_exact(idwt_output.rect.width() as usize)
2278            .skip(skip_y as usize)
2279            .take(idwt_output.rect.height() as usize);
2280
2281        let output_row_iter = channel_data
2282            .container
2283            .chunks_exact_mut(header.size_data.image_width() as usize)
2284            .skip(resolution_tile.rect.y0.saturating_sub(image_y_offset) as usize);
2285
2286        for (input_row, output_row) in input_row_iter.zip(output_row_iter) {
2287            let input_row = &input_row[skip_x as usize..];
2288            let output_row = &mut output_row
2289                [resolution_tile.rect.x0.saturating_sub(image_x_offset) as usize..]
2290                [..input_row.len()];
2291
2292            output_row.copy_from_slice(input_row);
2293        }
2294    } else {
2295        if sign_shift != 0.0 {
2296            for sample in idwt_output.coefficients.iter_mut() {
2297                *sample += sign_shift;
2298            }
2299        }
2300        let image_width = header.size_data.image_width();
2301        let image_height = header.size_data.image_height();
2302
2303        let x_shrink_factor = header.size_data.x_shrink_factor;
2304        let y_shrink_factor = header.size_data.y_shrink_factor;
2305
2306        let x_offset = header
2307            .size_data
2308            .image_area_x_offset
2309            .div_ceil(x_shrink_factor);
2310        let y_offset = header
2311            .size_data
2312            .image_area_y_offset
2313            .div_ceil(y_shrink_factor);
2314
2315        // Otherwise, copy sample by sample.
2316        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
2317            let relative_y = (y - component_tile.rect.y0) as usize;
2318            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
2319
2320            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
2321                let relative_x = (x - component_tile.rect.x0) as usize;
2322                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
2323
2324                let sample = idwt_output.coefficients
2325                    [relative_y * idwt_output.rect.width() as usize + relative_x];
2326
2327                for x_position in u32::max(reference_grid_x, x_offset)
2328                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
2329                {
2330                    for y_position in u32::max(reference_grid_y, y_offset)
2331                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
2332                    {
2333                        let pos = (y_position - y_offset) as usize * image_width as usize
2334                            + (x_position - x_offset) as usize;
2335
2336                        channel_data.container[pos] = sample;
2337                    }
2338                }
2339            }
2340        }
2341    }
2342
2343    Ok(())
2344}
2345
2346fn store_i64<'a>(
2347    tile: &'a Tile<'a>,
2348    header: &Header<'_>,
2349    tile_ctx: &mut TileDecodeContext,
2350    component_info: &ComponentInfo,
2351    component_idx: usize,
2352) -> Result<()> {
2353    if tile_ctx.output_region.is_some() {
2354        bail!(DecodingError::UnsupportedFeature(
2355            "25-38 bit region decode requires exact integer region IDWT support"
2356        ));
2357    }
2358
2359    let channel_data = &mut tile_ctx.channel_data[component_idx];
2360    let idwt_output = &mut tile_ctx.idwt_output;
2361    let output = channel_data
2362        .integer_container
2363        .as_mut()
2364        .ok_or(DecodingError::CodeBlockDecodeFailure)?;
2365
2366    let component_tile = ComponentTile::new(tile, component_info);
2367    let resolution_tile = ResolutionTile::new(
2368        component_tile,
2369        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
2370    );
2371
2372    let sign_shift = if tile.mct {
2373        0
2374    } else {
2375        component_unsigned_level_shift_i64(component_info)
2376    };
2377
2378    let (scale_x, scale_y) = (
2379        component_info.size_info.horizontal_resolution,
2380        component_info.size_info.vertical_resolution,
2381    );
2382
2383    let (image_x_offset, image_y_offset) = (
2384        header.size_data.image_area_x_offset,
2385        header.size_data.image_area_y_offset,
2386    );
2387
2388    if scale_x == 1 && scale_y == 1 {
2389        let source_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
2390        let source_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
2391        let copy_width = resolution_tile
2392            .rect
2393            .width()
2394            .min(idwt_output.rect.width().saturating_sub(source_x));
2395        let copy_height = resolution_tile
2396            .rect
2397            .height()
2398            .min(idwt_output.rect.height().saturating_sub(source_y));
2399        let output_x = resolution_tile.rect.x0.saturating_sub(image_x_offset);
2400        let output_y = resolution_tile.rect.y0.saturating_sub(image_y_offset);
2401        let input_width = idwt_output.rect.width() as usize;
2402        let image_width = header.size_data.image_width() as usize;
2403        let copy_width = copy_width as usize;
2404
2405        for row in 0..copy_height as usize {
2406            let src_start = (source_y as usize + row)
2407                .checked_mul(input_width)
2408                .and_then(|offset| offset.checked_add(source_x as usize))
2409                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
2410            let dst_start = (output_y as usize + row)
2411                .checked_mul(image_width)
2412                .and_then(|offset| offset.checked_add(output_x as usize))
2413                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
2414            let src = &idwt_output.coefficients_i64[src_start..src_start + copy_width];
2415            let dst = &mut output[dst_start..dst_start + copy_width];
2416            if sign_shift == 0 {
2417                dst.copy_from_slice(src);
2418            } else {
2419                for (dst, src) in dst.iter_mut().zip(src.iter().copied()) {
2420                    *dst = src + sign_shift;
2421                }
2422            }
2423        }
2424    } else {
2425        let image_width = header.size_data.image_width();
2426        let image_height = header.size_data.image_height();
2427
2428        let x_shrink_factor = header.size_data.x_shrink_factor;
2429        let y_shrink_factor = header.size_data.y_shrink_factor;
2430
2431        let x_offset = header
2432            .size_data
2433            .image_area_x_offset
2434            .div_ceil(x_shrink_factor);
2435        let y_offset = header
2436            .size_data
2437            .image_area_y_offset
2438            .div_ceil(y_shrink_factor);
2439
2440        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
2441            let relative_y = (y - component_tile.rect.y0) as usize;
2442            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
2443
2444            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
2445                let relative_x = (x - component_tile.rect.x0) as usize;
2446                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
2447
2448                let sample = idwt_output.coefficients_i64
2449                    [relative_y * idwt_output.rect.width() as usize + relative_x]
2450                    + sign_shift;
2451
2452                for x_position in u32::max(reference_grid_x, x_offset)
2453                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
2454                {
2455                    for y_position in u32::max(reference_grid_y, y_offset)
2456                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
2457                    {
2458                        let pos = (y_position - y_offset) as usize * image_width as usize
2459                            + (x_position - x_offset) as usize;
2460
2461                        output[pos] = sample;
2462                    }
2463                }
2464            }
2465        }
2466    }
2467
2468    Ok(())
2469}
2470
2471fn component_unsigned_level_shift(component_info: &ComponentInfo) -> f32 {
2472    if component_info.size_info.signed {
2473        0.0
2474    } else {
2475        (1_u64 << (component_info.size_info.precision - 1)) as f32
2476    }
2477}
2478
2479fn component_unsigned_level_shift_i64(component_info: &ComponentInfo) -> i64 {
2480    if component_info.size_info.signed {
2481        0
2482    } else {
2483        1_i64 << (component_info.size_info.precision - 1)
2484    }
2485}
2486
2487fn store_region<'a>(
2488    tile: &'a Tile<'a>,
2489    header: &Header<'_>,
2490    tile_ctx: &mut TileDecodeContext,
2491    component_info: &ComponentInfo,
2492    component_idx: usize,
2493    output_region: OutputRegion,
2494    backend: &mut Option<&mut dyn HtCodeBlockDecoder>,
2495    sign_shift: f32,
2496) -> Result<()> {
2497    let channel_data = &mut tile_ctx.channel_data[component_idx];
2498    let idwt_output = &mut tile_ctx.idwt_output;
2499
2500    let component_tile = ComponentTile::new(tile, component_info);
2501    let resolution_tile = ResolutionTile::new(
2502        component_tile,
2503        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
2504    );
2505
2506    let (scale_x, scale_y) = (
2507        component_info.size_info.horizontal_resolution,
2508        component_info.size_info.vertical_resolution,
2509    );
2510    let image_width = header.size_data.image_width();
2511    let image_height = header.size_data.image_height();
2512    let x_shrink_factor = header.size_data.x_shrink_factor;
2513    let y_shrink_factor = header.size_data.y_shrink_factor;
2514    let x_offset = header
2515        .size_data
2516        .image_area_x_offset
2517        .div_ceil(x_shrink_factor);
2518    let y_offset = header
2519        .size_data
2520        .image_area_y_offset
2521        .div_ceil(y_shrink_factor);
2522    let region_x1 = output_region.x + output_region.width;
2523    let region_y1 = output_region.y + output_region.height;
2524    let output_width = output_region.width as usize;
2525
2526    if scale_x == 1 && scale_y == 1 {
2527        let region_rect_x0 = output_region.x + x_offset;
2528        let region_rect_y0 = output_region.y + y_offset;
2529        let region_rect_x1 = region_x1 + x_offset;
2530        let region_rect_y1 = region_y1 + y_offset;
2531        let copy_x0 = idwt_output
2532            .rect
2533            .x0
2534            .max(resolution_tile.rect.x0)
2535            .max(region_rect_x0);
2536        let copy_y0 = idwt_output
2537            .rect
2538            .y0
2539            .max(resolution_tile.rect.y0)
2540            .max(region_rect_y0);
2541        let copy_x1 = idwt_output
2542            .rect
2543            .x1
2544            .min(resolution_tile.rect.x1)
2545            .min(region_rect_x1);
2546        let copy_y1 = idwt_output
2547            .rect
2548            .y1
2549            .min(resolution_tile.rect.y1)
2550            .min(region_rect_y1);
2551
2552        let handled = if let Some(backend) = backend.as_deref_mut() {
2553            copy_x0 < copy_x1
2554                && copy_y0 < copy_y1
2555                && backend.decode_store_component(J2kStoreComponentJob {
2556                    input: &idwt_output.coefficients,
2557                    input_width: idwt_output.rect.width(),
2558                    source_x: copy_x0 - idwt_output.rect.x0,
2559                    source_y: copy_y0 - idwt_output.rect.y0,
2560                    copy_width: copy_x1 - copy_x0,
2561                    copy_height: copy_y1 - copy_y0,
2562                    output: &mut channel_data.container,
2563                    output_width: output_region.width,
2564                    output_x: copy_x0 - region_rect_x0,
2565                    output_y: copy_y0 - region_rect_y0,
2566                    addend: sign_shift,
2567                })?
2568        } else {
2569            false
2570        };
2571
2572        if handled {
2573            return Ok(());
2574        }
2575
2576        if sign_shift != 0.0 {
2577            for sample in idwt_output.coefficients.iter_mut() {
2578                *sample += sign_shift;
2579            }
2580        }
2581
2582        if copy_x0 < copy_x1 && copy_y0 < copy_y1 {
2583            let input_width = idwt_output.rect.width() as usize;
2584            let copy_width = (copy_x1 - copy_x0) as usize;
2585            for y in copy_y0..copy_y1 {
2586                let src_start = (y - idwt_output.rect.y0) as usize * input_width
2587                    + (copy_x0 - idwt_output.rect.x0) as usize;
2588                let dst_start = (y - region_rect_y0) as usize * output_width
2589                    + (copy_x0 - region_rect_x0) as usize;
2590                channel_data.container[dst_start..dst_start + copy_width]
2591                    .copy_from_slice(&idwt_output.coefficients[src_start..src_start + copy_width]);
2592            }
2593        }
2594
2595        return Ok(());
2596    }
2597
2598    if sign_shift != 0.0 {
2599        for sample in idwt_output.coefficients.iter_mut() {
2600            *sample += sign_shift;
2601        }
2602    }
2603
2604    for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
2605        let relative_y = (y - component_tile.rect.y0) as usize;
2606        let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
2607
2608        for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
2609            let relative_x = (x - component_tile.rect.x0) as usize;
2610            let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
2611
2612            let sample = idwt_output.coefficients
2613                [relative_y * idwt_output.rect.width() as usize + relative_x];
2614
2615            for x_position in u32::max(reference_grid_x, x_offset)
2616                ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
2617            {
2618                let image_x = x_position - x_offset;
2619                if image_x < output_region.x || image_x >= region_x1 {
2620                    continue;
2621                }
2622
2623                for y_position in u32::max(reference_grid_y, y_offset)
2624                    ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
2625                {
2626                    let image_y = y_position - y_offset;
2627                    if image_y < output_region.y || image_y >= region_y1 {
2628                        continue;
2629                    }
2630
2631                    let pos = (image_y - output_region.y) as usize * output_width
2632                        + (image_x - output_region.x) as usize;
2633                    channel_data.container[pos] = sample;
2634                }
2635            }
2636        }
2637    }
2638
2639    Ok(())
2640}
2641
2642#[cfg(test)]
2643mod tests {
2644    use super::{collect_classic_code_block_data, CodeBlock, DecompositionStorage, Layer, Segment};
2645    use crate::error::DecodingError;
2646    use crate::j2c::codestream::CodeBlockStyle;
2647    use crate::j2c::rect::IntRect;
2648    use alloc::vec;
2649
2650    fn classic_test_style() -> CodeBlockStyle {
2651        CodeBlockStyle {
2652            selective_arithmetic_coding_bypass: false,
2653            reset_context_probabilities: false,
2654            termination_on_each_pass: true,
2655            vertically_causal_context: false,
2656            segmentation_symbols: false,
2657            high_throughput_block_coding: false,
2658        }
2659    }
2660
2661    fn classic_test_code_block() -> CodeBlock {
2662        CodeBlock {
2663            rect: IntRect::from_xywh(0, 0, 1, 1),
2664            x_idx: 0,
2665            y_idx: 0,
2666            layers: 0..1,
2667            has_been_included: true,
2668            missing_bit_planes: 0,
2669            number_of_coding_passes: 3,
2670            l_block: 3,
2671            non_empty_layer_count: 1,
2672        }
2673    }
2674
2675    #[test]
2676    fn collect_classic_code_block_data_preserves_zero_length_segments() {
2677        let mut storage = DecompositionStorage::default();
2678        storage.layers.push(Layer {
2679            segments: Some(0..3),
2680        });
2681        storage.segments.push(Segment {
2682            idx: 0,
2683            coding_pases: 1,
2684            data_length: 1,
2685            data: &[0xAA],
2686        });
2687        storage.segments.push(Segment {
2688            idx: 1,
2689            coding_pases: 1,
2690            data_length: 0,
2691            data: &[],
2692        });
2693        storage.segments.push(Segment {
2694            idx: 2,
2695            coding_pases: 1,
2696            data_length: 1,
2697            data: &[0xBB],
2698        });
2699
2700        let (combined_data, segments) = collect_classic_code_block_data(
2701            &classic_test_code_block(),
2702            &classic_test_style(),
2703            &storage,
2704        )
2705        .expect("collect classic segments");
2706
2707        assert_eq!(combined_data, vec![0xAA, 0xBB]);
2708        assert_eq!(segments.len(), 3);
2709        assert_eq!(segments[0].data_offset, 0);
2710        assert_eq!(segments[0].data_length, 1);
2711        assert_eq!(segments[0].start_coding_pass, 0);
2712        assert_eq!(segments[0].end_coding_pass, 1);
2713        assert_eq!(segments[1].data_offset, 1);
2714        assert_eq!(segments[1].data_length, 0);
2715        assert_eq!(segments[1].start_coding_pass, 1);
2716        assert_eq!(segments[1].end_coding_pass, 2);
2717        assert_eq!(segments[2].data_offset, 1);
2718        assert_eq!(segments[2].data_length, 1);
2719        assert_eq!(segments[2].start_coding_pass, 2);
2720        assert_eq!(segments[2].end_coding_pass, 3);
2721    }
2722
2723    #[test]
2724    fn collect_classic_code_block_data_rejects_non_contiguous_segment_indices() {
2725        let mut storage = DecompositionStorage::default();
2726        storage.layers.push(Layer {
2727            segments: Some(0..2),
2728        });
2729        storage.segments.push(Segment {
2730            idx: 0,
2731            coding_pases: 1,
2732            data_length: 1,
2733            data: &[0xAA],
2734        });
2735        storage.segments.push(Segment {
2736            idx: 2,
2737            coding_pases: 2,
2738            data_length: 1,
2739            data: &[0xBB],
2740        });
2741
2742        let error = collect_classic_code_block_data(
2743            &classic_test_code_block(),
2744            &classic_test_style(),
2745            &storage,
2746        )
2747        .expect_err("non-contiguous segment indices must fail");
2748
2749        assert_eq!(error, DecodingError::CodeBlockDecodeFailure.into());
2750    }
2751
2752    #[test]
2753    fn auto_cpu_parallelism_enables_ht_sub_band_parallel_branch() {
2754        assert!(super::should_decode_ht_sub_band_in_parallel(
2755            super::CpuDecodeParallelism::Auto,
2756            16
2757        ));
2758    }
2759
2760    #[test]
2761    fn serial_cpu_parallelism_disables_ht_sub_band_parallel_branch() {
2762        assert!(!super::should_decode_ht_sub_band_in_parallel(
2763            super::CpuDecodeParallelism::Serial,
2764            16
2765        ));
2766    }
2767}