Skip to main content

hayro_jpeg2000/j2c/
decode.rs

1//! Decoding JPEG2000 code streams.
2//!
3//! This is the "core" module of the crate that orchestrates all
4//! stages in such a way that a given codestream is decoded into its
5//! component channels.
6
7use alloc::boxed::Box;
8use alloc::vec::Vec;
9
10use super::bitplane::{BitPlaneDecodeBuffers, BitPlaneDecodeContext};
11use super::build::{CodeBlock, Decomposition, Layer, Precinct, Segment, SubBand, SubBandType};
12use super::codestream::{ComponentInfo, Header, ProgressionOrder, QuantizationStyle};
13use super::idwt::IDWTOutput;
14use super::progression::{
15    IteratorInput, ProgressionData, component_position_resolution_layer_progression,
16    layer_resolution_component_position_progression,
17    position_component_resolution_layer_progression,
18    resolution_layer_component_position_progression,
19    resolution_position_component_layer_progression,
20};
21use super::tag_tree::TagNode;
22use super::tile::{ComponentTile, ResolutionTile, Tile};
23use super::{ComponentData, bitplane, build, idwt, mct, segment, tile};
24use crate::error::{DecodingError, Result, TileError, bail};
25use crate::j2c::segment::MAX_BITPLANE_COUNT;
26use crate::math::SimdBuffer;
27use crate::reader::BitReader;
28use core::ops::Range;
29
30pub(crate) fn decode<'a>(
31    data: &'a [u8],
32    header: &'a Header<'a>,
33    ctx: &mut DecoderContext<'a>,
34) -> Result<()> {
35    let mut reader = BitReader::new(data);
36    let tiles = tile::parse(&mut reader, header)?;
37
38    if tiles.is_empty() {
39        bail!(TileError::Invalid);
40    }
41
42    ctx.reset(header, &tiles[0]);
43
44    for tile in &tiles {
45        trace!(
46            "tile {} rect [{},{} {}x{}]",
47            tile.idx,
48            tile.rect.x0,
49            tile.rect.y0,
50            tile.rect.width(),
51            tile.rect.height(),
52        );
53
54        let iter_input = IteratorInput::new(tile);
55
56        let progression_iterator: Box<dyn Iterator<Item = ProgressionData>> =
57            match tile.progression_order {
58                ProgressionOrder::LayerResolutionComponentPosition => {
59                    Box::new(layer_resolution_component_position_progression(iter_input))
60                }
61                ProgressionOrder::ResolutionLayerComponentPosition => {
62                    Box::new(resolution_layer_component_position_progression(iter_input))
63                }
64                ProgressionOrder::ResolutionPositionComponentLayer => Box::new(
65                    resolution_position_component_layer_progression(iter_input)
66                        .ok_or(DecodingError::InvalidProgressionIterator)?,
67                ),
68                ProgressionOrder::PositionComponentResolutionLayer => Box::new(
69                    position_component_resolution_layer_progression(iter_input)
70                        .ok_or(DecodingError::InvalidProgressionIterator)?,
71                ),
72                ProgressionOrder::ComponentPositionResolutionLayer => Box::new(
73                    component_position_resolution_layer_progression(iter_input)
74                        .ok_or(DecodingError::InvalidProgressionIterator)?,
75                ),
76            };
77
78        decode_tile(
79            tile,
80            header,
81            progression_iterator,
82            &mut ctx.tile_decode_context,
83            &mut ctx.channel_data,
84            &mut ctx.storage,
85        )?;
86    }
87
88    // Note that this assumes that either all tiles have MCT or none of them.
89    // In theory, only some could have it... But hopefully no such cursed
90    // images exist!
91    if tiles[0].mct {
92        mct::apply_inverse(&mut ctx.channel_data, &tiles[0].component_infos, header)?;
93    }
94
95    apply_sign_shift(&mut ctx.channel_data, &header.component_infos);
96
97    Ok(())
98}
99
100/// A decoder context for decoding JPEG2000 images.
101#[derive(Default)]
102pub struct DecoderContext<'a> {
103    tile_decode_context: TileDecodeContext,
104    /// The raw, decoded samples for each channel.
105    pub(crate) channel_data: Vec<ComponentData>,
106    storage: DecompositionStorage<'a>,
107}
108
109impl DecoderContext<'_> {
110    fn reset(&mut self, header: &Header<'_>, initial_tile: &Tile<'_>) {
111        self.tile_decode_context.reset();
112        self.storage.reset();
113
114        self.channel_data.clear();
115        // TODO: SIMD Buffers should be reused across runs!
116        for info in &initial_tile.component_infos {
117            self.channel_data.push(ComponentData {
118                container: SimdBuffer::zeros(
119                    header.size_data.image_width() as usize
120                        * header.size_data.image_height() as usize,
121                ),
122                bit_depth: info.size_info.precision,
123            });
124        }
125    }
126}
127
128fn decode_tile<'a, 'b>(
129    tile: &'b Tile<'a>,
130    header: &Header<'_>,
131    progression_iterator: Box<dyn Iterator<Item = ProgressionData> + '_>,
132    tile_ctx: &mut TileDecodeContext,
133    channel_data: &mut [ComponentData],
134    storage: &mut DecompositionStorage<'a>,
135) -> Result<()> {
136    storage.reset();
137
138    // This is the method that orchestrates all steps.
139
140    // First, we build the decompositions, including their sub-bands, precincts
141    // and code blocks.
142    build::build(tile, storage)?;
143    // Next, we parse the layers/segments for each code block.
144    segment::parse(tile, progression_iterator, header, storage)?;
145    // We then decode the bitplanes of each code block, yielding the
146    // (possibly dequantized) coefficients of each code block.
147    decode_component_tile_bit_planes(tile, tile_ctx, storage, header)?;
148
149    // Unlike before, we interleave the apply_idwt and store stages
150    // for each component tile so we can reuse allocations better.
151    for (idx, component_info) in header.component_infos.iter().enumerate() {
152        // Next, we apply the inverse discrete wavelet transform.
153        idwt::apply(
154            storage,
155            tile_ctx,
156            idx,
157            header,
158            component_info.wavelet_transform(),
159        );
160        // Finally, we store the raw samples for the tile area in the correct
161        // location. Note that in case we have MCT, we are not applying it yet.
162        // It will be applied in the very end once all tiles have been processed.
163        // The reason we do this is that applying MCT requires access to the
164        // data from _all_ components. If we didn't defer this until the end
165        // we would have to collect the IDWT outputs of all components before
166        // applying it. By not applying MCT here, we can get away with doing
167        // IDWT and store on a per-component basis. Thus, we only need to
168        // store one IDWT output at a time, allowing for better reuse of
169        // allocations.
170        store(
171            tile,
172            header,
173            tile_ctx,
174            &mut channel_data[idx],
175            component_info,
176        );
177    }
178
179    Ok(())
180}
181
182/// All decompositions for a single tile.
183#[derive(Clone)]
184pub(crate) struct TileDecompositions {
185    pub(crate) first_ll_sub_band: usize,
186    pub(crate) decompositions: Range<usize>,
187}
188
189impl TileDecompositions {
190    pub(crate) fn sub_band_iter(
191        &self,
192        resolution: u8,
193        decompositions: &[Decomposition],
194    ) -> SubBandIter {
195        let indices = if resolution == 0 {
196            [
197                self.first_ll_sub_band,
198                self.first_ll_sub_band,
199                self.first_ll_sub_band,
200            ]
201        } else {
202            decompositions[self.decompositions.clone()][resolution as usize - 1].sub_bands
203        };
204
205        SubBandIter {
206            next_idx: 0,
207            indices,
208            resolution,
209        }
210    }
211}
212
213#[derive(Clone)]
214pub(crate) struct SubBandIter {
215    resolution: u8,
216    next_idx: usize,
217    indices: [usize; 3],
218}
219
220impl Iterator for SubBandIter {
221    type Item = usize;
222
223    fn next(&mut self) -> Option<Self::Item> {
224        let value = if self.resolution == 0 {
225            if self.next_idx > 0 {
226                None
227            } else {
228                Some(self.indices[0])
229            }
230        } else if self.next_idx >= self.indices.len() {
231            None
232        } else {
233            Some(self.indices[self.next_idx])
234        };
235
236        self.next_idx += 1;
237
238        value
239    }
240}
241
242/// A buffer so that we can reuse allocations for layers/code blocks/etc.
243/// across different tiles.
244#[derive(Default)]
245pub(crate) struct DecompositionStorage<'a> {
246    pub(crate) segments: Vec<Segment<'a>>,
247    pub(crate) layers: Vec<Layer>,
248    pub(crate) code_blocks: Vec<CodeBlock>,
249    pub(crate) precincts: Vec<Precinct>,
250    pub(crate) tag_tree_nodes: Vec<TagNode>,
251    pub(crate) coefficients: Vec<f32>,
252    pub(crate) sub_bands: Vec<SubBand>,
253    pub(crate) decompositions: Vec<Decomposition>,
254    pub(crate) tile_decompositions: Vec<TileDecompositions>,
255}
256
257impl DecompositionStorage<'_> {
258    fn reset(&mut self) {
259        self.segments.clear();
260        self.layers.clear();
261        self.code_blocks.clear();
262        // No need to clear the coefficients, as they will be resized
263        // and then overridden.
264        // self.coefficients.clear();
265        self.precincts.clear();
266        self.sub_bands.clear();
267        self.decompositions.clear();
268        self.tile_decompositions.clear();
269        self.tag_tree_nodes.clear();
270    }
271}
272
273/// A reusable context used during the decoding of a single tile.
274///
275/// Some of the fields are temporary in nature and reset after moving on to the
276/// next tile, some contain global state.
277#[derive(Default)]
278pub(crate) struct TileDecodeContext {
279    /// A reusable buffer for the IDWT output.
280    pub(crate) idwt_output: IDWTOutput,
281    /// A scratch buffer used during IDWT.
282    pub(crate) idwt_scratch_buffer: Vec<f32>,
283    /// A reusable context for decoding code blocks.
284    pub(crate) bit_plane_decode_context: BitPlaneDecodeContext,
285    /// Reusable buffers for decoding bitplanes.
286    pub(crate) bit_plane_decode_buffers: BitPlaneDecodeBuffers,
287}
288
289impl TileDecodeContext {
290    fn reset(&mut self) {
291        // This method doesn't do anything, just keeping it there in case
292        // it's needed in the future.
293        // Bitplane decode context and buffers will be reset in the
294        // corresponding methods. IDWT output and scratch buffer will be
295        // overridden on demand, so those don't need to be reset either.
296    }
297}
298
299fn decode_component_tile_bit_planes<'a>(
300    tile: &Tile<'a>,
301    tile_ctx: &mut TileDecodeContext,
302    storage: &mut DecompositionStorage<'a>,
303    header: &Header<'_>,
304) -> Result<()> {
305    for (tile_decompositions_idx, component_info) in tile.component_infos.iter().enumerate() {
306        // Only decode the resolution levels we actually care about.
307        for resolution in
308            0..component_info.num_resolution_levels() - header.skipped_resolution_levels
309        {
310            let tile_composition = &storage.tile_decompositions[tile_decompositions_idx];
311            let sub_band_iter = tile_composition.sub_band_iter(resolution, &storage.decompositions);
312
313            for sub_band_idx in sub_band_iter {
314                decode_sub_band_bitplanes(
315                    sub_band_idx,
316                    resolution,
317                    component_info,
318                    tile_ctx,
319                    storage,
320                    header,
321                )?;
322            }
323        }
324    }
325
326    Ok(())
327}
328
329fn decode_sub_band_bitplanes(
330    sub_band_idx: usize,
331    resolution: u8,
332    component_info: &ComponentInfo,
333    tile_ctx: &mut TileDecodeContext,
334    storage: &mut DecompositionStorage<'_>,
335    header: &Header<'_>,
336) -> Result<()> {
337    let sub_band = &storage.sub_bands[sub_band_idx];
338
339    let dequantization_step = {
340        if component_info.quantization_info.quantization_style == QuantizationStyle::NoQuantization
341        {
342            1.0
343        } else {
344            let (exponent, mantissa) =
345                component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
346
347            let r_b = {
348                let log_gain = match sub_band.sub_band_type {
349                    SubBandType::LowLow => 0,
350                    SubBandType::LowHigh => 1,
351                    SubBandType::HighLow => 1,
352                    SubBandType::HighHigh => 2,
353                };
354
355                component_info.size_info.precision as u16 + log_gain
356            };
357
358            crate::math::pow2i(r_b as i32 - exponent as i32) * (1.0 + (mantissa as f32) / 2048.0)
359        }
360    };
361
362    let num_bitplanes = {
363        let (exponent, _) = component_info.exponent_mantissa(sub_band.sub_band_type, resolution)?;
364        // Equation (E-2)
365        let num_bitplanes = (component_info.quantization_info.guard_bits as u16)
366            .checked_add(exponent)
367            .and_then(|x| x.checked_sub(1))
368            .ok_or(DecodingError::InvalidBitplaneCount)?;
369
370        if num_bitplanes > MAX_BITPLANE_COUNT as u16 {
371            bail!(DecodingError::TooManyBitplanes);
372        }
373
374        num_bitplanes as u8
375    };
376
377    for precinct in sub_band
378        .precincts
379        .clone()
380        .map(|idx| &storage.precincts[idx])
381    {
382        for code_block in precinct
383            .code_blocks
384            .clone()
385            .map(|idx| &storage.code_blocks[idx])
386        {
387            bitplane::decode(
388                code_block,
389                sub_band.sub_band_type,
390                num_bitplanes,
391                &component_info.coding_style.parameters.code_block_style,
392                tile_ctx,
393                storage,
394                header.strict,
395            )?;
396
397            // Turn the signs and magnitudes into singular coefficients and
398            // copy them into the sub-band.
399
400            let x_offset = code_block.rect.x0 - sub_band.rect.x0;
401            let y_offset = code_block.rect.y0 - sub_band.rect.y0;
402
403            let base_store = &mut storage.coefficients[sub_band.coefficients.clone()];
404            let mut base_idx = (y_offset * sub_band.rect.width()) as usize + x_offset as usize;
405
406            for coefficients in tile_ctx.bit_plane_decode_context.coefficient_rows() {
407                let out_row = &mut base_store[base_idx..];
408
409                for (output, coefficient) in out_row.iter_mut().zip(coefficients.iter().copied()) {
410                    *output = coefficient.get() as f32;
411                    *output *= dequantization_step;
412                }
413
414                base_idx += sub_band.rect.width() as usize;
415            }
416        }
417    }
418
419    Ok(())
420}
421
422fn apply_sign_shift(channel_data: &mut [ComponentData], component_infos: &[ComponentInfo]) {
423    use crate::math::{Level, dispatch, f32x8};
424
425    for (channel, component_info) in channel_data.iter_mut().zip(component_infos.iter()) {
426        let offset = (1_u32 << (component_info.size_info.precision - 1)) as f32;
427        dispatch!(Level::new(), simd => {
428            let offset_v = f32x8::splat(simd, offset);
429            for chunk in channel.container.chunks_exact_mut(8) {
430                let v = f32x8::from_slice(simd, chunk);
431                (v + offset_v).store(chunk);
432            }
433        });
434    }
435}
436
437fn store<'a>(
438    tile: &'a Tile<'a>,
439    header: &Header<'_>,
440    tile_ctx: &mut TileDecodeContext,
441    channel_data: &mut ComponentData,
442    component_info: &ComponentInfo,
443) {
444    let idwt_output = &mut tile_ctx.idwt_output;
445
446    let component_tile = ComponentTile::new(tile, component_info);
447    let resolution_tile = ResolutionTile::new(
448        component_tile,
449        component_info.num_resolution_levels() - 1 - header.skipped_resolution_levels,
450    );
451
452    let (scale_x, scale_y) = (
453        component_info.size_info.horizontal_resolution,
454        component_info.size_info.vertical_resolution,
455    );
456
457    let (image_x_offset, image_y_offset) = (
458        header.size_data.image_area_x_offset,
459        header.size_data.image_area_y_offset,
460    );
461
462    if scale_x == 1 && scale_y == 1 {
463        // If no sub-sampling, use a fast path where we copy rows of coefficients
464        // at once.
465
466        // The rect of the IDWT output corresponds to the rect of the highest
467        // decomposition level of the tile, which is usually not 1:1 aligned
468        // with the actual tile rectangle. We also need to account for the
469        // offset of the reference grid.
470
471        let skip_x = image_x_offset.saturating_sub(idwt_output.rect.x0);
472        let skip_y = image_y_offset.saturating_sub(idwt_output.rect.y0);
473
474        let input_row_iter = idwt_output
475            .coefficients
476            .chunks_exact(idwt_output.rect.width() as usize)
477            .skip(skip_y as usize)
478            .take(idwt_output.rect.height() as usize);
479
480        let output_row_iter = channel_data
481            .container
482            .chunks_exact_mut(header.size_data.image_width() as usize)
483            .skip(resolution_tile.rect.y0.saturating_sub(image_y_offset) as usize);
484
485        for (input_row, output_row) in input_row_iter.zip(output_row_iter) {
486            let input_row = &input_row[skip_x as usize..];
487            let output_row = &mut output_row
488                [resolution_tile.rect.x0.saturating_sub(image_x_offset) as usize..]
489                [..input_row.len()];
490
491            output_row.copy_from_slice(input_row);
492        }
493    } else {
494        let image_width = header.size_data.image_width();
495        let image_height = header.size_data.image_height();
496
497        let x_shrink_factor = header.size_data.x_shrink_factor;
498        let y_shrink_factor = header.size_data.y_shrink_factor;
499
500        let x_offset = header
501            .size_data
502            .image_area_x_offset
503            .div_ceil(x_shrink_factor);
504        let y_offset = header
505            .size_data
506            .image_area_y_offset
507            .div_ceil(y_shrink_factor);
508
509        // Otherwise, copy sample by sample.
510        for y in resolution_tile.rect.y0..resolution_tile.rect.y1 {
511            let relative_y = (y - component_tile.rect.y0) as usize;
512            let reference_grid_y = (scale_y as u32 * y) / y_shrink_factor;
513
514            for x in resolution_tile.rect.x0..resolution_tile.rect.x1 {
515                let relative_x = (x - component_tile.rect.x0) as usize;
516                let reference_grid_x = (scale_x as u32 * x) / x_shrink_factor;
517
518                let sample = idwt_output.coefficients
519                    [relative_y * idwt_output.rect.width() as usize + relative_x];
520
521                for x_position in u32::max(reference_grid_x, x_offset)
522                    ..u32::min(reference_grid_x + scale_x as u32, image_width + x_offset)
523                {
524                    for y_position in u32::max(reference_grid_y, y_offset)
525                        ..u32::min(reference_grid_y + scale_y as u32, image_height + y_offset)
526                    {
527                        let pos = (y_position - y_offset) as usize * image_width as usize
528                            + (x_position - x_offset) as usize;
529
530                        channel_data.container[pos] = sample;
531                    }
532                }
533            }
534        }
535    }
536}