Skip to main content

dicom_toolkit_jpeg2000/j2c/
decode.rs

1//! Decoding JPEG2000 code streams.
2//!
3//! This is the "core" module of the crate that orchestrates all
4//! stages in such a way that a given codestream is decoded into its
5//! component channels.
6
7use alloc::boxed::Box;
8use alloc::vec::Vec;
9
10use super::bitplane::{BitPlaneDecodeBuffers, BitPlaneDecodeContext};
11use super::build::{CodeBlock, Decomposition, Layer, Precinct, Segment, SubBand, SubBandType};
12use super::codestream::{ComponentInfo, Header, ProgressionOrder, QuantizationStyle};
13use super::ht_block_decode::{self, HtBlockDecodeContext};
14use super::idwt::IDWTOutput;
15use super::progression::{
16    component_position_resolution_layer_progression,
17    layer_resolution_component_position_progression,
18    position_component_resolution_layer_progression,
19    resolution_layer_component_position_progression,
20    resolution_position_component_layer_progression, IteratorInput, ProgressionData,
21};
22use super::tag_tree::TagNode;
23use super::tile::{ComponentTile, ResolutionTile, Tile};
24use super::{bitplane, build, idwt, mct, segment, tile, ComponentData};
25use crate::error::{bail, DecodingError, Result, TileError};
26use crate::j2c::segment::MAX_BITPLANE_COUNT;
27use crate::math::SimdBuffer;
28use crate::reader::BitReader;
29use core::ops::{DerefMut, Range};
30
31pub(crate) fn decode<'a>(
32    data: &'a [u8],
33    header: &'a Header<'a>,
34    ctx: &mut DecoderContext<'a>,
35) -> Result<()> {
36    let mut reader = BitReader::new(data);
37    let tiles = tile::parse(&mut reader, header)?;
38
39    if tiles.is_empty() {
40        bail!(TileError::Invalid);
41    }
42
43    ctx.reset(header, &tiles[0]);
44    let (tile_ctx, storage) = (&mut ctx.tile_decode_context, &mut ctx.storage);
45
46    for tile in &tiles {
47        ltrace!(
48            "tile {} rect [{},{} {}x{}]",
49            tile.idx,
50            tile.rect.x0,
51            tile.rect.y0,
52            tile.rect.width(),
53            tile.rect.height(),
54        );
55
56        let iter_input = IteratorInput::new(tile);
57
58        let progression_iterator: Box<dyn Iterator<Item = ProgressionData>> =
59            match tile.progression_order {
60                ProgressionOrder::LayerResolutionComponentPosition => {
61                    Box::new(layer_resolution_component_position_progression(iter_input))
62                }
63                ProgressionOrder::ResolutionLayerComponentPosition => {
64                    Box::new(resolution_layer_component_position_progression(iter_input))
65                }
66                ProgressionOrder::ResolutionPositionComponentLayer => Box::new(
67                    resolution_position_component_layer_progression(iter_input)
68                        .ok_or(DecodingError::InvalidProgressionIterator)?,
69                ),
70                ProgressionOrder::PositionComponentResolutionLayer => Box::new(
71                    position_component_resolution_layer_progression(iter_input)
72                        .ok_or(DecodingError::InvalidProgressionIterator)?,
73                ),
74                ProgressionOrder::ComponentPositionResolutionLayer => Box::new(
75                    component_position_resolution_layer_progression(iter_input)
76                        .ok_or(DecodingError::InvalidProgressionIterator)?,
77                ),
78            };
79
80        decode_tile(tile, header, progression_iterator, tile_ctx, storage)?;
81    }
82
83    // Note that this assumes that either all tiles have MCT or none of them.
84    // In theory, only some could have it... But hopefully no such cursed
85    // images exist!
86    if tiles[0].mct {
87        mct::apply_inverse(tile_ctx, &tiles[0].component_infos, header)?;
88        apply_sign_shift(tile_ctx, &header.component_infos);
89    }
90
91    Ok(())
92}
93
94/// A decoder context for decoding JPEG2000 images.
95#[derive(Default)]
96pub struct DecoderContext<'a> {
97    pub(crate) tile_decode_context: TileDecodeContext,
98    storage: DecompositionStorage<'a>,
99}
100
101impl DecoderContext<'_> {
102    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) {
103        self.tile_decode_context.reset(header, initial_tile);
104        self.storage.reset();
105    }
106}
107
108fn decode_tile<'a, 'b>(
109    tile: &'b Tile<'a>,
110    header: &Header<'_>,
111    progression_iterator: Box<dyn Iterator<Item = ProgressionData> + '_>,
112    tile_ctx: &mut TileDecodeContext,
113    storage: &mut DecompositionStorage<'a>,
114) -> Result<()> {
115    storage.reset();
116
117    // This is the method that orchestrates all steps.
118
119    // First, we build the decompositions, including their sub-bands, precincts
120    // and code blocks.
121    build::build(tile, storage)?;
122    // Next, we parse the layers/segments for each code block.
123    segment::parse(tile, progression_iterator, header, storage)?;
124    // We then decode the bitplanes of each code block, yielding the
125    // (possibly dequantized) coefficients of each code block.
126    decode_component_tile_bit_planes(tile, tile_ctx, storage, header)?;
127
128    // Unlike before, we interleave the apply_idwt and store stages
129    // for each component tile so we can reuse allocations better.
130    for (idx, component_info) in header.component_infos.iter().enumerate() {
131        // Next, we apply the inverse discrete wavelet transform.
132        idwt::apply(
133            storage,
134            tile_ctx,
135            idx,
136            header,
137            component_info.wavelet_transform(),
138        );
139        // Finally, we store the raw samples for the tile area in the correct
140        // location. Note that in case we have MCT, we are not applying it yet.
141        // It will be applied in the very end once all tiles have been processed.
142        // The reason we do this is that applying MCT requires access to the
143        // data from _all_ components. If we didn't defer this until the end
144        // we would have to collect the IDWT outputs of all components before
145        // applying it. By not applying MCT here, we can get away with doing
146        // IDWT and store on a per-component basis. Thus, we only need to
147        // store one IDWT output at a time, allowing for better reuse of
148        // allocations.
149        store(tile, header, tile_ctx, component_info, idx);
150    }
151
152    Ok(())
153}
154
155/// All decompositions for a single tile.
156#[derive(Clone)]
157pub(crate) struct TileDecompositions {
158    pub(crate) first_ll_sub_band: usize,
159    pub(crate) decompositions: Range<usize>,
160}
161
162impl TileDecompositions {
163    pub(crate) fn sub_band_iter(
164        &self,
165        resolution: u8,
166        decompositions: &[Decomposition],
167    ) -> SubBandIter {
168        let indices = if resolution == 0 {
169            [
170                self.first_ll_sub_band,
171                self.first_ll_sub_band,
172                self.first_ll_sub_band,
173            ]
174        } else {
175            decompositions[self.decompositions.clone()][resolution as usize - 1].sub_bands
176        };
177
178        SubBandIter {
179            next_idx: 0,
180            indices,
181            resolution,
182        }
183    }
184}
185
186#[derive(Clone)]
187pub(crate) struct SubBandIter {
188    resolution: u8,
189    next_idx: usize,
190    indices: [usize; 3],
191}
192
193impl Iterator for SubBandIter {
194    type Item = usize;
195
196    fn next(&mut self) -> Option<Self::Item> {
197        let value = if self.resolution == 0 {
198            if self.next_idx > 0 {
199                None
200            } else {
201                Some(self.indices[0])
202            }
203        } else if self.next_idx >= self.indices.len() {
204            None
205        } else {
206            Some(self.indices[self.next_idx])
207        };
208
209        self.next_idx += 1;
210
211        value
212    }
213}
214
215/// A buffer so that we can reuse allocations for layers/code blocks/etc.
216/// across different tiles.
217#[derive(Default)]
218pub(crate) struct DecompositionStorage<'a> {
219    pub(crate) segments: Vec<Segment<'a>>,
220    pub(crate) layers: Vec<Layer>,
221    pub(crate) code_blocks: Vec<CodeBlock>,
222    pub(crate) precincts: Vec<Precinct>,
223    pub(crate) tag_tree_nodes: Vec<TagNode>,
224    pub(crate) coefficients: Vec<f32>,
225    pub(crate) sub_bands: Vec<SubBand>,
226    pub(crate) decompositions: Vec<Decomposition>,
227    pub(crate) tile_decompositions: Vec<TileDecompositions>,
228}
229
230impl DecompositionStorage<'_> {
231    fn reset(&mut self) {
232        self.segments.clear();
233        self.layers.clear();
234        self.code_blocks.clear();
235        // No need to clear the coefficients, as they will be resized
236        // and then overridden.
237        // self.coefficients.clear();
238        self.precincts.clear();
239        self.sub_bands.clear();
240        self.decompositions.clear();
241        self.tile_decompositions.clear();
242        self.tag_tree_nodes.clear();
243    }
244}
245
246/// A reusable context used during the decoding of a single tile.
247///
248/// Some of the fields are temporary in nature and reset after moving on to the
249/// next tile, some contain global state.
250#[derive(Default)]
251pub(crate) struct TileDecodeContext {
252    /// A reusable buffer for the IDWT output.
253    pub(crate) idwt_output: IDWTOutput,
254    /// A scratch buffer used during IDWT.
255    pub(crate) idwt_scratch_buffer: Vec<f32>,
256    /// A reusable context for decoding code blocks.
257    pub(crate) bit_plane_decode_context: BitPlaneDecodeContext,
258    /// Reusable buffers for decoding bitplanes.
259    pub(crate) bit_plane_decode_buffers: BitPlaneDecodeBuffers,
260    /// A reusable context for decoding HTJ2K code blocks.
261    pub(crate) ht_block_decode_context: HtBlockDecodeContext,
262    /// The raw, decoded samples for each channel.
263    pub(crate) channel_data: Vec<ComponentData>,
264}
265
266impl TileDecodeContext {
267    /// Reset the context for processing a new image.
268    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) {
269        // Bitplane decode context and buffers will be reset in the
270        // corresponding methods. IDWT output and scratch buffer will be
271        // overridden on demand, so those don't need to be reset either.
272        self.channel_data.clear();
273
274        // TODO: SIMD Buffers should be reused across runs!
275        for info in &initial_tile.component_infos {
276            self.channel_data.push(ComponentData {
277                container: SimdBuffer::zeros(
278                    header.size_data.image_width() as usize
279                        * header.size_data.image_height() as usize,
280                ),
281                bit_depth: info.size_info.precision,
282            });
283        }
284    }
285}
286
287fn decode_component_tile_bit_planes<'a>(
288    tile: &Tile<'a>,
289    tile_ctx: &mut TileDecodeContext,
290    storage: &mut DecompositionStorage<'a>,
291    header: &Header<'_>,
292) -> Result<()> {
293    for (tile_decompositions_idx, component_info) in tile.component_infos.iter().enumerate() {
294        // Only decode the resolution levels we actually care about.
295        for resolution in
296            0..component_info.num_resolution_levels() - header.skipped_resolution_levels
297        {
298            let tile_composition = &storage.tile_decompositions[tile_decompositions_idx];
299            let sub_band_iter = tile_composition.sub_band_iter(resolution, &storage.decompositions);
300
301            for sub_band_idx in sub_band_iter {
302                decode_sub_band_bitplanes(
303                    sub_band_idx,
304                    resolution,
305                    component_info,
306                    tile_ctx,
307                    storage,
308                    header,
309                )?;
310            }
311        }
312    }
313
314    Ok(())
315}
316
317fn decode_sub_band_bitplanes(
318    sub_band_idx: usize,
319    resolution: u8,
320    component_info: &ComponentInfo,
321    tile_ctx: &mut TileDecodeContext,
322    storage: &mut DecompositionStorage<'_>,
323    header: &Header<'_>,
324) -> Result<()> {
325    let sub_band = &storage.sub_bands[sub_band_idx];
326
327    let dequantization_step = {
328        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
329        {
330            1.0
331        } else {
332            let (exponent, mantissa) =
333                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
334
335            let r_b = {
336                let log_gain = match sub_band.sub_band_type {
337                    SubBandType::LowLow => 0,
338                    SubBandType::LowHigh => 1,
339                    SubBandType::HighLow => 1,
340                    SubBandType::HighHigh => 2,
341                };
342
343                component_info.size_info.precision as u16 + log_gain
344            };
345
346            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
347        }
348    };
349
350    let num_bitplanes = {
351        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
352        // Equation (E-2)
353        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
354            .checked_add(exponent)
355            .and_then(|x| x.checked_sub(1))
356            .ok_or(DecodingError::InvalidBitplaneCount)?;
357
358        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
359            bail!(DecodingError::TooManyBitplanes);
360        }
361
362        num_bitplanes as u8
363    };
364
365    for precinct in sub_band
366        .precincts
367        .clone()
368        .map(|idx| &storage.precincts[idx])
369    {
370        for code_block in precinct
371            .code_blocks
372            .clone()
373            .map(|idx| &storage.code_blocks[idx])
374        {
375            // Turn the signs and magnitudes into singular coefficients and
376            // copy them into the sub-band.
377
378            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
379            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
380
381            if component_info
382                .coding_style
383                .parameters
384                .code_block_style
385                .uses_high_throughput_block_coding()
386            {
387                ht_block_decode::decode(
388                    code_block,
389                    num_bitplanes,
390                    component_info
391                        .coding_style
392                        .parameters
393                        .code_block_style
394                        .vertically_causal_context,
395                    &mut tile_ctx.ht_block_decode_context,
396                    storage,
397                    header.strict,
398                )?;
399
400                let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
401                let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
402
403                for coefficients in tile_ctx.ht_block_decode_context.coefficient_rows() {
404                    let out_row = &mut base_store[base_idx..];
405
406                    for (output, coefficient) in
407                        out_row.iter_mut().zip(coefficients.iter().copied())
408                    {
409                        *output =
410                            ht_block_decode::coefficient_to_i32(coefficient, num_bitplanes) as f32;
411                        *output *= dequantization_step;
412                    }
413
414                    base_idx += sub_band.rect.width() as usize;
415                }
416            } else {
417                bitplane::decode(
418                    code_block,
419                    sub_band.sub_band_type,
420                    num_bitplanes,
421                    &component_info.coding_style.parameters.code_block_style,
422                    tile_ctx,
423                    storage,
424                    header.strict,
425                )?;
426
427                let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
428                let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
429
430                for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
431                    let out_row = &mut base_store[base_idx..];
432
433                    for (output, coefficient) in
434                        out_row.iter_mut().zip(coefficients.iter().copied())
435                    {
436                        *output = coefficient.get() as f32;
437                        *output *= dequantization_step;
438                    }
439
440                    base_idx += sub_band.rect.width() as usize;
441                }
442            }
443        }
444    }
445
446    Ok(())
447}
448
449fn apply_sign_shift(tile_ctx: &mut TileDecodeContext, component_infos: &[ComponentInfo]) {
450    for (channel_data, component_info) in
451        tile_ctx.channel_data.iter_mut().zip(component_infos.iter())
452    {
453        for sample in channel_data.container.deref_mut() {
454            *sample += (1_u32 << (component_info.size_info.precision - 1)) as f32;
455        }
456    }
457}
458
459fn store<'a>(
460    tile: &'a Tile<'a>,
461    header: &Header<'_>,
462    tile_ctx: &mut TileDecodeContext,
463    component_info: &ComponentInfo,
464    component_idx: usize,
465) {
466    let channel_data = &mut tile_ctx.channel_data[component_idx];
467    let idwt_output = &mut tile_ctx.idwt_output;
468
469    let component_tile = ComponentTile::new(tile, component_info);
470    let resolution_tile = ResolutionTile::new(
471        component_tile,
472        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
473    );
474
475    // If we have MCT, the sign shift needs to be applied after the
476    // MCT transform. We take care of that in the main decode method.
477    // Otherwise, we might as well just apply it now.
478    if !tile.mct {
479        for sample in idwt_output.coefficients.iter_mut() {
480            *sample += (1_u32 << (component_info.size_info.precision - 1)) as f32;
481        }
482    }
483
484    let (scale_x, scale_y) = (
485        component_info.size_info.horizontal_resolution,
486        component_info.size_info.vertical_resolution,
487    );
488
489    let (image_x_offset, image_y_offset) = (
490        header.size_data.image_area_x_offset,
491        header.size_data.image_area_y_offset,
492    );
493
494    if scale_x == 1 && scale_y == 1 {
495        // If no sub-sampling, use a fast path where we copy rows of coefficients
496        // at once.
497
498        // The rect of the IDWT output corresponds to the rect of the highest
499        // decomposition level of the tile, which is usually not 1:1 aligned
500        // with the actual tile rectangle. We also need to account for the
501        // offset of the reference grid.
502
503        let skip_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
504        let skip_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
505
506        let input_row_iter = idwt_output
507            .coefficients
508            .chunks_exact(idwt_output.rect.width() as usize)
509            .skip(skip_y as usize)
510            .take(idwt_output.rect.height() as usize);
511
512        let output_row_iter = channel_data
513            .container
514            .chunks_exact_mut(header.size_data.image_width() as usize)
515            .skip(resolution_tile.rect.y0.saturating_sub(image_y_offset) as usize);
516
517        for (input_row, output_row) in input_row_iter.zip(output_row_iter) {
518            let input_row = &input_row[skip_x as usize..];
519            let output_row = &mut output_row
520                [resolution_tile.rect.x0.saturating_sub(image_x_offset) as usize..]
521                [..input_row.len()];
522
523            output_row.copy_from_slice(input_row);
524        }
525    } else {
526        let image_width = header.size_data.image_width();
527        let image_height = header.size_data.image_height();
528
529        let x_shrink_factor = header.size_data.x_shrink_factor;
530        let y_shrink_factor = header.size_data.y_shrink_factor;
531
532        let x_offset = header
533            .size_data
534            .image_area_x_offset
535            .div_ceil(x_shrink_factor);
536        let y_offset = header
537            .size_data
538            .image_area_y_offset
539            .div_ceil(y_shrink_factor);
540
541        // Otherwise, copy sample by sample.
542        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
543            let relative_y = (y - component_tile.rect.y0) as usize;
544            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
545
546            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
547                let relative_x = (x - component_tile.rect.x0) as usize;
548                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
549
550                let sample = idwt_output.coefficients
551                    [relative_y * idwt_output.rect.width() as usize + relative_x];
552
553                for x_position in u32::max(reference_grid_x, x_offset)
554                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
555                {
556                    for y_position in u32::max(reference_grid_y, y_offset)
557                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
558                    {
559                        let pos = (y_position - y_offset) as usize * image_width as usize
560                            + (x_position - x_offset) as usize;
561
562                        channel_data.container[pos] = sample;
563                    }
564                }
565            }
566        }
567    }
568}