1use jxl_bitstream::Bitstream;
2use jxl_grid::{AllocTracker, MutableSubgrid, SharedSubgrid};
3use jxl_modular::{ChannelShift, Sample};
4
5use crate::{BlockInfo, HfBlockContext, HfPass, Result};
6
7#[derive(Debug)]
9pub struct HfCoeffParams<'a, 'b, S: Sample> {
10 pub num_hf_presets: u32,
11 pub hf_block_ctx: &'a HfBlockContext,
12 pub block_info: SharedSubgrid<'a, BlockInfo>,
13 pub jpeg_upsampling: [u32; 3],
14 pub lf_quant: Option<[SharedSubgrid<'a, S>; 3]>,
15 pub hf_pass: &'a HfPass,
16 pub coeff_shift: u32,
17 pub tracker: Option<&'b AllocTracker>,
18}
19
20pub fn write_hf_coeff<S: Sample>(
22 bitstream: &mut Bitstream,
23 params: HfCoeffParams<S>,
24 hf_coeff_output: &mut [MutableSubgrid<i32>; 3],
25) -> Result<()> {
26 const COEFF_FREQ_CONTEXT: [u32; 63] = [
27 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19,
28 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, 27,
29 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30,
30 ];
31 const COEFF_NUM_NONZERO_CONTEXT: [u32; 63] = [
32 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, 152, 152, 152, 152, 152, 152, 152, 152,
33 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206,
34 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206,
35 206, 206, 206, 206, 206, 206, 206,
36 ];
37
38 let HfCoeffParams {
39 num_hf_presets,
40 hf_block_ctx,
41 block_info,
42 jpeg_upsampling,
43 lf_quant,
44 hf_pass,
45 coeff_shift,
46 tracker,
47 } = params;
48 let mut dist = hf_pass.clone_decoder();
49
50 let HfBlockContext {
51 qf_thresholds,
52 lf_thresholds,
53 block_ctx_map,
54 num_block_clusters,
55 } = hf_block_ctx;
56 let lf_idx_mul =
57 (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1);
58 let hf_idx_mul = qf_thresholds.len() + 1;
59 let upsampling_shifts: [_; 3] =
60 std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx));
61 let hshifts = upsampling_shifts.map(|shift| shift.hshift());
62 let vshifts = upsampling_shifts.map(|shift| shift.vshift());
63
64 let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros();
65 let hfp = bitstream.read_bits(hfp_bits as usize)?;
66 if hfp >= num_hf_presets {
67 tracing::error!(hfp, num_hf_presets, "selected HF preset out of bounds");
68 return Err(
69 jxl_bitstream::Error::ValidationFailed("selected HF preset out of bounds").into(),
70 );
71 }
72
73 let ctx_size = 495 * *num_block_clusters;
74 let cluster_map = dist.cluster_map()[(ctx_size * hfp) as usize..][..ctx_size as usize].to_vec();
75
76 dist.begin(bitstream)?;
77
78 let width = block_info.width();
79 let height = block_info.height();
80 let non_zeros_grid_lengths =
81 upsampling_shifts.map(|shift| shift.shift_size((width as u32, height as u32)).0 as usize);
82
83 let _non_zeros_grid_handle = tracker
84 .map(|tracker| {
85 let len =
86 non_zeros_grid_lengths[0] + non_zeros_grid_lengths[1] + non_zeros_grid_lengths[2];
87 tracker.alloc::<u32>(len)
88 })
89 .transpose()?;
90 let mut non_zeros_grid_row = [
91 vec![0u32; non_zeros_grid_lengths[0]],
92 vec![0u32; non_zeros_grid_lengths[1]],
93 vec![0u32; non_zeros_grid_lengths[2]],
94 ];
95
96 for y in 0..height {
97 for x in 0..width {
98 let BlockInfo::Data {
99 dct_select,
100 hf_mul: qf,
101 } = block_info.get(x, y)
102 else {
103 continue;
104 };
105 let (w8, h8) = dct_select.dct_select_size();
106 let num_blocks = w8 * h8; let num_blocks_log = num_blocks.trailing_zeros();
108 let order_id = dct_select.order_id();
109
110 let lf_idx = if let Some(lf_quant) = &lf_quant {
111 let mut idx = 0usize;
112 for c in [0, 2, 1] {
113 let lf_thresholds = &lf_thresholds[c];
114 idx *= lf_thresholds.len() + 1;
115
116 let x = x >> hshifts[c];
117 let y = y >> vshifts[c];
118 let q = lf_quant[c].get(x, y);
119 for &threshold in lf_thresholds {
120 if q.to_i32() > threshold {
121 idx += 1;
122 }
123 }
124 }
125 idx
126 } else {
127 0
128 };
129
130 let hf_idx = {
131 let mut idx = 0usize;
132 for &threshold in qf_thresholds {
133 if qf > threshold as i32 {
134 idx += 1;
135 }
136 }
137 idx
138 };
139
140 for c in 0..3 {
141 let ch_idx = c * 13 + order_id as usize;
142 let c = [1, 0, 2][c]; let hshift = hshifts[c];
145 let vshift = vshifts[c];
146 let sx = x >> hshift;
147 let sy = y >> vshift;
148 if hshift != 0 || vshift != 0 {
149 if sx << hshift != x || sy << vshift != y {
150 continue;
151 }
152 if !matches!(block_info.get(sx, sy), BlockInfo::Data { .. }) {
153 continue;
154 }
155 }
156
157 let idx = (ch_idx * hf_idx_mul + hf_idx) * lf_idx_mul + lf_idx;
158 let block_ctx = block_ctx_map[idx] as u32;
159 let non_zeros_ctx = {
160 let predicted = if sy == 0 {
161 if sx == 0 {
162 32
163 } else {
164 non_zeros_grid_row[c][sx - 1]
165 }
166 } else if sx == 0 {
167 non_zeros_grid_row[c][sx]
168 } else {
169 (non_zeros_grid_row[c][sx] + non_zeros_grid_row[c][sx - 1] + 1) >> 1
170 };
171 debug_assert!(predicted < 64);
172
173 let idx = if predicted >= 8 {
174 4 + predicted / 2
175 } else {
176 predicted
177 };
178 block_ctx + idx * num_block_clusters
179 };
180
181 let mut non_zeros = dist.read_varint_with_multiplier_clustered(
182 bitstream,
183 cluster_map[non_zeros_ctx as usize],
184 0,
185 )?;
186 if non_zeros > (63 << num_blocks_log) {
187 tracing::error!(non_zeros, num_blocks, "non_zeros too large");
188 return Err(
189 jxl_bitstream::Error::ValidationFailed("non_zeros too large").into(),
190 );
191 }
192
193 let non_zeros_val = (non_zeros + num_blocks - 1) >> num_blocks_log;
194 for dx in 0..w8 as usize {
195 non_zeros_grid_row[c][sx + dx] = non_zeros_val;
196 }
197 if non_zeros == 0 {
198 continue;
199 }
200
201 let coeff_grid = &mut hf_coeff_output[c];
202 let mut is_prev_coeff_nonzero = (non_zeros <= num_blocks * 4) as u32;
203 let order = hf_pass.order(order_id as usize, c);
204
205 let coeff_ctx_base = block_ctx * 458 + 37 * num_block_clusters;
206 let cluster_map = &cluster_map[coeff_ctx_base as usize..][..458];
207 for (idx, &coeff_coord) in order[num_blocks as usize..].iter().enumerate() {
208 let coeff_ctx = {
209 let non_zeros = (non_zeros - 1) >> num_blocks_log;
210 let idx = idx >> num_blocks_log;
211 (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx])
212 * 2
213 + is_prev_coeff_nonzero
214 };
215 let cluster = *cluster_map.get(coeff_ctx as usize).ok_or_else(|| {
216 tracing::error!("too many zeros in varblock HF coefficient");
217 jxl_bitstream::Error::ValidationFailed(
218 "too many zeros in varblock HF coefficient",
219 )
220 })?;
221 let ucoeff =
222 dist.read_varint_with_multiplier_clustered(bitstream, cluster, 0)?;
223 if ucoeff == 0 {
224 is_prev_coeff_nonzero = 0;
225 continue;
226 }
227
228 let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift;
229 let (mut dx, mut dy) = coeff_coord;
230 if dct_select.need_transpose() {
231 std::mem::swap(&mut dx, &mut dy);
232 }
233 let x = sx * 8 + dx as usize;
234 let y = sy * 8 + dy as usize;
235
236 *coeff_grid.get_mut(x, y) += coeff;
237
238 is_prev_coeff_nonzero = 1;
239 non_zeros -= 1;
240
241 if non_zeros == 0 {
242 break;
243 }
244 }
245 }
246 }
247 }
248
249 dist.finalize()?;
250
251 Ok(())
252}