Skip to main content

hayro_jpeg2000/j2c/
decode.rs

1//! Decoding JPEG2000 code streams.
2//!
3//! This is the "core" module of the crate that orchestrates all
4//! stages in such a way that a given codestream is decoded into its
5//! component channels.
6
7use alloc::boxed::Box;
8use alloc::vec::Vec;
9
10use super::bitplane::{BitPlaneDecodeBuffers, BitPlaneDecodeContext};
11use super::build::{CodeBlock, Decomposition, Layer, Precinct, Segment, SubBand, SubBandType};
12use super::codestream::{ComponentInfo, Header, ProgressionOrder, QuantizationStyle};
13use super::idwt::IDWTOutput;
14use super::progression::{
15    IteratorInput, ProgressionData, component_position_resolution_layer_progression,
16    layer_resolution_component_position_progression,
17    position_component_resolution_layer_progression,
18    resolution_layer_component_position_progression,
19    resolution_position_component_layer_progression,
20};
21use super::tag_tree::TagNode;
22use super::tile::{ComponentTile, ResolutionTile, Tile};
23use super::{ComponentData, bitplane, build, idwt, mct, segment, tile};
24use crate::error::{DecodingError, Result, TileError, bail};
25use crate::j2c::segment::MAX_BITPLANE_COUNT;
26use crate::math::SimdBuffer;
27use crate::reader::BitReader;
28use core::ops::{DerefMut, Range};
29
30pub(crate) fn decode<'a>(
31    data: &'a [u8],
32    header: &'a Header<'a>,
33    ctx: &mut DecoderContext<'a>,
34) -> Result<()> {
35    let mut reader = BitReader::new(data);
36    let tiles = tile::parse(&mut reader, header)?;
37
38    if tiles.is_empty() {
39        bail!(TileError::Invalid);
40    }
41
42    ctx.reset(header, &tiles[0]);
43    let (tile_ctx, storage) = (&mut ctx.tile_decode_context, &mut ctx.storage);
44
45    for tile in &tiles {
46        ltrace!(
47            "tile {} rect [{},{} {}x{}]",
48            tile.idx,
49            tile.rect.x0,
50            tile.rect.y0,
51            tile.rect.width(),
52            tile.rect.height(),
53        );
54
55        let iter_input = IteratorInput::new(tile);
56
57        let progression_iterator: Box<dyn Iterator<Item = ProgressionData>> =
58            match tile.progression_order {
59                ProgressionOrder::LayerResolutionComponentPosition => {
60                    Box::new(layer_resolution_component_position_progression(iter_input))
61                }
62                ProgressionOrder::ResolutionLayerComponentPosition => {
63                    Box::new(resolution_layer_component_position_progression(iter_input))
64                }
65                ProgressionOrder::ResolutionPositionComponentLayer => Box::new(
66                    resolution_position_component_layer_progression(iter_input)
67                        .ok_or(DecodingError::InvalidProgressionIterator)?,
68                ),
69                ProgressionOrder::PositionComponentResolutionLayer => Box::new(
70                    position_component_resolution_layer_progression(iter_input)
71                        .ok_or(DecodingError::InvalidProgressionIterator)?,
72                ),
73                ProgressionOrder::ComponentPositionResolutionLayer => Box::new(
74                    component_position_resolution_layer_progression(iter_input)
75                        .ok_or(DecodingError::InvalidProgressionIterator)?,
76                ),
77            };
78
79        decode_tile(tile, header, progression_iterator, tile_ctx, storage)?;
80    }
81
82    // Note that this assumes that either all tiles have MCT or none of them.
83    // In theory, only some could have it... But hopefully no such cursed
84    // images exist!
85    if tiles[0].mct {
86        mct::apply_inverse(tile_ctx, &tiles[0].component_infos, header)?;
87        apply_sign_shift(tile_ctx, &header.component_infos);
88    }
89
90    Ok(())
91}
92
93/// A decoder context for decoding JPEG2000 images.
94#[derive(Default)]
95pub struct DecoderContext<'a> {
96    pub(crate) tile_decode_context: TileDecodeContext,
97    storage: DecompositionStorage<'a>,
98}
99
100impl DecoderContext<'_> {
101    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) {
102        self.tile_decode_context.reset(header, initial_tile);
103        self.storage.reset();
104    }
105}
106
107fn decode_tile<'a, 'b>(
108    tile: &'b Tile<'a>,
109    header: &Header<'_>,
110    progression_iterator: Box<dyn Iterator<Item = ProgressionData> + '_>,
111    tile_ctx: &mut TileDecodeContext,
112    storage: &mut DecompositionStorage<'a>,
113) -> Result<()> {
114    storage.reset();
115
116    // This is the method that orchestrates all steps.
117
118    // First, we build the decompositions, including their sub-bands, precincts
119    // and code blocks.
120    build::build(tile, storage)?;
121    // Next, we parse the layers/segments for each code block.
122    segment::parse(tile, progression_iterator, header, storage)?;
123    // We then decode the bitplanes of each code block, yielding the
124    // (possibly dequantized) coefficients of each code block.
125    decode_component_tile_bit_planes(tile, tile_ctx, storage, header)?;
126
127    // Unlike before, we interleave the apply_idwt and store stages
128    // for each component tile so we can reuse allocations better.
129    for (idx, component_info) in header.component_infos.iter().enumerate() {
130        // Next, we apply the inverse discrete wavelet transform.
131        idwt::apply(
132            storage,
133            tile_ctx,
134            idx,
135            header,
136            component_info.wavelet_transform(),
137        );
138        // Finally, we store the raw samples for the tile area in the correct
139        // location. Note that in case we have MCT, we are not applying it yet.
140        // It will be applied in the very end once all tiles have been processed.
141        // The reason we do this is that applying MCT requires access to the
142        // data from _all_ components. If we didn't defer this until the end
143        // we would have to collect the IDWT outputs of all components before
144        // applying it. By not applying MCT here, we can get away with doing
145        // IDWT and store on a per-component basis. Thus, we only need to
146        // store one IDWT output at a time, allowing for better reuse of
147        // allocations.
148        store(tile, header, tile_ctx, component_info, idx);
149    }
150
151    Ok(())
152}
153
154/// All decompositions for a single tile.
155#[derive(Clone)]
156pub(crate) struct TileDecompositions {
157    pub(crate) first_ll_sub_band: usize,
158    pub(crate) decompositions: Range<usize>,
159}
160
161impl TileDecompositions {
162    pub(crate) fn sub_band_iter(
163        &self,
164        resolution: u8,
165        decompositions: &[Decomposition],
166    ) -> SubBandIter {
167        let indices = if resolution == 0 {
168            [
169                self.first_ll_sub_band,
170                self.first_ll_sub_band,
171                self.first_ll_sub_band,
172            ]
173        } else {
174            decompositions[self.decompositions.clone()][resolution as usize - 1].sub_bands
175        };
176
177        SubBandIter {
178            next_idx: 0,
179            indices,
180            resolution,
181        }
182    }
183}
184
185#[derive(Clone)]
186pub(crate) struct SubBandIter {
187    resolution: u8,
188    next_idx: usize,
189    indices: [usize; 3],
190}
191
192impl Iterator for SubBandIter {
193    type Item = usize;
194
195    fn next(&mut self) -> Option<Self::Item> {
196        let value = if self.resolution == 0 {
197            if self.next_idx > 0 {
198                None
199            } else {
200                Some(self.indices[0])
201            }
202        } else if self.next_idx >= self.indices.len() {
203            None
204        } else {
205            Some(self.indices[self.next_idx])
206        };
207
208        self.next_idx += 1;
209
210        value
211    }
212}
213
214/// A buffer so that we can reuse allocations for layers/code blocks/etc.
215/// across different tiles.
216#[derive(Default)]
217pub(crate) struct DecompositionStorage<'a> {
218    pub(crate) segments: Vec<Segment<'a>>,
219    pub(crate) layers: Vec<Layer>,
220    pub(crate) code_blocks: Vec<CodeBlock>,
221    pub(crate) precincts: Vec<Precinct>,
222    pub(crate) tag_tree_nodes: Vec<TagNode>,
223    pub(crate) coefficients: Vec<f32>,
224    pub(crate) sub_bands: Vec<SubBand>,
225    pub(crate) decompositions: Vec<Decomposition>,
226    pub(crate) tile_decompositions: Vec<TileDecompositions>,
227}
228
229impl DecompositionStorage<'_> {
230    fn reset(&mut self) {
231        self.segments.clear();
232        self.layers.clear();
233        self.code_blocks.clear();
234        // No need to clear the coefficients, as they will be resized
235        // and then overridden.
236        // self.coefficients.clear();
237        self.precincts.clear();
238        self.sub_bands.clear();
239        self.decompositions.clear();
240        self.tile_decompositions.clear();
241        self.tag_tree_nodes.clear();
242    }
243}
244
245/// A reusable context used during the decoding of a single tile.
246///
247/// Some of the fields are temporary in nature and reset after moving on to the
248/// next tile, some contain global state.
249#[derive(Default)]
250pub(crate) struct TileDecodeContext {
251    /// A reusable buffer for the IDWT output.
252    pub(crate) idwt_output: IDWTOutput,
253    /// A scratch buffer used during IDWT.
254    pub(crate) idwt_scratch_buffer: Vec<f32>,
255    /// A reusable context for decoding code blocks.
256    pub(crate) bit_plane_decode_context: BitPlaneDecodeContext,
257    /// Reusable buffers for decoding bitplanes.
258    pub(crate) bit_plane_decode_buffers: BitPlaneDecodeBuffers,
259    /// The raw, decoded samples for each channel.
260    pub(crate) channel_data: Vec<ComponentData>,
261}
262
263impl TileDecodeContext {
264    /// Reset the context for processing a new image.
265    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) {
266        // Bitplane decode context and buffers will be reset in the
267        // corresponding methods. IDWT output and scratch buffer will be
268        // overridden on demand, so those don't need to be reset either.
269        self.channel_data.clear();
270
271        // TODO: SIMD Buffers should be reused across runs!
272        for info in &initial_tile.component_infos {
273            self.channel_data.push(ComponentData {
274                container: SimdBuffer::zeros(
275                    header.size_data.image_width() as usize
276                        * header.size_data.image_height() as usize,
277                ),
278                bit_depth: info.size_info.precision,
279            });
280        }
281    }
282}
283
284fn decode_component_tile_bit_planes<'a>(
285    tile: &Tile<'a>,
286    tile_ctx: &mut TileDecodeContext,
287    storage: &mut DecompositionStorage<'a>,
288    header: &Header<'_>,
289) -> Result<()> {
290    for (tile_decompositions_idx, component_info) in tile.component_infos.iter().enumerate() {
291        // Only decode the resolution levels we actually care about.
292        for resolution in
293            0..component_info.num_resolution_levels() - header.skipped_resolution_levels
294        {
295            let tile_composition = &storage.tile_decompositions[tile_decompositions_idx];
296            let sub_band_iter = tile_composition.sub_band_iter(resolution, &storage.decompositions);
297
298            for sub_band_idx in sub_band_iter {
299                decode_sub_band_bitplanes(
300                    sub_band_idx,
301                    resolution,
302                    component_info,
303                    tile_ctx,
304                    storage,
305                    header,
306                )?;
307            }
308        }
309    }
310
311    Ok(())
312}
313
314fn decode_sub_band_bitplanes(
315    sub_band_idx: usize,
316    resolution: u8,
317    component_info: &ComponentInfo,
318    tile_ctx: &mut TileDecodeContext,
319    storage: &mut DecompositionStorage<'_>,
320    header: &Header<'_>,
321) -> Result<()> {
322    let sub_band = &storage.sub_bands[sub_band_idx];
323
324    let dequantization_step = {
325        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
326        {
327            1.0
328        } else {
329            let (exponent, mantissa) =
330                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
331
332            let r_b = {
333                let log_gain = match sub_band.sub_band_type {
334                    SubBandType::LowLow => 0,
335                    SubBandType::LowHigh => 1,
336                    SubBandType::HighLow => 1,
337                    SubBandType::HighHigh => 2,
338                };
339
340                component_info.size_info.precision as u16 + log_gain
341            };
342
343            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
344        }
345    };
346
347    let num_bitplanes = {
348        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
349        // Equation (E-2)
350        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
351            .checked_add(exponent)
352            .and_then(|x| x.checked_sub(1))
353            .ok_or(DecodingError::InvalidBitplaneCount)?;
354
355        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
356            bail!(DecodingError::TooManyBitplanes);
357        }
358
359        num_bitplanes as u8
360    };
361
362    for precinct in sub_band
363        .precincts
364        .clone()
365        .map(|idx| &storage.precincts[idx])
366    {
367        for code_block in precinct
368            .code_blocks
369            .clone()
370            .map(|idx| &storage.code_blocks[idx])
371        {
372            bitplane::decode(
373                code_block,
374                sub_band.sub_band_type,
375                num_bitplanes,
376                &component_info.coding_style.parameters.code_block_style,
377                tile_ctx,
378                storage,
379                header.strict,
380            )?;
381
382            // Turn the signs and magnitudes into singular coefficients and
383            // copy them into the sub-band.
384
385            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
386            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
387
388            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
389            let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
390
391            for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
392                let out_row = &mut base_store[base_idx..];
393
394                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
395                    *output = coefficient.get() as f32;
396                    *output *= dequantization_step;
397                }
398
399                base_idx += sub_band.rect.width() as usize;
400            }
401        }
402    }
403
404    Ok(())
405}
406
407fn apply_sign_shift(tile_ctx: &mut TileDecodeContext, component_infos: &[ComponentInfo]) {
408    for (channel_data, component_info) in
409        tile_ctx.channel_data.iter_mut().zip(component_infos.iter())
410    {
411        for sample in channel_data.container.deref_mut() {
412            *sample += (1_u32 << (component_info.size_info.precision - 1)) as f32;
413        }
414    }
415}
416
417fn store<'a>(
418    tile: &'a Tile<'a>,
419    header: &Header<'_>,
420    tile_ctx: &mut TileDecodeContext,
421    component_info: &ComponentInfo,
422    component_idx: usize,
423) {
424    let channel_data = &mut tile_ctx.channel_data[component_idx];
425    let idwt_output = &mut tile_ctx.idwt_output;
426
427    let component_tile = ComponentTile::new(tile, component_info);
428    let resolution_tile = ResolutionTile::new(
429        component_tile,
430        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
431    );
432
433    // If we have MCT, the sign shift needs to be applied after the
434    // MCT transform. We take care of that in the main decode method.
435    // Otherwise, we might as well just apply it now.
436    if !tile.mct {
437        for sample in idwt_output.coefficients.iter_mut() {
438            *sample += (1_u32 << (component_info.size_info.precision - 1)) as f32;
439        }
440    }
441
442    let (scale_x, scale_y) = (
443        component_info.size_info.horizontal_resolution,
444        component_info.size_info.vertical_resolution,
445    );
446
447    let (image_x_offset, image_y_offset) = (
448        header.size_data.image_area_x_offset,
449        header.size_data.image_area_y_offset,
450    );
451
452    if scale_x == 1 && scale_y == 1 {
453        // If no sub-sampling, use a fast path where we copy rows of coefficients
454        // at once.
455
456        // The rect of the IDWT output corresponds to the rect of the highest
457        // decomposition level of the tile, which is usually not 1:1 aligned
458        // with the actual tile rectangle. We also need to account for the
459        // offset of the reference grid.
460
461        let skip_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
462        let skip_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
463
464        let input_row_iter = idwt_output
465            .coefficients
466            .chunks_exact(idwt_output.rect.width() as usize)
467            .skip(skip_y as usize)
468            .take(idwt_output.rect.height() as usize);
469
470        let output_row_iter = channel_data
471            .container
472            .chunks_exact_mut(header.size_data.image_width() as usize)
473            .skip(resolution_tile.rect.y0.saturating_sub(image_y_offset) as usize);
474
475        for (input_row, output_row) in input_row_iter.zip(output_row_iter) {
476            let input_row = &input_row[skip_x as usize..];
477            let output_row = &mut output_row
478                [resolution_tile.rect.x0.saturating_sub(image_x_offset) as usize..]
479                [..input_row.len()];
480
481            output_row.copy_from_slice(input_row);
482        }
483    } else {
484        let image_width = header.size_data.image_width();
485        let image_height = header.size_data.image_height();
486
487        let x_shrink_factor = header.size_data.x_shrink_factor;
488        let y_shrink_factor = header.size_data.y_shrink_factor;
489
490        let x_offset = header
491            .size_data
492            .image_area_x_offset
493            .div_ceil(x_shrink_factor);
494        let y_offset = header
495            .size_data
496            .image_area_y_offset
497            .div_ceil(y_shrink_factor);
498
499        // Otherwise, copy sample by sample.
500        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
501            let relative_y = (y - component_tile.rect.y0) as usize;
502            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
503
504            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
505                let relative_x = (x - component_tile.rect.x0) as usize;
506                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
507
508                let sample = idwt_output.coefficients
509                    [relative_y * idwt_output.rect.width() as usize + relative_x];
510
511                for x_position in u32::max(reference_grid_x, x_offset)
512                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
513                {
514                    for y_position in u32::max(reference_grid_y, y_offset)
515                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
516                    {
517                        let pos = (y_position - y_offset) as usize * image_width as usize
518                            + (x_position - x_offset) as usize;
519
520                        channel_data.container[pos] = sample;
521                    }
522                }
523            }
524        }
525    }
526}