Skip to main content

j2k_native/j2c/
decode.rs

1//! Decoding JPEG2000 code streams.
2//!
3//! This is the "core" module of the crate that orchestrates all
4//! stages in such a way that a given codestream is decoded into its
5//! component channels.
6
7use alloc::boxed::Box;
8use alloc::vec;
9use alloc::vec::Vec;
10
11use super::bitplane::{BitPlaneDecodeBuffers, BitPlaneDecodeContext};
12use super::build::{CodeBlock, Decomposition, Layer, Precinct, Segment, SubBand, SubBandType};
13use super::codestream::{ComponentInfo, Header, QuantizationStyle};
14use super::ht_block_decode::{self, HtBlockDecodeContext};
15use super::idwt::IDWTOutput;
16use super::progression::{progression_iterator, ProgressionData};
17use super::roi::RoiPlan;
18use super::tag_tree::TagNode;
19use super::tile::{ComponentTile, ResolutionTile, Tile};
20use super::{bitplane, build, idwt, mct, segment, tile, ComponentData};
21use crate::error::{bail, ColorError, DecodingError, Result, TileError};
22use crate::j2c::segment::MAX_BITPLANE_COUNT;
23use crate::math::SimdBuffer;
24use crate::profile;
25use crate::reader::BitReader;
26use crate::{
27    add_roi_shift_to_bitplanes, apply_roi_maxshift_inverse_i32, checked_decode_byte_len3,
28    checked_decode_sample_count, decode_j2k_code_block_scalar, HtCodeBlockBatchJob,
29    HtCodeBlockDecodeJob, HtCodeBlockDecoder, HtOwnedCodeBlockBatchJob, HtOwnedSubBandPlan,
30    HtSubBandDecodeJob, J2kCodeBlockBatchJob, J2kCodeBlockDecodeJob, J2kCodeBlockSegment,
31    J2kCodeBlockStyle, J2kDirectBandId, J2kDirectColorPlan, J2kDirectGrayscalePlan,
32    J2kDirectGrayscaleStep, J2kDirectIdwtStep, J2kDirectStoreStep, J2kOwnedCodeBlockBatchJob,
33    J2kOwnedSubBandPlan, J2kRect, J2kStoreComponentJob, J2kSubBandDecodeJob, J2kSubBandType,
34    J2kWaveletTransform,
35};
36#[cfg(feature = "parallel")]
37use crate::{decode_ht_code_block_scalar_with_workspace, HtCodeBlockDecodeWorkspace};
38use core::mem::size_of;
39use core::ops::{DerefMut, Range};
40
41pub(crate) fn decode<'a>(
42    data: &'a [u8],
43    header: &Header<'a>,
44    ctx: &mut DecoderContext<'a>,
45    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
46) -> Result<()> {
47    let mut reader = BitReader::new(data);
48    let profile_enabled = profile::profile_stages_enabled();
49    let total_start = profile::profile_now(profile_enabled);
50    let mut profile_timings = DecodeProfileTimings::default();
51    let stage_start = profile::profile_now(profile_enabled);
52    let tiles = tile::parse(&mut reader, header)?;
53    profile_timings.parse_tiles_us += profile::elapsed_us(stage_start);
54
55    if tiles.is_empty() {
56        bail!(TileError::Invalid);
57    }
58
59    ctx.reset(header, &tiles[0])?;
60    let cpu_decode_parallelism = ctx.cpu_decode_parallelism;
61    let (tile_ctx, storage) = (&mut ctx.tile_decode_context, &mut ctx.storage);
62
63    for tile in &tiles {
64        ltrace!(
65            "tile {} rect [{},{} {}x{}]",
66            tile.idx,
67            tile.rect.x0,
68            tile.rect.y0,
69            tile.rect.width(),
70            tile.rect.height(),
71        );
72
73        decode_tile(
74            tile,
75            header,
76            progression_iterator(tile)?,
77            tile_ctx,
78            storage,
79            ht_decoder,
80            cpu_decode_parallelism,
81            profile_enabled,
82            &mut profile_timings,
83        )?;
84    }
85
86    // Note that this assumes that either all tiles have MCT or none of them.
87    // In theory, only some could have it... But hopefully no such cursed
88    // images exist!
89    if tiles[0].mct {
90        let stage_start = profile::profile_now(profile_enabled);
91        mct::apply_inverse(tile_ctx, &tiles[0].component_infos, header, ht_decoder)?;
92        apply_sign_shift(tile_ctx, &header.component_infos);
93        profile_timings.mct_us += profile::elapsed_us(stage_start);
94    }
95
96    if profile_enabled {
97        emit_decode_profile_row(tile_ctx, &profile_timings, total_start);
98    }
99
100    Ok(())
101}
102
103pub(crate) fn build_direct_grayscale_plan<'a>(
104    data: &'a [u8],
105    header: &Header<'a>,
106    ctx: &mut DecoderContext<'a>,
107) -> Result<J2kDirectGrayscalePlan> {
108    let mut reader = BitReader::new(data);
109    let tiles = tile::parse(&mut reader, header)?;
110
111    if tiles.len() != 1 {
112        bail!(DecodingError::UnsupportedFeature(
113            "direct grayscale plan only supports single-tile codestreams"
114        ));
115    }
116
117    let tile = &tiles[0];
118    if tile.component_infos.len() != 1 {
119        bail!(DecodingError::UnsupportedFeature(
120            "direct grayscale plan only supports single-component codestreams"
121        ));
122    }
123    ctx.tile_decode_context.channel_data.clear();
124    ctx.storage.reset();
125
126    build::build(tile, &mut ctx.storage)?;
127    if let Some(output_region) = ctx.tile_decode_context.output_region {
128        ctx.storage.roi_plan = RoiPlan::build(tile, header, &ctx.storage, output_region);
129    }
130
131    segment::parse(tile, progression_iterator(tile)?, header, &mut ctx.storage)?;
132
133    let component_info = &tile.component_infos[0];
134    build_component_plan_from_storage(
135        tile,
136        header,
137        &ctx.storage,
138        0,
139        (1_u32 << (component_info.size_info.precision - 1)) as f32,
140    )
141}
142
143pub(crate) fn build_direct_color_plan<'a>(
144    data: &'a [u8],
145    header: &Header<'a>,
146    ctx: &mut DecoderContext<'a>,
147) -> Result<J2kDirectColorPlan> {
148    let mut reader = BitReader::new(data);
149    let tiles = tile::parse(&mut reader, header)?;
150
151    if tiles.len() != 1 {
152        bail!(DecodingError::UnsupportedFeature(
153            "direct color plan only supports single-tile codestreams"
154        ));
155    }
156
157    let tile = &tiles[0];
158    if tile.component_infos.len() != 3 {
159        bail!(DecodingError::UnsupportedFeature(
160            "direct color plan only supports three-component RGB codestreams"
161        ));
162    }
163    let transform = tile.component_infos[0].wavelet_transform();
164    if tile.mct
165        && (transform != tile.component_infos[1].wavelet_transform()
166            || transform != tile.component_infos[2].wavelet_transform())
167    {
168        bail!(ColorError::Mct);
169    }
170
171    ctx.tile_decode_context.channel_data.clear();
172    ctx.storage.reset();
173
174    build::build(tile, &mut ctx.storage)?;
175    if let Some(output_region) = ctx.tile_decode_context.output_region {
176        ctx.storage.roi_plan = RoiPlan::build(tile, header, &ctx.storage, output_region);
177    }
178
179    segment::parse(tile, progression_iterator(tile)?, header, &mut ctx.storage)?;
180
181    let mut bit_depths = [0_u8; 3];
182    let mut component_plans = Vec::with_capacity(3);
183    for (component_idx, bit_depth) in bit_depths.iter_mut().enumerate() {
184        let component_info = &tile.component_infos[component_idx];
185        *bit_depth = component_info.size_info.precision;
186        let addend = if tile.mct {
187            0.0
188        } else {
189            (1_u32 << (component_info.size_info.precision - 1)) as f32
190        };
191        component_plans.push(build_component_plan_from_storage(
192            tile,
193            header,
194            &ctx.storage,
195            component_idx,
196            addend,
197        )?);
198    }
199
200    Ok(J2kDirectColorPlan {
201        dimensions: (
202            header.size_data.image_width(),
203            header.size_data.image_height(),
204        ),
205        bit_depths,
206        mct: tile.mct,
207        transform: J2kWaveletTransform::from(transform),
208        component_plans,
209    })
210}
211
212fn build_component_plan_from_storage(
213    tile: &Tile<'_>,
214    header: &Header<'_>,
215    storage: &DecompositionStorage<'_>,
216    component_idx: usize,
217    store_addend: f32,
218) -> Result<J2kDirectGrayscalePlan> {
219    let component_info =
220        tile.component_infos
221            .get(component_idx)
222            .ok_or(DecodingError::UnsupportedFeature(
223                "direct component plan index is out of range",
224            ))?;
225    if component_info.size_info.horizontal_resolution != 1
226        || component_info.size_info.vertical_resolution != 1
227    {
228        bail!(DecodingError::UnsupportedFeature(
229            "direct component plan only supports unit-sampled components"
230        ));
231    }
232
233    let tile_decompositions =
234        storage
235            .tile_decompositions
236            .get(component_idx)
237            .ok_or(DecodingError::UnsupportedFeature(
238                "direct component decomposition index is out of range",
239            ))?;
240    let decompositions = &storage.decompositions[tile_decompositions.decompositions.clone()];
241    let active_decomposition_count = decompositions
242        .len()
243        .saturating_sub(header.skipped_resolution_levels as usize);
244    let sub_band_step_count = (0..component_info.num_resolution_levels()
245        - header.skipped_resolution_levels)
246        .map(|resolution| {
247            tile_decompositions
248                .sub_band_iter(resolution, &storage.decompositions)
249                .count()
250        })
251        .sum::<usize>();
252    let mut steps =
253        Vec::with_capacity(sub_band_step_count + active_decomposition_count.saturating_add(1));
254    let mut next_band_id: J2kDirectBandId = 0;
255    let mut sub_band_ids = vec![None; storage.sub_bands.len()];
256
257    for resolution in 0..component_info.num_resolution_levels() - header.skipped_resolution_levels {
258        let sub_band_iter = tile_decompositions.sub_band_iter(resolution, &storage.decompositions);
259        for sub_band_idx in sub_band_iter {
260            if let Some(step) = build_grayscale_sub_band_step(
261                &storage.sub_bands[sub_band_idx],
262                sub_band_idx,
263                next_band_id,
264                resolution,
265                component_info,
266                storage,
267                header,
268            )? {
269                sub_band_ids[sub_band_idx] = Some(next_band_id);
270                next_band_id = next_band_id
271                    .checked_add(1)
272                    .ok_or(DecodingError::CodeBlockDecodeFailure)?;
273                steps.push(step);
274            }
275        }
276    }
277
278    let mut current_ll_rect = storage.sub_bands[tile_decompositions.first_ll_sub_band].rect;
279    let mut current_ll_band_id = sub_band_ids[tile_decompositions.first_ll_sub_band]
280        .ok_or(DecodingError::CodeBlockDecodeFailure)?;
281    let decompositions = &decompositions[..active_decomposition_count];
282    for decomposition in decompositions {
283        let hl = &storage.sub_bands[decomposition.sub_bands[0]];
284        let lh = &storage.sub_bands[decomposition.sub_bands[1]];
285        let hh = &storage.sub_bands[decomposition.sub_bands[2]];
286        let output_band_id = next_band_id;
287        next_band_id = next_band_id
288            .checked_add(1)
289            .ok_or(DecodingError::CodeBlockDecodeFailure)?;
290        steps.push(J2kDirectGrayscaleStep::Idwt(J2kDirectIdwtStep {
291            output_band_id,
292            rect: J2kRect::from(decomposition.rect),
293            transform: J2kWaveletTransform::from(component_info.wavelet_transform()),
294            ll_band_id: current_ll_band_id,
295            ll: J2kRect::from(current_ll_rect),
296            hl_band_id: sub_band_ids[decomposition.sub_bands[0]]
297                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
298            hl: J2kRect::from(hl.rect),
299            lh_band_id: sub_band_ids[decomposition.sub_bands[1]]
300                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
301            lh: J2kRect::from(lh.rect),
302            hh_band_id: sub_band_ids[decomposition.sub_bands[2]]
303                .ok_or(DecodingError::CodeBlockDecodeFailure)?,
304            hh: J2kRect::from(hh.rect),
305        }));
306        current_ll_rect = decomposition.rect;
307        current_ll_band_id = output_band_id;
308    }
309
310    let component_tile = ComponentTile::new(tile, component_info);
311    let resolution_tile = ResolutionTile::new(
312        component_tile,
313        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
314    );
315    let image_x_offset = header.size_data.image_area_x_offset;
316    let image_y_offset = header.size_data.image_area_y_offset;
317    let source_x = image_x_offset.saturating_sub(current_ll_rect.x0);
318    let source_y = image_y_offset.saturating_sub(current_ll_rect.y0);
319    let copy_width = resolution_tile
320        .rect
321        .width()
322        .min(current_ll_rect.width().saturating_sub(source_x));
323    let copy_height = resolution_tile
324        .rect
325        .height()
326        .min(current_ll_rect.height().saturating_sub(source_y));
327    let output_x = resolution_tile.rect.x0.saturating_sub(image_x_offset);
328    let output_y = resolution_tile.rect.y0.saturating_sub(image_y_offset);
329    steps.push(J2kDirectGrayscaleStep::Store(J2kDirectStoreStep {
330        input_band_id: current_ll_band_id,
331        input_rect: J2kRect::from(current_ll_rect),
332        source_x,
333        source_y,
334        copy_width,
335        copy_height,
336        output_width: header.size_data.image_width(),
337        output_height: header.size_data.image_height(),
338        output_x,
339        output_y,
340        addend: store_addend,
341    }));
342
343    Ok(J2kDirectGrayscalePlan {
344        dimensions: (
345            header.size_data.image_width(),
346            header.size_data.image_height(),
347        ),
348        bit_depth: component_info.size_info.precision,
349        steps,
350    })
351}
352
353fn build_grayscale_sub_band_step(
354    sub_band: &SubBand,
355    sub_band_idx: usize,
356    band_id: J2kDirectBandId,
357    resolution: u8,
358    component_info: &ComponentInfo,
359    storage: &DecompositionStorage<'_>,
360    header: &Header<'_>,
361) -> Result<Option<J2kDirectGrayscaleStep>> {
362    let dequantization_step = {
363        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
364        {
365            1.0
366        } else {
367            let (exponent, mantissa) =
368                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
369
370            let r_b = {
371                let log_gain = match sub_band.sub_band_type {
372                    SubBandType::LowLow => 0,
373                    SubBandType::LowHigh => 1,
374                    SubBandType::HighLow => 1,
375                    SubBandType::HighHigh => 2,
376                };
377
378                component_info.size_info.precision as u16 + log_gain
379            };
380
381            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
382        }
383    };
384
385    let num_bitplanes = {
386        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
387        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
388            .checked_add(exponent)
389            .and_then(|x| x.checked_sub(1))
390            .ok_or(DecodingError::InvalidBitplaneCount)?;
391
392        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
393            bail!(DecodingError::TooManyBitplanes);
394        }
395
396        num_bitplanes as u8
397    };
398
399    if component_info
400        .coding_style
401        .parameters
402        .code_block_style
403        .uses_high_throughput_block_coding()
404    {
405        let coded_bitplanes =
406            add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, 31)?;
407        let stripe_causal = component_info
408            .coding_style
409            .parameters
410            .code_block_style
411            .vertically_causal_context;
412        let mut jobs = Vec::with_capacity(direct_sub_band_job_capacity(sub_band, storage));
413        for precinct in sub_band
414            .precincts
415            .clone()
416            .map(|idx| &storage.precincts[idx])
417        {
418            for code_block in precinct
419                .code_blocks
420                .clone()
421                .map(|idx| &storage.code_blocks[idx])
422            {
423                if !code_block_required_by_index(storage, sub_band_idx, code_block) {
424                    continue;
425                }
426                let actual_bitplanes = if header.strict {
427                    coded_bitplanes
428                        .checked_sub(code_block.missing_bit_planes)
429                        .ok_or(DecodingError::InvalidBitplaneCount)?
430                } else {
431                    coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
432                };
433                let max_coding_passes = if actual_bitplanes == 0 {
434                    0
435                } else {
436                    1 + 3 * (actual_bitplanes - 1)
437                };
438                if code_block.number_of_coding_passes > max_coding_passes && header.strict {
439                    bail!(DecodingError::TooManyCodingPasses);
440                }
441                if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
442                    continue;
443                }
444
445                let combined = ht_block_decode::collect_code_block_data(code_block, storage)?;
446                jobs.push(HtOwnedCodeBlockBatchJob {
447                    output_x: code_block.rect.x0 - sub_band.rect.x0,
448                    output_y: code_block.rect.y0 - sub_band.rect.y0,
449                    data: combined.data,
450                    cleanup_length: combined.cleanup_length,
451                    refinement_length: combined.refinement_length,
452                    width: code_block.rect.width(),
453                    height: code_block.rect.height(),
454                    output_stride: sub_band.rect.width() as usize,
455                    missing_bit_planes: code_block.missing_bit_planes,
456                    number_of_coding_passes: code_block.number_of_coding_passes,
457                    num_bitplanes,
458                    roi_shift: component_info.roi_shift,
459                    stripe_causal,
460                    strict: header.strict,
461                    dequantization_step,
462                });
463            }
464        }
465
466        return Ok(Some(J2kDirectGrayscaleStep::HtSubBand(
467            HtOwnedSubBandPlan {
468                band_id,
469                rect: J2kRect::from(sub_band.rect),
470                width: sub_band.rect.width(),
471                height: sub_band.rect.height(),
472                jobs,
473            },
474        )));
475    }
476
477    let classic_job_sub_band_type = match sub_band.sub_band_type {
478        SubBandType::LowLow => J2kSubBandType::LowLow,
479        SubBandType::HighLow => J2kSubBandType::HighLow,
480        SubBandType::LowHigh => J2kSubBandType::LowHigh,
481        SubBandType::HighHigh => J2kSubBandType::HighHigh,
482    };
483    let classic_job_style = J2kCodeBlockStyle {
484        selective_arithmetic_coding_bypass: component_info
485            .coding_style
486            .parameters
487            .code_block_style
488            .selective_arithmetic_coding_bypass,
489        reset_context_probabilities: component_info
490            .coding_style
491            .parameters
492            .code_block_style
493            .reset_context_probabilities,
494        termination_on_each_pass: component_info
495            .coding_style
496            .parameters
497            .code_block_style
498            .termination_on_each_pass,
499        vertically_causal_context: component_info
500            .coding_style
501            .parameters
502            .code_block_style
503            .vertically_causal_context,
504        segmentation_symbols: component_info
505            .coding_style
506            .parameters
507            .code_block_style
508            .segmentation_symbols,
509    };
510
511    let mut jobs = Vec::with_capacity(direct_sub_band_job_capacity(sub_band, storage));
512    for precinct in sub_band
513        .precincts
514        .clone()
515        .map(|idx| &storage.precincts[idx])
516    {
517        for code_block in precinct
518            .code_blocks
519            .clone()
520            .map(|idx| &storage.code_blocks[idx])
521        {
522            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
523                continue;
524            }
525            let (combined_data, segments) = collect_classic_code_block_data(
526                code_block,
527                &component_info.coding_style.parameters.code_block_style,
528                storage,
529            )?;
530            jobs.push(J2kOwnedCodeBlockBatchJob {
531                output_x: code_block.rect.x0 - sub_band.rect.x0,
532                output_y: code_block.rect.y0 - sub_band.rect.y0,
533                data: combined_data,
534                segments,
535                width: code_block.rect.width(),
536                height: code_block.rect.height(),
537                output_stride: sub_band.rect.width() as usize,
538                missing_bit_planes: code_block.missing_bit_planes,
539                number_of_coding_passes: code_block.number_of_coding_passes,
540                total_bitplanes: num_bitplanes,
541                roi_shift: component_info.roi_shift,
542                sub_band_type: classic_job_sub_band_type,
543                style: classic_job_style,
544                strict: header.strict,
545                dequantization_step,
546            });
547        }
548    }
549
550    Ok(Some(J2kDirectGrayscaleStep::ClassicSubBand(
551        J2kOwnedSubBandPlan {
552            band_id,
553            rect: J2kRect::from(sub_band.rect),
554            width: sub_band.rect.width(),
555            height: sub_band.rect.height(),
556            jobs,
557        },
558    )))
559}
560
561fn direct_sub_band_job_capacity(sub_band: &SubBand, storage: &DecompositionStorage<'_>) -> usize {
562    sub_band
563        .precincts
564        .clone()
565        .map(|idx| storage.precincts[idx].code_blocks.len())
566        .sum()
567}
568
569fn collect_classic_code_block_data(
570    code_block: &CodeBlock,
571    style: &super::codestream::CodeBlockStyle,
572    storage: &DecompositionStorage<'_>,
573) -> Result<(Vec<u8>, Vec<J2kCodeBlockSegment>)> {
574    let mut combined_data = Vec::new();
575    let mut collected_segments = Vec::new();
576    let mut last_segment_idx = 0u8;
577    let mut segment_start_offset = 0usize;
578    let mut segment_start_coding_pass = 0u8;
579    let mut coding_passes = 0u8;
580    let is_normal_mode =
581        !style.selective_arithmetic_coding_bypass && !style.termination_on_each_pass;
582
583    for layer in &storage.layers[code_block.layers.start..code_block.layers.end] {
584        let Some(range) = layer.segments.clone() else {
585            continue;
586        };
587
588        for segment in &storage.segments[range] {
589            if segment.idx != last_segment_idx {
590                if segment.idx != last_segment_idx + 1 {
591                    bail!(DecodingError::CodeBlockDecodeFailure);
592                }
593                if coding_passes > segment_start_coding_pass
594                    || combined_data.len() > segment_start_offset
595                {
596                    let data_offset = u32::try_from(segment_start_offset)
597                        .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
598                    let data_length = u32::try_from(combined_data.len() - segment_start_offset)
599                        .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
600                    let use_arithmetic = if style.selective_arithmetic_coding_bypass {
601                        if segment_start_coding_pass <= 9 {
602                            true
603                        } else {
604                            segment_start_coding_pass.is_multiple_of(3)
605                        }
606                    } else {
607                        true
608                    };
609                    collected_segments.push(J2kCodeBlockSegment {
610                        data_offset,
611                        data_length,
612                        start_coding_pass: segment_start_coding_pass,
613                        end_coding_pass: coding_passes,
614                        use_arithmetic,
615                    });
616                }
617                segment_start_offset = combined_data.len();
618                segment_start_coding_pass = coding_passes;
619                last_segment_idx += 1;
620            }
621
622            combined_data.extend_from_slice(segment.data);
623            coding_passes = coding_passes.saturating_add(segment.coding_pases);
624        }
625    }
626
627    if coding_passes > segment_start_coding_pass || combined_data.len() > segment_start_offset {
628        let data_offset = u32::try_from(segment_start_offset)
629            .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
630        let data_length = u32::try_from(combined_data.len().saturating_sub(segment_start_offset))
631            .map_err(|_| DecodingError::CodeBlockDecodeFailure)?;
632        let use_arithmetic = if style.selective_arithmetic_coding_bypass {
633            if segment_start_coding_pass <= 9 {
634                true
635            } else {
636                segment_start_coding_pass.is_multiple_of(3)
637            }
638        } else {
639            true
640        };
641        collected_segments.push(J2kCodeBlockSegment {
642            data_offset,
643            data_length,
644            start_coding_pass: segment_start_coding_pass,
645            end_coding_pass: coding_passes,
646            use_arithmetic,
647        });
648    }
649
650    if is_normal_mode {
651        collected_segments.clear();
652        collected_segments.push(J2kCodeBlockSegment {
653            data_offset: 0,
654            data_length: u32::try_from(combined_data.len())
655                .map_err(|_| DecodingError::CodeBlockDecodeFailure)?,
656            start_coding_pass: 0,
657            end_coding_pass: coding_passes,
658            use_arithmetic: true,
659        });
660    }
661
662    if coding_passes != code_block.number_of_coding_passes {
663        bail!(DecodingError::CodeBlockDecodeFailure);
664    }
665
666    Ok((combined_data, collected_segments))
667}
668
669#[derive(Debug, Clone, Copy, PartialEq, Eq)]
670pub(crate) struct OutputRegion {
671    pub(crate) x: u32,
672    pub(crate) y: u32,
673    pub(crate) width: u32,
674    pub(crate) height: u32,
675}
676
677impl OutputRegion {
678    pub(crate) fn from_tuple(region: (u32, u32, u32, u32)) -> Self {
679        let (x, y, width, height) = region;
680        Self {
681            x,
682            y,
683            width,
684            height,
685        }
686    }
687
688    fn dimensions(self) -> (u32, u32) {
689        (self.width, self.height)
690    }
691}
692
693#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
694pub(crate) struct DecodeDebugCounters {
695    pub(crate) decoded_code_blocks: usize,
696    pub(crate) skipped_code_blocks: usize,
697    pub(crate) idwt_output_samples: usize,
698    pub(crate) ht_phase_stats: ht_block_decode::HtBlockDecodeStats,
699}
700
701/// CPU parallelism policy for native JPEG 2000 decode.
702#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
703pub enum CpuDecodeParallelism {
704    /// Allow a single tile decode to use internal code-block parallelism.
705    #[default]
706    Auto,
707    /// Keep code-block decode serial for callers that already parallelize tiles.
708    Serial,
709}
710
711/// A decoder context for decoding JPEG2000 images.
712pub struct DecoderContext<'a> {
713    pub(crate) tile_decode_context: TileDecodeContext,
714    pub(crate) storage: DecompositionStorage<'a>,
715    cpu_decode_parallelism: CpuDecodeParallelism,
716}
717
718impl Default for DecoderContext<'_> {
719    fn default() -> Self {
720        Self {
721            tile_decode_context: TileDecodeContext::default(),
722            storage: DecompositionStorage::default(),
723            cpu_decode_parallelism: CpuDecodeParallelism::Auto,
724        }
725    }
726}
727
728impl DecoderContext<'_> {
729    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) -> Result<()> {
730        self.tile_decode_context.reset(header, initial_tile)?;
731        self.storage.reset();
732        Ok(())
733    }
734
735    pub(crate) fn set_output_region(&mut self, output_region: Option<(u32, u32, u32, u32)>) {
736        self.tile_decode_context.output_region = output_region.map(OutputRegion::from_tuple);
737    }
738
739    /// Return the native CPU decode parallelism policy.
740    pub fn cpu_decode_parallelism(&self) -> CpuDecodeParallelism {
741        self.cpu_decode_parallelism
742    }
743
744    /// Set the native CPU decode parallelism policy.
745    pub fn set_cpu_decode_parallelism(&mut self, parallelism: CpuDecodeParallelism) {
746        self.cpu_decode_parallelism = parallelism;
747    }
748}
749
750fn decode_tile<'a, 'b>(
751    tile: &'b Tile<'a>,
752    header: &Header<'_>,
753    progression_iterator: Box<dyn Iterator<Item = ProgressionData> + '_>,
754    tile_ctx: &mut TileDecodeContext,
755    storage: &mut DecompositionStorage<'a>,
756    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
757    cpu_decode_parallelism: CpuDecodeParallelism,
758    profile_enabled: bool,
759    profile_timings: &mut DecodeProfileTimings,
760) -> Result<()> {
761    storage.reset();
762
763    // This is the method that orchestrates all steps.
764
765    // First, we build the decompositions, including their sub-bands, precincts
766    // and code blocks.
767    let stage_start = profile::profile_now(profile_enabled);
768    build::build(tile, storage)?;
769    if let Some(output_region) = tile_ctx.output_region {
770        storage.roi_plan = RoiPlan::build(tile, header, storage, output_region);
771        if storage.roi_plan.is_some() {
772            storage.coefficients.fill(0.0);
773        }
774    }
775    profile_timings.build_us += profile::elapsed_us(stage_start);
776    // Next, we parse the layers/segments for each code block.
777    let stage_start = profile::profile_now(profile_enabled);
778    segment::parse(tile, progression_iterator, header, storage)?;
779    profile_timings.segment_us += profile::elapsed_us(stage_start);
780    // We then decode the bitplanes of each code block, yielding the
781    // (possibly dequantized) coefficients of each code block.
782    let stage_start = profile::profile_now(profile_enabled);
783    decode_component_tile_bit_planes(
784        tile,
785        tile_ctx,
786        storage,
787        header,
788        ht_decoder,
789        cpu_decode_parallelism,
790        profile_enabled,
791    )?;
792    profile_timings.codeblock_us += profile::elapsed_us(stage_start);
793
794    // Unlike before, we interleave the apply_idwt and store stages
795    // for each component tile so we can reuse allocations better.
796    for (idx, component_info) in header.component_infos.iter().enumerate() {
797        // Next, we apply the inverse discrete wavelet transform.
798        let stage_start = profile::profile_now(profile_enabled);
799        idwt::apply(
800            storage,
801            tile_ctx,
802            idx,
803            header,
804            component_info.wavelet_transform(),
805            ht_decoder,
806        )?;
807        profile_timings.idwt_us += profile::elapsed_us(stage_start);
808        // Finally, we store the raw samples for the tile area in the correct
809        // location. Note that in case we have MCT, we are not applying it yet.
810        // It will be applied in the very end once all tiles have been processed.
811        // The reason we do this is that applying MCT requires access to the
812        // data from _all_ components. If we didn't defer this until the end
813        // we would have to collect the IDWT outputs of all components before
814        // applying it. By not applying MCT here, we can get away with doing
815        // IDWT and store on a per-component basis. Thus, we only need to
816        // store one IDWT output at a time, allowing for better reuse of
817        // allocations.
818        let stage_start = profile::profile_now(profile_enabled);
819        store(tile, header, tile_ctx, component_info, idx, ht_decoder)?;
820        profile_timings.store_us += profile::elapsed_us(stage_start);
821    }
822
823    Ok(())
824}
825
826#[derive(Default)]
827struct DecodeProfileTimings {
828    parse_tiles_us: u128,
829    build_us: u128,
830    segment_us: u128,
831    codeblock_us: u128,
832    idwt_us: u128,
833    store_us: u128,
834    mct_us: u128,
835}
836
837#[cold]
838#[inline(never)]
839fn emit_decode_profile_row(
840    tile_ctx: &TileDecodeContext,
841    profile_timings: &DecodeProfileTimings,
842    total_start: Option<profile::ProfileInstant>,
843) {
844    profile::emit_profile_row(
845        "decode",
846        "cpu",
847        &[
848            ("parse_tiles_us", profile_timings.parse_tiles_us),
849            ("build_us", profile_timings.build_us),
850            ("segment_us", profile_timings.segment_us),
851            ("codeblock_us", profile_timings.codeblock_us),
852            ("ht_blocks", tile_ctx.debug_counters.ht_phase_stats.blocks),
853            (
854                "ht_refinement_blocks",
855                tile_ctx.debug_counters.ht_phase_stats.refinement_blocks,
856            ),
857            (
858                "ht_cleanup_bytes",
859                tile_ctx.debug_counters.ht_phase_stats.cleanup_bytes,
860            ),
861            (
862                "ht_refinement_bytes",
863                tile_ctx.debug_counters.ht_phase_stats.refinement_bytes,
864            ),
865            (
866                "ht_cleanup_us",
867                tile_ctx.debug_counters.ht_phase_stats.ht_cleanup_us,
868            ),
869            (
870                "ht_mag_sgn_us",
871                tile_ctx.debug_counters.ht_phase_stats.ht_mag_sgn_us,
872            ),
873            (
874                "ht_sigma_us",
875                tile_ctx.debug_counters.ht_phase_stats.ht_sigma_us,
876            ),
877            (
878                "ht_sigprop_us",
879                tile_ctx.debug_counters.ht_phase_stats.ht_sigprop_us,
880            ),
881            (
882                "ht_magref_us",
883                tile_ctx.debug_counters.ht_phase_stats.ht_magref_us,
884            ),
885            ("idwt_us", profile_timings.idwt_us),
886            ("store_us", profile_timings.store_us),
887            ("mct_us", profile_timings.mct_us),
888            ("total_us", profile::elapsed_us(total_start)),
889        ],
890    );
891}
892
893/// All decompositions for a single tile.
894#[derive(Clone)]
895pub(crate) struct TileDecompositions {
896    pub(crate) first_ll_sub_band: usize,
897    pub(crate) decompositions: Range<usize>,
898}
899
900impl TileDecompositions {
901    pub(crate) fn sub_band_iter(
902        &self,
903        resolution: u8,
904        decompositions: &[Decomposition],
905    ) -> SubBandIter {
906        let indices = if resolution == 0 {
907            [
908                self.first_ll_sub_band,
909                self.first_ll_sub_band,
910                self.first_ll_sub_band,
911            ]
912        } else {
913            decompositions[self.decompositions.clone()][resolution as usize - 1].sub_bands
914        };
915
916        SubBandIter {
917            next_idx: 0,
918            indices,
919            resolution,
920        }
921    }
922}
923
924#[derive(Clone)]
925pub(crate) struct SubBandIter {
926    resolution: u8,
927    next_idx: usize,
928    indices: [usize; 3],
929}
930
931impl Iterator for SubBandIter {
932    type Item = usize;
933
934    fn next(&mut self) -> Option<Self::Item> {
935        let value = if self.resolution == 0 {
936            if self.next_idx > 0 {
937                None
938            } else {
939                Some(self.indices[0])
940            }
941        } else if self.next_idx >= self.indices.len() {
942            None
943        } else {
944            Some(self.indices[self.next_idx])
945        };
946
947        self.next_idx += 1;
948
949        value
950    }
951}
952
953/// A buffer so that we can reuse allocations for layers/code blocks/etc.
954/// across different tiles.
955#[derive(Default)]
956pub(crate) struct DecompositionStorage<'a> {
957    pub(crate) segments: Vec<Segment<'a>>,
958    pub(crate) layers: Vec<Layer>,
959    pub(crate) code_blocks: Vec<CodeBlock>,
960    pub(crate) precincts: Vec<Precinct>,
961    pub(crate) tag_tree_nodes: Vec<TagNode>,
962    pub(crate) coefficients: Vec<f32>,
963    pub(crate) sub_bands: Vec<SubBand>,
964    pub(crate) decompositions: Vec<Decomposition>,
965    pub(crate) tile_decompositions: Vec<TileDecompositions>,
966    pub(crate) roi_plan: Option<RoiPlan>,
967}
968
969impl DecompositionStorage<'_> {
970    pub(crate) fn reset(&mut self) {
971        self.segments.clear();
972        self.layers.clear();
973        self.code_blocks.clear();
974        // No need to clear the coefficients, as they will be resized
975        // and then overridden.
976        // self.coefficients.clear();
977        self.precincts.clear();
978        self.sub_bands.clear();
979        self.decompositions.clear();
980        self.tile_decompositions.clear();
981        self.tag_tree_nodes.clear();
982        self.roi_plan = None;
983    }
984}
985
986/// A reusable context used during the decoding of a single tile.
987///
988/// Some of the fields are temporary in nature and reset after moving on to the
989/// next tile, some contain global state.
990#[derive(Default)]
991pub(crate) struct TileDecodeContext {
992    /// A reusable buffer for the IDWT output.
993    pub(crate) idwt_output: IDWTOutput,
994    /// A scratch buffer used during IDWT.
995    pub(crate) idwt_scratch_buffer: Vec<f32>,
996    /// A reusable context for decoding code blocks.
997    pub(crate) bit_plane_decode_context: BitPlaneDecodeContext,
998    /// Reusable buffers for decoding bitplanes.
999    pub(crate) bit_plane_decode_buffers: BitPlaneDecodeBuffers,
1000    /// A reusable context for decoding HTJ2K code blocks.
1001    pub(crate) ht_block_decode_context: HtBlockDecodeContext,
1002    /// The raw, decoded samples for each channel.
1003    pub(crate) channel_data: Vec<ComponentData>,
1004    /// Optional output window for region-local decode storage.
1005    pub(crate) output_region: Option<OutputRegion>,
1006    /// Debug counters for tests and ROI instrumentation.
1007    pub(crate) debug_counters: DecodeDebugCounters,
1008}
1009
1010impl TileDecodeContext {
1011    /// Reset the context for processing a new image.
1012    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) -> Result<()> {
1013        // Bitplane decode context and buffers will be reset in the
1014        // corresponding methods. IDWT output and scratch buffer will be
1015        // overridden on demand, so those don't need to be reset either.
1016        self.channel_data.clear();
1017        self.debug_counters = DecodeDebugCounters::default();
1018
1019        let (output_width, output_height) =
1020            self.output_region.map(OutputRegion::dimensions).unwrap_or((
1021                header.size_data.image_width(),
1022                header.size_data.image_height(),
1023            ));
1024
1025        let sample_count = checked_decode_sample_count(output_width, output_height)?;
1026        checked_decode_byte_len3(
1027            sample_count,
1028            initial_tile.component_infos.len(),
1029            size_of::<f32>(),
1030        )?;
1031
1032        // Allocate per component here; the surrounding context reuses the
1033        // higher-level vectors while `SimdBuffer` owns its initialized storage.
1034        for info in &initial_tile.component_infos {
1035            self.channel_data.push(ComponentData {
1036                container: SimdBuffer::zeros(sample_count),
1037                bit_depth: info.size_info.precision,
1038            });
1039        }
1040        Ok(())
1041    }
1042}
1043
1044pub(crate) fn decode_component_tile_bit_planes<'a>(
1045    tile: &Tile<'a>,
1046    tile_ctx: &mut TileDecodeContext,
1047    storage: &mut DecompositionStorage<'a>,
1048    header: &Header<'_>,
1049    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1050    cpu_decode_parallelism: CpuDecodeParallelism,
1051    profile_enabled: bool,
1052) -> Result<()> {
1053    for (tile_decompositions_idx, component_info) in tile.component_infos.iter().enumerate() {
1054        // Only decode the resolution levels we actually care about.
1055        for resolution in
1056            0..component_info.num_resolution_levels() - header.skipped_resolution_levels
1057        {
1058            let tile_composition = &storage.tile_decompositions[tile_decompositions_idx];
1059            let sub_band_iter = tile_composition.sub_band_iter(resolution, &storage.decompositions);
1060
1061            for sub_band_idx in sub_band_iter {
1062                decode_sub_band_bitplanes(
1063                    sub_band_idx,
1064                    resolution,
1065                    component_info,
1066                    tile_ctx,
1067                    storage,
1068                    header,
1069                    ht_decoder,
1070                    cpu_decode_parallelism,
1071                    profile_enabled,
1072                )?;
1073            }
1074        }
1075    }
1076
1077    Ok(())
1078}
1079
1080fn decode_sub_band_bitplanes(
1081    sub_band_idx: usize,
1082    resolution: u8,
1083    component_info: &ComponentInfo,
1084    tile_ctx: &mut TileDecodeContext,
1085    storage: &mut DecompositionStorage<'_>,
1086    header: &Header<'_>,
1087    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1088    cpu_decode_parallelism: CpuDecodeParallelism,
1089    profile_enabled: bool,
1090) -> Result<()> {
1091    let sub_band = storage.sub_bands[sub_band_idx].clone();
1092
1093    let dequantization_step = {
1094        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
1095        {
1096            1.0
1097        } else {
1098            let (exponent, mantissa) =
1099                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
1100
1101            let r_b = {
1102                let log_gain = match sub_band.sub_band_type {
1103                    SubBandType::LowLow => 0,
1104                    SubBandType::LowHigh => 1,
1105                    SubBandType::HighLow => 1,
1106                    SubBandType::HighHigh => 2,
1107                };
1108
1109                component_info.size_info.precision as u16 + log_gain
1110            };
1111
1112            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
1113        }
1114    };
1115
1116    let num_bitplanes = {
1117        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
1118        // Equation (E-2)
1119        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
1120            .checked_add(exponent)
1121            .and_then(|x| x.checked_sub(1))
1122            .ok_or(DecodingError::InvalidBitplaneCount)?;
1123
1124        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
1125            bail!(DecodingError::TooManyBitplanes);
1126        }
1127
1128        num_bitplanes as u8
1129    };
1130
1131    if component_info
1132        .coding_style
1133        .parameters
1134        .code_block_style
1135        .uses_high_throughput_block_coding()
1136    {
1137        decode_sub_band_ht_blocks(
1138            sub_band_idx,
1139            &sub_band,
1140            component_info,
1141            tile_ctx,
1142            storage,
1143            header,
1144            ht_decoder,
1145            cpu_decode_parallelism,
1146            num_bitplanes,
1147            dequantization_step,
1148            profile_enabled,
1149        )?;
1150        return Ok(());
1151    }
1152
1153    let coded_bitplanes =
1154        add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, MAX_BITPLANE_COUNT)?;
1155
1156    let classic_job_sub_band_type = match sub_band.sub_band_type {
1157        SubBandType::LowLow => J2kSubBandType::LowLow,
1158        SubBandType::HighLow => J2kSubBandType::HighLow,
1159        SubBandType::LowHigh => J2kSubBandType::LowHigh,
1160        SubBandType::HighHigh => J2kSubBandType::HighHigh,
1161    };
1162    let classic_job_style = J2kCodeBlockStyle {
1163        selective_arithmetic_coding_bypass: component_info
1164            .coding_style
1165            .parameters
1166            .code_block_style
1167            .selective_arithmetic_coding_bypass,
1168        reset_context_probabilities: component_info
1169            .coding_style
1170            .parameters
1171            .code_block_style
1172            .reset_context_probabilities,
1173        termination_on_each_pass: component_info
1174            .coding_style
1175            .parameters
1176            .code_block_style
1177            .termination_on_each_pass,
1178        vertically_causal_context: component_info
1179            .coding_style
1180            .parameters
1181            .code_block_style
1182            .vertically_causal_context,
1183        segmentation_symbols: component_info
1184            .coding_style
1185            .parameters
1186            .code_block_style
1187            .segmentation_symbols,
1188    };
1189
1190    if let Some(ht_decoder) = ht_decoder.as_deref_mut() {
1191        let pending_blocks =
1192            collect_pending_classic_blocks(sub_band_idx, &sub_band, component_info, storage)?;
1193
1194        let batch_jobs: Vec<_> = pending_blocks
1195            .iter()
1196            .map(|pending| J2kCodeBlockBatchJob {
1197                output_x: pending.output_x,
1198                output_y: pending.output_y,
1199                code_block: J2kCodeBlockDecodeJob {
1200                    data: &pending.combined_data,
1201                    segments: &pending.segments,
1202                    width: pending.width,
1203                    height: pending.height,
1204                    output_stride: sub_band.rect.width() as usize,
1205                    missing_bit_planes: pending.missing_bit_planes,
1206                    number_of_coding_passes: pending.number_of_coding_passes,
1207                    total_bitplanes: num_bitplanes,
1208                    roi_shift: component_info.roi_shift,
1209                    sub_band_type: classic_job_sub_band_type,
1210                    style: classic_job_style,
1211                    strict: header.strict,
1212                    dequantization_step,
1213                },
1214            })
1215            .collect();
1216
1217        let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1218        if ht_decoder.decode_j2k_sub_band(
1219            J2kSubBandDecodeJob {
1220                width: sub_band.rect.width(),
1221                height: sub_band.rect.height(),
1222                jobs: &batch_jobs,
1223            },
1224            base_store,
1225        )? {
1226            tile_ctx.debug_counters.decoded_code_blocks += batch_jobs.len();
1227            return Ok(());
1228        }
1229
1230        let output_stride = sub_band.rect.width() as usize;
1231        for job in batch_jobs {
1232            tile_ctx.debug_counters.decoded_code_blocks += 1;
1233            let base_idx = (job.output_y * sub_band.rect.width()) as usize + job.output_x as usize;
1234            let output_len = if job.code_block.height == 0 {
1235                0
1236            } else {
1237                output_stride
1238                    .checked_mul(job.code_block.height as usize - 1)
1239                    .and_then(|prefix| prefix.checked_add(job.code_block.width as usize))
1240                    .ok_or(DecodingError::CodeBlockDecodeFailure)?
1241            };
1242            let output_slice = &mut base_store[base_idx..base_idx + output_len];
1243            if ht_decoder.decode_j2k_code_block(job.code_block, output_slice)? {
1244                continue;
1245            }
1246            decode_j2k_code_block_scalar(job.code_block, output_slice)?;
1247        }
1248
1249        return Ok(());
1250    }
1251
1252    let code_block_count = count_classic_code_blocks(sub_band_idx, &sub_band, storage);
1253    if should_decode_classic_sub_band_in_parallel(cpu_decode_parallelism, code_block_count) {
1254        #[cfg(feature = "parallel")]
1255        {
1256            let pending_blocks =
1257                collect_pending_classic_blocks(sub_band_idx, &sub_band, component_info, storage)?;
1258            let decoded_blocks = decode_classic_sub_band_blocks_parallel(
1259                &pending_blocks,
1260                classic_job_sub_band_type,
1261                classic_job_style,
1262                header.strict,
1263                num_bitplanes,
1264                component_info.roi_shift,
1265                dequantization_step,
1266            )?;
1267            tile_ctx.debug_counters.decoded_code_blocks += decoded_blocks.len();
1268            copy_decoded_classic_blocks_to_sub_band(&decoded_blocks, &sub_band, storage)?;
1269            return Ok(());
1270        }
1271    }
1272
1273    for precinct in sub_band
1274        .precincts
1275        .clone()
1276        .map(|idx| &storage.precincts[idx])
1277    {
1278        for code_block in precinct
1279            .code_blocks
1280            .clone()
1281            .map(|idx| &storage.code_blocks[idx])
1282        {
1283            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1284                tile_ctx.debug_counters.skipped_code_blocks += 1;
1285                continue;
1286            }
1287            tile_ctx.debug_counters.decoded_code_blocks += 1;
1288            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
1289            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
1290            let output_stride = sub_band.rect.width() as usize;
1291            let base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
1292
1293            bitplane::decode(
1294                code_block,
1295                sub_band.sub_band_type,
1296                coded_bitplanes,
1297                &component_info.coding_style.parameters.code_block_style,
1298                tile_ctx,
1299                storage,
1300                header.strict,
1301            )?;
1302
1303            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1304            let mut base_idx = base_idx;
1305
1306            for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
1307                let out_row = &mut base_store[base_idx..];
1308
1309                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
1310                    let coefficient =
1311                        apply_roi_maxshift_inverse_i32(coefficient.get(), component_info.roi_shift);
1312                    *output = coefficient as f32;
1313                    *output *= dequantization_step;
1314                }
1315
1316                base_idx += output_stride;
1317            }
1318        }
1319    }
1320
1321    Ok(())
1322}
1323
1324struct PendingHtBlock {
1325    combined: ht_block_decode::CombinedCodeBlockData,
1326    output_x: u32,
1327    output_y: u32,
1328    width: u32,
1329    height: u32,
1330    missing_bit_planes: u8,
1331    number_of_coding_passes: u8,
1332}
1333
1334struct PendingClassicBlock {
1335    combined_data: Vec<u8>,
1336    segments: Vec<J2kCodeBlockSegment>,
1337    output_x: u32,
1338    output_y: u32,
1339    width: u32,
1340    height: u32,
1341    missing_bit_planes: u8,
1342    number_of_coding_passes: u8,
1343}
1344
1345#[cfg(feature = "parallel")]
1346struct DecodedClassicBlock {
1347    output_x: u32,
1348    output_y: u32,
1349    width: u32,
1350    height: u32,
1351    coefficients: Vec<f32>,
1352}
1353
1354#[cfg(feature = "parallel")]
1355struct DecodedHtBlock {
1356    output_x: u32,
1357    output_y: u32,
1358    width: u32,
1359    height: u32,
1360    coefficients: Vec<f32>,
1361}
1362
1363fn count_classic_code_blocks(
1364    sub_band_idx: usize,
1365    sub_band: &SubBand,
1366    storage: &DecompositionStorage<'_>,
1367) -> usize {
1368    sub_band
1369        .precincts
1370        .clone()
1371        .map(|idx| &storage.precincts[idx])
1372        .map(|precinct| {
1373            precinct
1374                .code_blocks
1375                .clone()
1376                .filter(|idx| {
1377                    let code_block = &storage.code_blocks[*idx];
1378                    code_block_required_by_index(storage, sub_band_idx, code_block)
1379                })
1380                .count()
1381        })
1382        .sum()
1383}
1384
1385fn code_block_required_by_index(
1386    storage: &DecompositionStorage<'_>,
1387    sub_band_idx: usize,
1388    code_block: &CodeBlock,
1389) -> bool {
1390    storage
1391        .roi_plan
1392        .as_ref()
1393        .is_none_or(|plan| plan.code_block_required(sub_band_idx, code_block.rect))
1394}
1395
1396fn collect_pending_classic_blocks(
1397    sub_band_idx: usize,
1398    sub_band: &SubBand,
1399    component_info: &ComponentInfo,
1400    storage: &DecompositionStorage<'_>,
1401) -> Result<Vec<PendingClassicBlock>> {
1402    let mut pending_blocks =
1403        Vec::with_capacity(count_classic_code_blocks(sub_band_idx, sub_band, storage));
1404    for precinct in sub_band
1405        .precincts
1406        .clone()
1407        .map(|idx| &storage.precincts[idx])
1408    {
1409        for code_block in precinct
1410            .code_blocks
1411            .clone()
1412            .map(|idx| &storage.code_blocks[idx])
1413        {
1414            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1415                continue;
1416            }
1417            let (combined_data, segments) = collect_classic_code_block_data(
1418                code_block,
1419                &component_info.coding_style.parameters.code_block_style,
1420                storage,
1421            )?;
1422            pending_blocks.push(PendingClassicBlock {
1423                combined_data,
1424                segments,
1425                output_x: code_block.rect.x0 - sub_band.rect.x0,
1426                output_y: code_block.rect.y0 - sub_band.rect.y0,
1427                width: code_block.rect.width(),
1428                height: code_block.rect.height(),
1429                missing_bit_planes: code_block.missing_bit_planes,
1430                number_of_coding_passes: code_block.number_of_coding_passes,
1431            });
1432        }
1433    }
1434    Ok(pending_blocks)
1435}
1436
1437fn count_ht_code_blocks(
1438    sub_band_idx: usize,
1439    sub_band: &SubBand,
1440    storage: &DecompositionStorage<'_>,
1441) -> usize {
1442    sub_band
1443        .precincts
1444        .clone()
1445        .map(|idx| &storage.precincts[idx])
1446        .map(|precinct| {
1447            precinct
1448                .code_blocks
1449                .clone()
1450                .filter(|idx| {
1451                    let code_block = &storage.code_blocks[*idx];
1452                    code_block_required_by_index(storage, sub_band_idx, code_block)
1453                        && code_block.number_of_coding_passes > 0
1454                })
1455                .count()
1456        })
1457        .sum()
1458}
1459
1460#[cfg(feature = "parallel")]
1461fn collect_pending_ht_blocks(
1462    sub_band_idx: usize,
1463    sub_band: &SubBand,
1464    storage: &DecompositionStorage<'_>,
1465    header: &Header<'_>,
1466    num_bitplanes: u8,
1467    roi_shift: u8,
1468) -> Result<Vec<PendingHtBlock>> {
1469    let coded_bitplanes = add_roi_shift_to_bitplanes(num_bitplanes, roi_shift, 31)?;
1470    let mut pending_blocks =
1471        Vec::with_capacity(count_ht_code_blocks(sub_band_idx, sub_band, storage));
1472    for precinct in sub_band
1473        .precincts
1474        .clone()
1475        .map(|idx| &storage.precincts[idx])
1476    {
1477        for code_block in precinct
1478            .code_blocks
1479            .clone()
1480            .map(|idx| &storage.code_blocks[idx])
1481        {
1482            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1483                continue;
1484            }
1485            let actual_bitplanes = if header.strict {
1486                coded_bitplanes
1487                    .checked_sub(code_block.missing_bit_planes)
1488                    .ok_or(DecodingError::InvalidBitplaneCount)?
1489            } else {
1490                coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
1491            };
1492            let max_coding_passes = if actual_bitplanes == 0 {
1493                0
1494            } else {
1495                1 + 3 * (actual_bitplanes - 1)
1496            };
1497            if code_block.number_of_coding_passes > max_coding_passes && header.strict {
1498                bail!(DecodingError::TooManyCodingPasses);
1499            }
1500            if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
1501                continue;
1502            }
1503
1504            pending_blocks.push(PendingHtBlock {
1505                combined: ht_block_decode::collect_code_block_data(code_block, storage)?,
1506                output_x: code_block.rect.x0 - sub_band.rect.x0,
1507                output_y: code_block.rect.y0 - sub_band.rect.y0,
1508                width: code_block.rect.width(),
1509                height: code_block.rect.height(),
1510                missing_bit_planes: code_block.missing_bit_planes,
1511                number_of_coding_passes: code_block.number_of_coding_passes,
1512            });
1513        }
1514    }
1515    Ok(pending_blocks)
1516}
1517
1518pub(crate) fn should_decode_classic_sub_band_in_parallel(
1519    parallelism: CpuDecodeParallelism,
1520    code_block_count: usize,
1521) -> bool {
1522    cfg!(feature = "parallel") && parallelism == CpuDecodeParallelism::Auto && code_block_count >= 4
1523}
1524
1525pub(crate) fn should_decode_ht_sub_band_in_parallel(
1526    parallelism: CpuDecodeParallelism,
1527    code_block_count: usize,
1528) -> bool {
1529    cfg!(feature = "parallel") && parallelism == CpuDecodeParallelism::Auto && code_block_count >= 4
1530}
1531
1532#[cfg(feature = "parallel")]
1533fn decode_classic_sub_band_blocks_parallel(
1534    pending_blocks: &[PendingClassicBlock],
1535    sub_band_type: J2kSubBandType,
1536    style: J2kCodeBlockStyle,
1537    strict: bool,
1538    total_bitplanes: u8,
1539    roi_shift: u8,
1540    dequantization_step: f32,
1541) -> Result<Vec<DecodedClassicBlock>> {
1542    use rayon::prelude::*;
1543
1544    pending_blocks
1545        .par_iter()
1546        .map(|pending| {
1547            let output_stride = pending.width as usize;
1548            let output_len = output_stride
1549                .checked_mul(pending.height as usize)
1550                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1551            let mut coefficients = vec![0.0; output_len];
1552            decode_j2k_code_block_scalar(
1553                J2kCodeBlockDecodeJob {
1554                    data: &pending.combined_data,
1555                    segments: &pending.segments,
1556                    width: pending.width,
1557                    height: pending.height,
1558                    output_stride,
1559                    missing_bit_planes: pending.missing_bit_planes,
1560                    number_of_coding_passes: pending.number_of_coding_passes,
1561                    total_bitplanes,
1562                    roi_shift,
1563                    sub_band_type,
1564                    style,
1565                    strict,
1566                    dequantization_step,
1567                },
1568                &mut coefficients,
1569            )?;
1570            Ok(DecodedClassicBlock {
1571                output_x: pending.output_x,
1572                output_y: pending.output_y,
1573                width: pending.width,
1574                height: pending.height,
1575                coefficients,
1576            })
1577        })
1578        .collect::<Vec<_>>()
1579        .into_iter()
1580        .collect()
1581}
1582
1583#[cfg(feature = "parallel")]
1584fn copy_decoded_classic_blocks_to_sub_band(
1585    decoded_blocks: &[DecodedClassicBlock],
1586    sub_band: &SubBand,
1587    storage: &mut DecompositionStorage<'_>,
1588) -> Result<()> {
1589    let sub_band_width = sub_band.rect.width() as usize;
1590    let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1591    for block in decoded_blocks {
1592        if block
1593            .output_x
1594            .checked_add(block.width)
1595            .is_none_or(|x1| x1 > sub_band.rect.width())
1596            || block
1597                .output_y
1598                .checked_add(block.height)
1599                .is_none_or(|y1| y1 > sub_band.rect.height())
1600        {
1601            bail!(DecodingError::CodeBlockDecodeFailure);
1602        }
1603        let block_width = block.width as usize;
1604        for row in 0..block.height as usize {
1605            let dst_start = (block.output_y as usize + row)
1606                .checked_mul(sub_band_width)
1607                .and_then(|offset| offset.checked_add(block.output_x as usize))
1608                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1609            let dst_end = dst_start
1610                .checked_add(block_width)
1611                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1612            let src_start = row
1613                .checked_mul(block_width)
1614                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1615            let src_end = src_start
1616                .checked_add(block_width)
1617                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1618            base_store[dst_start..dst_end].copy_from_slice(&block.coefficients[src_start..src_end]);
1619        }
1620    }
1621    Ok(())
1622}
1623
1624#[cfg(feature = "parallel")]
1625fn decode_ht_sub_band_blocks_parallel(
1626    pending_blocks: &[PendingHtBlock],
1627    strict: bool,
1628    num_bitplanes: u8,
1629    roi_shift: u8,
1630    stripe_causal: bool,
1631    dequantization_step: f32,
1632) -> Result<Vec<DecodedHtBlock>> {
1633    use rayon::prelude::*;
1634
1635    pending_blocks
1636        .par_iter()
1637        .map(|pending| {
1638            let output_stride = pending.width as usize;
1639            let output_len = output_stride
1640                .checked_mul(pending.height as usize)
1641                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1642            let mut coefficients = vec![0.0; output_len];
1643            let mut workspace = HtCodeBlockDecodeWorkspace::default();
1644            decode_ht_code_block_scalar_with_workspace(
1645                HtCodeBlockDecodeJob {
1646                    data: &pending.combined.data,
1647                    cleanup_length: pending.combined.cleanup_length,
1648                    refinement_length: pending.combined.refinement_length,
1649                    width: pending.width,
1650                    height: pending.height,
1651                    output_stride,
1652                    missing_bit_planes: pending.missing_bit_planes,
1653                    number_of_coding_passes: pending.number_of_coding_passes,
1654                    num_bitplanes,
1655                    roi_shift,
1656                    stripe_causal,
1657                    strict,
1658                    dequantization_step,
1659                },
1660                &mut coefficients,
1661                &mut workspace,
1662            )?;
1663            Ok(DecodedHtBlock {
1664                output_x: pending.output_x,
1665                output_y: pending.output_y,
1666                width: pending.width,
1667                height: pending.height,
1668                coefficients,
1669            })
1670        })
1671        .collect::<Vec<_>>()
1672        .into_iter()
1673        .collect()
1674}
1675
1676#[cfg(feature = "parallel")]
1677fn copy_decoded_ht_blocks_to_sub_band(
1678    decoded_blocks: &[DecodedHtBlock],
1679    sub_band: &SubBand,
1680    storage: &mut DecompositionStorage<'_>,
1681) -> Result<()> {
1682    let sub_band_width = sub_band.rect.width() as usize;
1683    let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1684    for block in decoded_blocks {
1685        if block
1686            .output_x
1687            .checked_add(block.width)
1688            .is_none_or(|x1| x1 > sub_band.rect.width())
1689            || block
1690                .output_y
1691                .checked_add(block.height)
1692                .is_none_or(|y1| y1 > sub_band.rect.height())
1693        {
1694            bail!(DecodingError::CodeBlockDecodeFailure);
1695        }
1696        let block_width = block.width as usize;
1697        for row in 0..block.height as usize {
1698            let dst_start = (block.output_y as usize + row)
1699                .checked_mul(sub_band_width)
1700                .and_then(|offset| offset.checked_add(block.output_x as usize))
1701                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1702            let dst_end = dst_start
1703                .checked_add(block_width)
1704                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1705            let src_start = row
1706                .checked_mul(block_width)
1707                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1708            let src_end = src_start
1709                .checked_add(block_width)
1710                .ok_or(DecodingError::CodeBlockDecodeFailure)?;
1711            base_store[dst_start..dst_end].copy_from_slice(&block.coefficients[src_start..src_end]);
1712        }
1713    }
1714    Ok(())
1715}
1716
1717fn decode_sub_band_ht_blocks(
1718    sub_band_idx: usize,
1719    sub_band: &SubBand,
1720    component_info: &ComponentInfo,
1721    tile_ctx: &mut TileDecodeContext,
1722    storage: &mut DecompositionStorage<'_>,
1723    header: &Header<'_>,
1724    ht_decoder: &mut Option<&mut dyn HtCodeBlockDecoder>,
1725    cpu_decode_parallelism: CpuDecodeParallelism,
1726    num_bitplanes: u8,
1727    dequantization_step: f32,
1728    profile_enabled: bool,
1729) -> Result<()> {
1730    let coded_bitplanes = add_roi_shift_to_bitplanes(num_bitplanes, component_info.roi_shift, 31)?;
1731    let stripe_causal = component_info
1732        .coding_style
1733        .parameters
1734        .code_block_style
1735        .vertically_causal_context;
1736
1737    if let Some(ht_decoder) = ht_decoder.as_deref_mut() {
1738        let mut pending_blocks = Vec::new();
1739        for precinct in sub_band
1740            .precincts
1741            .clone()
1742            .map(|idx| &storage.precincts[idx])
1743        {
1744            for code_block in precinct
1745                .code_blocks
1746                .clone()
1747                .map(|idx| &storage.code_blocks[idx])
1748            {
1749                if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1750                    continue;
1751                }
1752                let actual_bitplanes = if header.strict {
1753                    coded_bitplanes
1754                        .checked_sub(code_block.missing_bit_planes)
1755                        .ok_or(DecodingError::InvalidBitplaneCount)?
1756                } else {
1757                    coded_bitplanes.saturating_sub(code_block.missing_bit_planes)
1758                };
1759                let max_coding_passes = if actual_bitplanes == 0 {
1760                    0
1761                } else {
1762                    1 + 3 * (actual_bitplanes - 1)
1763                };
1764                if code_block.number_of_coding_passes > max_coding_passes && header.strict {
1765                    bail!(DecodingError::TooManyCodingPasses);
1766                }
1767                if code_block.number_of_coding_passes == 0 || actual_bitplanes == 0 {
1768                    continue;
1769                }
1770
1771                pending_blocks.push(PendingHtBlock {
1772                    combined: ht_block_decode::collect_code_block_data(code_block, storage)?,
1773                    output_x: code_block.rect.x0 - sub_band.rect.x0,
1774                    output_y: code_block.rect.y0 - sub_band.rect.y0,
1775                    width: code_block.rect.width(),
1776                    height: code_block.rect.height(),
1777                    missing_bit_planes: code_block.missing_bit_planes,
1778                    number_of_coding_passes: code_block.number_of_coding_passes,
1779                });
1780            }
1781        }
1782
1783        let batch_jobs: Vec<_> = pending_blocks
1784            .iter()
1785            .map(|pending| HtCodeBlockBatchJob {
1786                output_x: pending.output_x,
1787                output_y: pending.output_y,
1788                code_block: HtCodeBlockDecodeJob {
1789                    data: &pending.combined.data,
1790                    cleanup_length: pending.combined.cleanup_length,
1791                    refinement_length: pending.combined.refinement_length,
1792                    width: pending.width,
1793                    height: pending.height,
1794                    output_stride: sub_band.rect.width() as usize,
1795                    missing_bit_planes: pending.missing_bit_planes,
1796                    number_of_coding_passes: pending.number_of_coding_passes,
1797                    num_bitplanes,
1798                    roi_shift: component_info.roi_shift,
1799                    stripe_causal,
1800                    strict: header.strict,
1801                    dequantization_step,
1802                },
1803            })
1804            .collect();
1805
1806        let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1807        if ht_decoder.decode_sub_band(
1808            HtSubBandDecodeJob {
1809                width: sub_band.rect.width(),
1810                height: sub_band.rect.height(),
1811                jobs: &batch_jobs,
1812            },
1813            base_store,
1814        )? {
1815            tile_ctx.debug_counters.decoded_code_blocks += batch_jobs.len();
1816            return Ok(());
1817        }
1818
1819        let output_stride = sub_band.rect.width() as usize;
1820        for job in batch_jobs {
1821            tile_ctx.debug_counters.decoded_code_blocks += 1;
1822            let base_idx = (job.output_y * sub_band.rect.width()) as usize + job.output_x as usize;
1823            let output_len = if job.code_block.height == 0 {
1824                0
1825            } else {
1826                output_stride * (job.code_block.height as usize - 1) + job.code_block.width as usize
1827            };
1828            ht_decoder.decode_code_block(
1829                job.code_block,
1830                &mut base_store[base_idx..base_idx + output_len],
1831            )?;
1832        }
1833
1834        return Ok(());
1835    }
1836
1837    let code_block_count = count_ht_code_blocks(sub_band_idx, sub_band, storage);
1838    if !profile_enabled
1839        && should_decode_ht_sub_band_in_parallel(cpu_decode_parallelism, code_block_count)
1840    {
1841        #[cfg(feature = "parallel")]
1842        {
1843            let pending_blocks = collect_pending_ht_blocks(
1844                sub_band_idx,
1845                sub_band,
1846                storage,
1847                header,
1848                num_bitplanes,
1849                component_info.roi_shift,
1850            )?;
1851            let decoded_blocks = decode_ht_sub_band_blocks_parallel(
1852                &pending_blocks,
1853                header.strict,
1854                num_bitplanes,
1855                component_info.roi_shift,
1856                stripe_causal,
1857                dequantization_step,
1858            )?;
1859            tile_ctx.debug_counters.decoded_code_blocks += decoded_blocks.len();
1860            copy_decoded_ht_blocks_to_sub_band(&decoded_blocks, sub_band, storage)?;
1861            return Ok(());
1862        }
1863    }
1864
1865    for precinct in sub_band
1866        .precincts
1867        .clone()
1868        .map(|idx| &storage.precincts[idx])
1869    {
1870        for code_block in precinct
1871            .code_blocks
1872            .clone()
1873            .map(|idx| &storage.code_blocks[idx])
1874        {
1875            if !code_block_required_by_index(storage, sub_band_idx, code_block) {
1876                tile_ctx.debug_counters.skipped_code_blocks += 1;
1877                continue;
1878            }
1879            tile_ctx.debug_counters.decoded_code_blocks += 1;
1880            ht_block_decode::decode_with_stats(
1881                code_block,
1882                coded_bitplanes,
1883                stripe_causal,
1884                &mut tile_ctx.ht_block_decode_context,
1885                storage,
1886                header.strict,
1887                Some(&mut tile_ctx.debug_counters.ht_phase_stats),
1888                profile_enabled,
1889            )?;
1890
1891            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
1892            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
1893            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
1894            let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
1895            let output_stride = sub_band.rect.width() as usize;
1896
1897            for coefficients in tile_ctx.ht_block_decode_context.coefficient_rows() {
1898                let out_row = &mut base_store[base_idx..];
1899
1900                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
1901                    let coefficient =
1902                        ht_block_decode::coefficient_to_i32(coefficient, coded_bitplanes);
1903                    let coefficient =
1904                        apply_roi_maxshift_inverse_i32(coefficient, component_info.roi_shift);
1905                    *output = coefficient as f32;
1906                    *output *= dequantization_step;
1907                }
1908
1909                base_idx += output_stride;
1910            }
1911        }
1912    }
1913
1914    Ok(())
1915}
1916
1917fn apply_sign_shift(tile_ctx: &mut TileDecodeContext, component_infos: &[ComponentInfo]) {
1918    for (channel_data, component_info) in
1919        tile_ctx.channel_data.iter_mut().zip(component_infos.iter())
1920    {
1921        let addend = (1_u32 << (component_info.size_info.precision - 1)) as f32;
1922        for sample in channel_data.container.deref_mut() {
1923            *sample += addend;
1924        }
1925    }
1926}
1927
1928fn store<'a>(
1929    tile: &'a Tile<'a>,
1930    header: &Header<'_>,
1931    tile_ctx: &mut TileDecodeContext,
1932    component_info: &ComponentInfo,
1933    component_idx: usize,
1934    backend: &mut Option<&mut dyn HtCodeBlockDecoder>,
1935) -> Result<()> {
1936    let channel_data = &mut tile_ctx.channel_data[component_idx];
1937    let idwt_output = &mut tile_ctx.idwt_output;
1938
1939    let component_tile = ComponentTile::new(tile, component_info);
1940    let resolution_tile = ResolutionTile::new(
1941        component_tile,
1942        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
1943    );
1944
1945    let sign_shift = if tile.mct {
1946        0.0
1947    } else {
1948        (1_u32 << (component_info.size_info.precision - 1)) as f32
1949    };
1950
1951    let (scale_x, scale_y) = (
1952        component_info.size_info.horizontal_resolution,
1953        component_info.size_info.vertical_resolution,
1954    );
1955
1956    let (image_x_offset, image_y_offset) = (
1957        header.size_data.image_area_x_offset,
1958        header.size_data.image_area_y_offset,
1959    );
1960
1961    if let Some(output_region) = tile_ctx.output_region {
1962        store_region(
1963            tile,
1964            header,
1965            tile_ctx,
1966            component_info,
1967            component_idx,
1968            output_region,
1969            backend,
1970            sign_shift,
1971        )?;
1972        return Ok(());
1973    }
1974
1975    if scale_x == 1 && scale_y == 1 {
1976        let source_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
1977        let source_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
1978        let copy_width = resolution_tile
1979            .rect
1980            .width()
1981            .min(idwt_output.rect.width().saturating_sub(source_x));
1982        let copy_height = resolution_tile
1983            .rect
1984            .height()
1985            .min(idwt_output.rect.height().saturating_sub(source_y));
1986        let output_x = resolution_tile.rect.x0.saturating_sub(image_x_offset);
1987        let output_y = resolution_tile.rect.y0.saturating_sub(image_y_offset);
1988
1989        let handled = if let Some(backend) = backend.as_deref_mut() {
1990            copy_width > 0
1991                && copy_height > 0
1992                && backend.decode_store_component(J2kStoreComponentJob {
1993                    input: &idwt_output.coefficients,
1994                    input_width: idwt_output.rect.width(),
1995                    source_x,
1996                    source_y,
1997                    copy_width,
1998                    copy_height,
1999                    output: &mut channel_data.container,
2000                    output_width: header.size_data.image_width(),
2001                    output_x,
2002                    output_y,
2003                    addend: sign_shift,
2004                })?
2005        } else {
2006            false
2007        };
2008
2009        if handled {
2010            return Ok(());
2011        }
2012
2013        // If no sub-sampling, use a fast path where we copy rows of coefficients
2014        // at once.
2015
2016        // The rect of the IDWT output corresponds to the rect of the highest
2017        // decomposition level of the tile, which is usually not 1:1 aligned
2018        // with the actual tile rectangle. We also need to account for the
2019        // offset of the reference grid.
2020
2021        let skip_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
2022        let skip_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
2023
2024        if sign_shift != 0.0 {
2025            for sample in idwt_output.coefficients.iter_mut() {
2026                *sample += sign_shift;
2027            }
2028        }
2029
2030        let input_row_iter = idwt_output
2031            .coefficients
2032            .chunks_exact(idwt_output.rect.width() as usize)
2033            .skip(skip_y as usize)
2034            .take(idwt_output.rect.height() as usize);
2035
2036        let output_row_iter = channel_data
2037            .container
2038            .chunks_exact_mut(header.size_data.image_width() as usize)
2039            .skip(resolution_tile.rect.y0.saturating_sub(image_y_offset) as usize);
2040
2041        for (input_row, output_row) in input_row_iter.zip(output_row_iter) {
2042            let input_row = &input_row[skip_x as usize..];
2043            let output_row = &mut output_row
2044                [resolution_tile.rect.x0.saturating_sub(image_x_offset) as usize..]
2045                [..input_row.len()];
2046
2047            output_row.copy_from_slice(input_row);
2048        }
2049    } else {
2050        if sign_shift != 0.0 {
2051            for sample in idwt_output.coefficients.iter_mut() {
2052                *sample += sign_shift;
2053            }
2054        }
2055        let image_width = header.size_data.image_width();
2056        let image_height = header.size_data.image_height();
2057
2058        let x_shrink_factor = header.size_data.x_shrink_factor;
2059        let y_shrink_factor = header.size_data.y_shrink_factor;
2060
2061        let x_offset = header
2062            .size_data
2063            .image_area_x_offset
2064            .div_ceil(x_shrink_factor);
2065        let y_offset = header
2066            .size_data
2067            .image_area_y_offset
2068            .div_ceil(y_shrink_factor);
2069
2070        // Otherwise, copy sample by sample.
2071        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
2072            let relative_y = (y - component_tile.rect.y0) as usize;
2073            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
2074
2075            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
2076                let relative_x = (x - component_tile.rect.x0) as usize;
2077                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
2078
2079                let sample = idwt_output.coefficients
2080                    [relative_y * idwt_output.rect.width() as usize + relative_x];
2081
2082                for x_position in u32::max(reference_grid_x, x_offset)
2083                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
2084                {
2085                    for y_position in u32::max(reference_grid_y, y_offset)
2086                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
2087                    {
2088                        let pos = (y_position - y_offset) as usize * image_width as usize
2089                            + (x_position - x_offset) as usize;
2090
2091                        channel_data.container[pos] = sample;
2092                    }
2093                }
2094            }
2095        }
2096    }
2097
2098    Ok(())
2099}
2100
2101fn store_region<'a>(
2102    tile: &'a Tile<'a>,
2103    header: &Header<'_>,
2104    tile_ctx: &mut TileDecodeContext,
2105    component_info: &ComponentInfo,
2106    component_idx: usize,
2107    output_region: OutputRegion,
2108    backend: &mut Option<&mut dyn HtCodeBlockDecoder>,
2109    sign_shift: f32,
2110) -> Result<()> {
2111    let channel_data = &mut tile_ctx.channel_data[component_idx];
2112    let idwt_output = &mut tile_ctx.idwt_output;
2113
2114    let component_tile = ComponentTile::new(tile, component_info);
2115    let resolution_tile = ResolutionTile::new(
2116        component_tile,
2117        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
2118    );
2119
2120    let (scale_x, scale_y) = (
2121        component_info.size_info.horizontal_resolution,
2122        component_info.size_info.vertical_resolution,
2123    );
2124    let image_width = header.size_data.image_width();
2125    let image_height = header.size_data.image_height();
2126    let x_shrink_factor = header.size_data.x_shrink_factor;
2127    let y_shrink_factor = header.size_data.y_shrink_factor;
2128    let x_offset = header
2129        .size_data
2130        .image_area_x_offset
2131        .div_ceil(x_shrink_factor);
2132    let y_offset = header
2133        .size_data
2134        .image_area_y_offset
2135        .div_ceil(y_shrink_factor);
2136    let region_x1 = output_region.x + output_region.width;
2137    let region_y1 = output_region.y + output_region.height;
2138    let output_width = output_region.width as usize;
2139
2140    if scale_x == 1 && scale_y == 1 {
2141        let region_rect_x0 = output_region.x + x_offset;
2142        let region_rect_y0 = output_region.y + y_offset;
2143        let region_rect_x1 = region_x1 + x_offset;
2144        let region_rect_y1 = region_y1 + y_offset;
2145        let copy_x0 = idwt_output
2146            .rect
2147            .x0
2148            .max(resolution_tile.rect.x0)
2149            .max(region_rect_x0);
2150        let copy_y0 = idwt_output
2151            .rect
2152            .y0
2153            .max(resolution_tile.rect.y0)
2154            .max(region_rect_y0);
2155        let copy_x1 = idwt_output
2156            .rect
2157            .x1
2158            .min(resolution_tile.rect.x1)
2159            .min(region_rect_x1);
2160        let copy_y1 = idwt_output
2161            .rect
2162            .y1
2163            .min(resolution_tile.rect.y1)
2164            .min(region_rect_y1);
2165
2166        let handled = if let Some(backend) = backend.as_deref_mut() {
2167            copy_x0 < copy_x1
2168                && copy_y0 < copy_y1
2169                && backend.decode_store_component(J2kStoreComponentJob {
2170                    input: &idwt_output.coefficients,
2171                    input_width: idwt_output.rect.width(),
2172                    source_x: copy_x0 - idwt_output.rect.x0,
2173                    source_y: copy_y0 - idwt_output.rect.y0,
2174                    copy_width: copy_x1 - copy_x0,
2175                    copy_height: copy_y1 - copy_y0,
2176                    output: &mut channel_data.container,
2177                    output_width: output_region.width,
2178                    output_x: copy_x0 - region_rect_x0,
2179                    output_y: copy_y0 - region_rect_y0,
2180                    addend: sign_shift,
2181                })?
2182        } else {
2183            false
2184        };
2185
2186        if handled {
2187            return Ok(());
2188        }
2189
2190        if sign_shift != 0.0 {
2191            for sample in idwt_output.coefficients.iter_mut() {
2192                *sample += sign_shift;
2193            }
2194        }
2195
2196        if copy_x0 < copy_x1 && copy_y0 < copy_y1 {
2197            let input_width = idwt_output.rect.width() as usize;
2198            let copy_width = (copy_x1 - copy_x0) as usize;
2199            for y in copy_y0..copy_y1 {
2200                let src_start = (y - idwt_output.rect.y0) as usize * input_width
2201                    + (copy_x0 - idwt_output.rect.x0) as usize;
2202                let dst_start = (y - region_rect_y0) as usize * output_width
2203                    + (copy_x0 - region_rect_x0) as usize;
2204                channel_data.container[dst_start..dst_start + copy_width]
2205                    .copy_from_slice(&idwt_output.coefficients[src_start..src_start + copy_width]);
2206            }
2207        }
2208
2209        return Ok(());
2210    }
2211
2212    if sign_shift != 0.0 {
2213        for sample in idwt_output.coefficients.iter_mut() {
2214            *sample += sign_shift;
2215        }
2216    }
2217
2218    for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
2219        let relative_y = (y - component_tile.rect.y0) as usize;
2220        let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
2221
2222        for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
2223            let relative_x = (x - component_tile.rect.x0) as usize;
2224            let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
2225
2226            let sample = idwt_output.coefficients
2227                [relative_y * idwt_output.rect.width() as usize + relative_x];
2228
2229            for x_position in u32::max(reference_grid_x, x_offset)
2230                ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
2231            {
2232                let image_x = x_position - x_offset;
2233                if image_x < output_region.x || image_x >= region_x1 {
2234                    continue;
2235                }
2236
2237                for y_position in u32::max(reference_grid_y, y_offset)
2238                    ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
2239                {
2240                    let image_y = y_position - y_offset;
2241                    if image_y < output_region.y || image_y >= region_y1 {
2242                        continue;
2243                    }
2244
2245                    let pos = (image_y - output_region.y) as usize * output_width
2246                        + (image_x - output_region.x) as usize;
2247                    channel_data.container[pos] = sample;
2248                }
2249            }
2250        }
2251    }
2252
2253    Ok(())
2254}
2255
2256#[cfg(test)]
2257mod tests {
2258    use super::{collect_classic_code_block_data, CodeBlock, DecompositionStorage, Layer, Segment};
2259    use crate::error::DecodingError;
2260    use crate::j2c::codestream::CodeBlockStyle;
2261    use crate::j2c::rect::IntRect;
2262    use alloc::vec;
2263
2264    fn classic_test_style() -> CodeBlockStyle {
2265        CodeBlockStyle {
2266            selective_arithmetic_coding_bypass: false,
2267            reset_context_probabilities: false,
2268            termination_on_each_pass: true,
2269            vertically_causal_context: false,
2270            segmentation_symbols: false,
2271            high_throughput_block_coding: false,
2272        }
2273    }
2274
2275    fn classic_test_code_block() -> CodeBlock {
2276        CodeBlock {
2277            rect: IntRect::from_xywh(0, 0, 1, 1),
2278            x_idx: 0,
2279            y_idx: 0,
2280            layers: 0..1,
2281            has_been_included: true,
2282            missing_bit_planes: 0,
2283            number_of_coding_passes: 3,
2284            l_block: 3,
2285            non_empty_layer_count: 1,
2286        }
2287    }
2288
2289    #[test]
2290    fn collect_classic_code_block_data_preserves_zero_length_segments() {
2291        let mut storage = DecompositionStorage::default();
2292        storage.layers.push(Layer {
2293            segments: Some(0..3),
2294        });
2295        storage.segments.push(Segment {
2296            idx: 0,
2297            coding_pases: 1,
2298            data_length: 1,
2299            data: &[0xAA],
2300        });
2301        storage.segments.push(Segment {
2302            idx: 1,
2303            coding_pases: 1,
2304            data_length: 0,
2305            data: &[],
2306        });
2307        storage.segments.push(Segment {
2308            idx: 2,
2309            coding_pases: 1,
2310            data_length: 1,
2311            data: &[0xBB],
2312        });
2313
2314        let (combined_data, segments) = collect_classic_code_block_data(
2315            &classic_test_code_block(),
2316            &classic_test_style(),
2317            &storage,
2318        )
2319        .expect("collect classic segments");
2320
2321        assert_eq!(combined_data, vec![0xAA, 0xBB]);
2322        assert_eq!(segments.len(), 3);
2323        assert_eq!(segments[0].data_offset, 0);
2324        assert_eq!(segments[0].data_length, 1);
2325        assert_eq!(segments[0].start_coding_pass, 0);
2326        assert_eq!(segments[0].end_coding_pass, 1);
2327        assert_eq!(segments[1].data_offset, 1);
2328        assert_eq!(segments[1].data_length, 0);
2329        assert_eq!(segments[1].start_coding_pass, 1);
2330        assert_eq!(segments[1].end_coding_pass, 2);
2331        assert_eq!(segments[2].data_offset, 1);
2332        assert_eq!(segments[2].data_length, 1);
2333        assert_eq!(segments[2].start_coding_pass, 2);
2334        assert_eq!(segments[2].end_coding_pass, 3);
2335    }
2336
2337    #[test]
2338    fn collect_classic_code_block_data_rejects_non_contiguous_segment_indices() {
2339        let mut storage = DecompositionStorage::default();
2340        storage.layers.push(Layer {
2341            segments: Some(0..2),
2342        });
2343        storage.segments.push(Segment {
2344            idx: 0,
2345            coding_pases: 1,
2346            data_length: 1,
2347            data: &[0xAA],
2348        });
2349        storage.segments.push(Segment {
2350            idx: 2,
2351            coding_pases: 2,
2352            data_length: 1,
2353            data: &[0xBB],
2354        });
2355
2356        let error = collect_classic_code_block_data(
2357            &classic_test_code_block(),
2358            &classic_test_style(),
2359            &storage,
2360        )
2361        .expect_err("non-contiguous segment indices must fail");
2362
2363        assert_eq!(error, DecodingError::CodeBlockDecodeFailure.into());
2364    }
2365
2366    #[test]
2367    fn auto_cpu_parallelism_enables_ht_sub_band_parallel_branch() {
2368        assert!(super::should_decode_ht_sub_band_in_parallel(
2369            super::CpuDecodeParallelism::Auto,
2370            16
2371        ));
2372    }
2373
2374    #[test]
2375    fn serial_cpu_parallelism_disables_ht_sub_band_parallel_branch() {
2376        assert!(!super::should_decode_ht_sub_band_in_parallel(
2377            super::CpuDecodeParallelism::Serial,
2378            16
2379        ));
2380    }
2381}