1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
use jxl_bitstream::Bitstream;
use jxl_grid::{AllocTracker, CutGrid};
use jxl_modular::{image::TransformedModularSubimage, ChannelShift, MaConfig, Sample};
use jxl_threadpool::JxlThreadPool;
use jxl_vardct::{write_hf_coeff, HfCoeffParams};

use super::{HfGlobal, LfGlobalVarDct, LfGroup};
use crate::{FrameHeader, Result};

#[derive(Debug)]
pub struct PassGroupParams<'frame, 'buf, 'g, 'tracker, S: Sample> {
    pub frame_header: &'frame FrameHeader,
    pub lf_group: &'frame LfGroup<S>,
    pub pass_idx: u32,
    pub group_idx: u32,
    pub global_ma_config: Option<&'frame MaConfig>,
    pub modular: Option<TransformedModularSubimage<'g, S>>,
    pub vardct: Option<PassGroupParamsVardct<'frame, 'buf, 'g>>,
    pub allow_partial: bool,
    pub tracker: Option<&'tracker AllocTracker>,
    pub pool: &'frame JxlThreadPool,
}

#[derive(Debug)]
pub struct PassGroupParamsVardct<'frame, 'buf, 'g> {
    pub lf_vardct: &'frame LfGlobalVarDct,
    pub hf_global: &'frame HfGlobal,
    pub hf_coeff_output: &'buf mut [CutGrid<'g, f32>; 3],
}

pub fn decode_pass_group<S: Sample>(
    bitstream: &mut Bitstream,
    params: PassGroupParams<S>,
) -> Result<()> {
    let PassGroupParams {
        frame_header,
        lf_group,
        pass_idx,
        group_idx,
        global_ma_config,
        modular,
        vardct,
        allow_partial,
        tracker,
        pool,
    } = params;

    if let (
        Some(PassGroupParamsVardct {
            lf_vardct,
            hf_global,
            hf_coeff_output,
        }),
        Some(hf_meta),
    ) = (vardct, &lf_group.hf_meta)
    {
        let hf_pass = &hf_global.hf_passes[pass_idx as usize];
        let coeff_shift = frame_header
            .passes
            .shift
            .get(pass_idx as usize)
            .copied()
            .unwrap_or(0);

        let group_col = group_idx % frame_header.groups_per_row();
        let group_row = group_idx / frame_header.groups_per_row();
        let lf_col = (group_col % 8) as usize;
        let lf_row = (group_row % 8) as usize;
        let group_dim_blocks = (frame_header.group_dim() / 8) as usize;

        let block_info = &hf_meta.block_info;

        let block_left = lf_col * group_dim_blocks;
        let block_top = lf_row * group_dim_blocks;
        let block_width = (block_info.width() - block_left).min(group_dim_blocks);
        let block_height = (block_info.height() - block_top).min(group_dim_blocks);

        let jpeg_upsampling = frame_header.jpeg_upsampling;
        let block_info = block_info.subgrid(
            block_left..(block_left + block_width),
            block_top..(block_top + block_height),
        );
        let lf_quant: Option<[_; 3]> = lf_group.lf_coeff.as_ref().map(|lf_coeff| {
            let lf_quant_channels = lf_coeff.lf_quant.image().unwrap().image_channels();
            std::array::from_fn(|idx| {
                let lf_quant = &lf_quant_channels[[1, 0, 2][idx]];
                let shift = ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx);

                let block_left = block_left >> shift.hshift();
                let block_top = block_top >> shift.vshift();
                let (block_width, block_height) =
                    shift.shift_size((block_width as u32, block_height as u32));
                lf_quant.subgrid(
                    block_left..(block_left + block_width as usize),
                    block_top..(block_top + block_height as usize),
                )
            })
        });

        let params = HfCoeffParams {
            num_hf_presets: hf_global.num_hf_presets,
            hf_block_ctx: &lf_vardct.hf_block_ctx,
            block_info,
            jpeg_upsampling,
            lf_quant,
            hf_pass,
            coeff_shift,
            tracker,
        };

        match write_hf_coeff(bitstream, params, hf_coeff_output) {
            Err(e) if e.unexpected_eof() && allow_partial => {
                tracing::debug!("Partially decoded HfCoeff");
                return Ok(());
            }
            Err(e) => return Err(e.into()),
            Ok(_) => {}
        };
    }

    if let Some(modular) = modular {
        decode_pass_group_modular(
            bitstream,
            frame_header,
            global_ma_config,
            pass_idx,
            group_idx,
            modular,
            allow_partial,
            tracker,
            pool,
        )?;
    }

    Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn decode_pass_group_modular<S: Sample>(
    bitstream: &mut Bitstream,
    frame_header: &FrameHeader,
    global_ma_config: Option<&MaConfig>,
    pass_idx: u32,
    group_idx: u32,
    modular: TransformedModularSubimage<S>,
    allow_partial: bool,
    tracker: Option<&AllocTracker>,
    pool: &JxlThreadPool,
) -> Result<()> {
    if modular.is_empty() {
        return Ok(());
    }

    let mut modular = modular.recursive(bitstream, global_ma_config, tracker)?;
    let mut subimage = modular.prepare_subimage()?;
    subimage.decode(
        bitstream,
        1 + 3 * frame_header.num_lf_groups()
            + 17
            + pass_idx * frame_header.num_groups()
            + group_idx,
        allow_partial,
    )?;
    subimage.finish(pool);
    Ok(())
}