jxl_vardct/
hf_coeff.rs

1use jxl_bitstream::Bitstream;
2use jxl_grid::{AllocTracker, MutableSubgrid, SharedSubgrid};
3use jxl_modular::{ChannelShift, Sample};
4
5use crate::{BlockInfo, HfBlockContext, HfPass, Result};
6
7/// Parameters for decoding `HfCoeff`.
8#[derive(Debug)]
9pub struct HfCoeffParams<'a, 'b, S: Sample> {
10    pub num_hf_presets: u32,
11    pub hf_block_ctx: &'a HfBlockContext,
12    pub block_info: SharedSubgrid<'a, BlockInfo>,
13    pub jpeg_upsampling: [u32; 3],
14    pub lf_quant: Option<[SharedSubgrid<'a, S>; 3]>,
15    pub hf_pass: &'a HfPass,
16    pub coeff_shift: u32,
17    pub tracker: Option<&'b AllocTracker>,
18}
19
20/// Decode and write HF coefficients from the bitstream.
21pub fn write_hf_coeff<S: Sample>(
22    bitstream: &mut Bitstream,
23    params: HfCoeffParams<S>,
24    hf_coeff_output: &mut [MutableSubgrid<i32>; 3],
25) -> Result<()> {
26    const COEFF_FREQ_CONTEXT: [u32; 63] = [
27        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19,
28        20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, 27,
29        27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30,
30    ];
31    const COEFF_NUM_NONZERO_CONTEXT: [u32; 63] = [
32        0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, 152, 152, 152, 152, 152, 152, 152, 152,
33        180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206,
34        206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206,
35        206, 206, 206, 206, 206, 206, 206,
36    ];
37
38    let HfCoeffParams {
39        num_hf_presets,
40        hf_block_ctx,
41        block_info,
42        jpeg_upsampling,
43        lf_quant,
44        hf_pass,
45        coeff_shift,
46        tracker,
47    } = params;
48    let mut dist = hf_pass.clone_decoder();
49
50    let HfBlockContext {
51        qf_thresholds,
52        lf_thresholds,
53        block_ctx_map,
54        num_block_clusters,
55    } = hf_block_ctx;
56    let lf_idx_mul =
57        (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1);
58    let hf_idx_mul = qf_thresholds.len() + 1;
59    let upsampling_shifts: [_; 3] =
60        std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx));
61    let hshifts = upsampling_shifts.map(|shift| shift.hshift());
62    let vshifts = upsampling_shifts.map(|shift| shift.vshift());
63
64    let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros();
65    let hfp = bitstream.read_bits(hfp_bits as usize)?;
66    if hfp >= num_hf_presets {
67        tracing::error!(hfp, num_hf_presets, "selected HF preset out of bounds");
68        return Err(
69            jxl_bitstream::Error::ValidationFailed("selected HF preset out of bounds").into(),
70        );
71    }
72
73    let ctx_size = 495 * *num_block_clusters;
74    let cluster_map = dist.cluster_map()[(ctx_size * hfp) as usize..][..ctx_size as usize].to_vec();
75
76    dist.begin(bitstream)?;
77
78    let width = block_info.width();
79    let height = block_info.height();
80    let non_zeros_grid_lengths =
81        upsampling_shifts.map(|shift| shift.shift_size((width as u32, height as u32)).0 as usize);
82
83    let _non_zeros_grid_handle = tracker
84        .map(|tracker| {
85            let len =
86                non_zeros_grid_lengths[0] + non_zeros_grid_lengths[1] + non_zeros_grid_lengths[2];
87            tracker.alloc::<u32>(len)
88        })
89        .transpose()?;
90    let mut non_zeros_grid_row = [
91        vec![0u32; non_zeros_grid_lengths[0]],
92        vec![0u32; non_zeros_grid_lengths[1]],
93        vec![0u32; non_zeros_grid_lengths[2]],
94    ];
95
96    for y in 0..height {
97        for x in 0..width {
98            let BlockInfo::Data {
99                dct_select,
100                hf_mul: qf,
101            } = block_info.get(x, y)
102            else {
103                continue;
104            };
105            let (w8, h8) = dct_select.dct_select_size();
106            let num_blocks = w8 * h8; // power of 2
107            let num_blocks_log = num_blocks.trailing_zeros();
108            let order_id = dct_select.order_id();
109
110            let lf_idx = if let Some(lf_quant) = &lf_quant {
111                let mut idx = 0usize;
112                for c in [0, 2, 1] {
113                    let lf_thresholds = &lf_thresholds[c];
114                    idx *= lf_thresholds.len() + 1;
115
116                    let x = x >> hshifts[c];
117                    let y = y >> vshifts[c];
118                    let q = lf_quant[c].get(x, y);
119                    for &threshold in lf_thresholds {
120                        if q.to_i32() > threshold {
121                            idx += 1;
122                        }
123                    }
124                }
125                idx
126            } else {
127                0
128            };
129
130            let hf_idx = {
131                let mut idx = 0usize;
132                for &threshold in qf_thresholds {
133                    if qf > threshold as i32 {
134                        idx += 1;
135                    }
136                }
137                idx
138            };
139
140            for c in 0..3 {
141                let ch_idx = c * 13 + order_id as usize;
142                let c = [1, 0, 2][c]; // y, x, b
143
144                let hshift = hshifts[c];
145                let vshift = vshifts[c];
146                let sx = x >> hshift;
147                let sy = y >> vshift;
148                if hshift != 0 || vshift != 0 {
149                    if sx << hshift != x || sy << vshift != y {
150                        continue;
151                    }
152                    if !matches!(block_info.get(sx, sy), BlockInfo::Data { .. }) {
153                        continue;
154                    }
155                }
156
157                let idx = (ch_idx * hf_idx_mul + hf_idx) * lf_idx_mul + lf_idx;
158                let block_ctx = block_ctx_map[idx] as u32;
159                let non_zeros_ctx = {
160                    let predicted = if sy == 0 {
161                        if sx == 0 {
162                            32
163                        } else {
164                            non_zeros_grid_row[c][sx - 1]
165                        }
166                    } else if sx == 0 {
167                        non_zeros_grid_row[c][sx]
168                    } else {
169                        (non_zeros_grid_row[c][sx] + non_zeros_grid_row[c][sx - 1] + 1) >> 1
170                    };
171                    debug_assert!(predicted < 64);
172
173                    let idx = if predicted >= 8 {
174                        4 + predicted / 2
175                    } else {
176                        predicted
177                    };
178                    block_ctx + idx * num_block_clusters
179                };
180
181                let mut non_zeros = dist.read_varint_with_multiplier_clustered(
182                    bitstream,
183                    cluster_map[non_zeros_ctx as usize],
184                    0,
185                )?;
186                if non_zeros > (63 << num_blocks_log) {
187                    tracing::error!(non_zeros, num_blocks, "non_zeros too large");
188                    return Err(
189                        jxl_bitstream::Error::ValidationFailed("non_zeros too large").into(),
190                    );
191                }
192
193                let non_zeros_val = (non_zeros + num_blocks - 1) >> num_blocks_log;
194                for dx in 0..w8 as usize {
195                    non_zeros_grid_row[c][sx + dx] = non_zeros_val;
196                }
197                if non_zeros == 0 {
198                    continue;
199                }
200
201                let coeff_grid = &mut hf_coeff_output[c];
202                let mut is_prev_coeff_nonzero = (non_zeros <= num_blocks * 4) as u32;
203                let order = hf_pass.order(order_id as usize, c);
204
205                let coeff_ctx_base = block_ctx * 458 + 37 * num_block_clusters;
206                let cluster_map = &cluster_map[coeff_ctx_base as usize..][..458];
207                for (idx, &coeff_coord) in order[num_blocks as usize..].iter().enumerate() {
208                    let coeff_ctx = {
209                        let non_zeros = (non_zeros - 1) >> num_blocks_log;
210                        let idx = idx >> num_blocks_log;
211                        (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx])
212                            * 2
213                            + is_prev_coeff_nonzero
214                    };
215                    let cluster = *cluster_map.get(coeff_ctx as usize).ok_or_else(|| {
216                        tracing::error!("too many zeros in varblock HF coefficient");
217                        jxl_bitstream::Error::ValidationFailed(
218                            "too many zeros in varblock HF coefficient",
219                        )
220                    })?;
221                    let ucoeff =
222                        dist.read_varint_with_multiplier_clustered(bitstream, cluster, 0)?;
223                    if ucoeff == 0 {
224                        is_prev_coeff_nonzero = 0;
225                        continue;
226                    }
227
228                    let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift;
229                    let (mut dx, mut dy) = coeff_coord;
230                    if dct_select.need_transpose() {
231                        std::mem::swap(&mut dx, &mut dy);
232                    }
233                    let x = sx * 8 + dx as usize;
234                    let y = sy * 8 + dy as usize;
235
236                    *coeff_grid.get_mut(x, y) += coeff;
237
238                    is_prev_coeff_nonzero = 1;
239                    non_zeros -= 1;
240
241                    if non_zeros == 0 {
242                        break;
243                    }
244                }
245            }
246        }
247    }
248
249    dist.finalize()?;
250
251    Ok(())
252}