1use std::collections::{HashMap, HashSet};
4use std::io::Write;
5
6use brotli_decompressor::DecompressorWriter;
7use jxl_bitstream::{Bitstream, U};
8use jxl_frame::Frame;
9use jxl_oxide_common::Bundle;
10
11use crate::huffman::HuffmanCode;
12
13mod bit_writer;
14mod error;
15mod huffman;
16mod reconstruct;
17
18pub use error::Error;
19pub use reconstruct::JpegBitstreamReconstructor;
20
21use error::Result;
22
23const HEADER_ICC: &[u8] = b"ICC_PROFILE\0";
24const HEADER_EXIF: &[u8] = b"Exif\0\0";
25const HEADER_XMP: &[u8] = b"http://ns.adobe.com/xap/1.0/\0";
26
27pub struct JpegBitstreamData {
29 header: Box<JpegBitstreamHeader>,
30 data_stream: Box<DecompressorWriter<Vec<u8>>>,
31}
32
33impl std::fmt::Debug for JpegBitstreamData {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("JpegBitstreamData")
36 .field("header", &self.header)
37 .finish_non_exhaustive()
38 }
39}
40
41impl JpegBitstreamData {
42 pub fn try_parse(data: &[u8]) -> Result<Option<Self>> {
49 let mut bitstream = Bitstream::new(data);
50 let header = match JpegBitstreamHeader::parse(&mut bitstream, ()) {
51 Ok(header) => Box::new(header),
52 Err(e) if e.unexpected_eof() => return Ok(None),
53 Err(e) => return Err(e.into()),
54 };
55 bitstream.zero_pad_to_byte()?;
56
57 let bytes_read = bitstream.num_read_bits() / 8;
58 let compressed_data = &data[bytes_read..];
59 let mut data_stream = Box::new(DecompressorWriter::new(Vec::new(), 4096));
60 data_stream
61 .write_all(compressed_data)
62 .map_err(Error::Brotli)?;
63
64 Ok(Some(Self {
65 header,
66 data_stream,
67 }))
68 }
69
70 pub fn feed_bytes(&mut self, data: &[u8]) -> Result<()> {
72 self.data_stream.write_all(data).map_err(Error::Brotli)
73 }
74
75 pub fn finalize(&mut self) -> Result<()> {
77 self.data_stream.flush().map_err(Error::Brotli)?;
78
79 let decompressed_len = self.data_stream.get_ref().len();
80 if decompressed_len != self.header.expected_data_len() {
81 tracing::error!(
82 decompressed_len,
83 expected = self.header.expected_data_len(),
84 "Data section length of jbrd box doesn't match expected length"
85 );
86 return Err(Error::InvalidData);
87 }
88
89 Ok(())
90 }
91
92 pub fn reconstruct<'jbrd, 'frame, 'meta>(
96 &'jbrd self,
97 frame: &'frame Frame,
98 icc_profile: &'meta [u8],
99 exif: &'meta [u8],
100 xmp: &'meta [u8],
101 pool: &jxl_threadpool::JxlThreadPool,
102 ) -> Result<JpegBitstreamReconstructor<'jbrd, 'frame, 'meta>> {
103 let Self {
104 ref header,
105 ref data_stream,
106 } = *self;
107 JpegBitstreamReconstructor::new(
108 header,
109 data_stream.get_ref(),
110 frame,
111 icc_profile,
112 exif,
113 xmp,
114 pool,
115 )
116 }
117
118 pub fn header(&self) -> &JpegBitstreamHeader {
119 &self.header
120 }
121}
122
123#[derive(Debug)]
124pub struct JpegBitstreamHeader {
125 is_gray: bool,
126 markers: Vec<u8>,
127 app_markers: Vec<AppMarker>,
128 com_lengths: Vec<u32>,
129 quant_tables: Vec<QuantTable>,
130 components: Vec<Component>,
131 huffman_codes: Vec<HuffmanCode>,
132 scan_info: Vec<ScanInfo>,
133 restart_interval: u32,
134 scan_more_info: Vec<ScanMoreInfo>,
135 intermarker_lengths: Vec<u32>,
136 tail_data_length: u32,
137 padding_bits: Option<Padding>,
138}
139
140impl Bundle for JpegBitstreamHeader {
141 type Error = jxl_bitstream::Error;
142
143 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
144 let is_gray = bitstream.read_bool()?;
145
146 let mut markers = Vec::new();
147 let mut num_app_markers = 0usize;
148 let mut num_com_markers = 0usize;
149 let mut num_scans = 0usize;
150 let mut num_intermarkers = 0usize;
151 let mut has_dri = false;
152 while markers.last() != Some(&0xd9) {
153 let marker_bits = bitstream.read_bits(6)? as u8 + 0xc0;
154 match marker_bits {
155 0xe0..=0xef => num_app_markers += 1,
156 0xfe => num_com_markers += 1,
157 0xda => num_scans += 1,
158 0xff => num_intermarkers += 1,
159 0xdd => has_dri = true,
160 _ => {}
161 }
162 markers.push(marker_bits);
163 }
164
165 let app_markers = (0..num_app_markers)
166 .map(|_| AppMarker::parse(bitstream, ()))
167 .collect::<Result<_, _>>()?;
168 let com_lengths = (0..num_com_markers)
169 .map(|_| bitstream.read_bits(16).map(|x| x + 1))
170 .collect::<Result<_, _>>()?;
171
172 let num_quant_tables = bitstream.read_bits(2)? + 1;
173 let quant_tables = (0..num_quant_tables)
174 .map(|_| QuantTable::parse(bitstream, ()))
175 .collect::<Result<_, _>>()?;
176
177 let comp_type = bitstream.read_bits(2)?;
178 let component_ids = match comp_type {
179 0 => vec![1u8],
180 1 => vec![1u8, 2, 3],
181 2 => vec![b'R', b'G', b'B'],
182 3 => {
183 let num_comp = bitstream.read_bits(2)? as u8 + 1;
184 (0..num_comp)
185 .map(|_| bitstream.read_bits(8).map(|x| x as u8))
186 .collect::<Result<_, _>>()?
187 }
188 _ => unreachable!(),
189 };
190 let components = component_ids
191 .into_iter()
192 .map(|id| -> Result<_, Self::Error> {
193 let q_idx = bitstream.read_bits(2)? as u8;
194 Ok(Component { id, q_idx })
195 })
196 .collect::<Result<_, _>>()?;
197
198 let num_huff = bitstream.read_u32(4, 2 + U(3), 10 + U(4), 26 + U(6))?;
199 let huffman_codes = (0..num_huff)
200 .map(|_| HuffmanCode::parse(bitstream, ()))
201 .collect::<Result<_, _>>()?;
202
203 let scan_info = (0..num_scans)
204 .map(|_| ScanInfo::parse(bitstream, ()))
205 .collect::<Result<_, _>>()?;
206 let restart_interval = if has_dri { bitstream.read_bits(16)? } else { 0 };
207 let scan_more_info = (0..num_scans)
208 .map(|_| ScanMoreInfo::parse(bitstream, ()))
209 .collect::<Result<_, _>>()?;
210
211 let intermarker_lengths = (0..num_intermarkers)
212 .map(|_| bitstream.read_bits(16))
213 .collect::<Result<_, _>>()?;
214
215 let tail_data_length = bitstream.read_u32(0, 1 + U(8), 257 + U(16), 65793 + U(22))?;
216
217 let has_padding = bitstream.read_bool()?;
218 let padding_bits = has_padding
219 .then(|| Padding::parse(bitstream, ()))
220 .transpose()?;
221
222 Ok(Self {
223 is_gray,
224 markers,
225 app_markers,
226 com_lengths,
227 quant_tables,
228 components,
229 huffman_codes,
230 scan_info,
231 restart_interval,
232 scan_more_info,
233 intermarker_lengths,
234 tail_data_length,
235 padding_bits,
236 })
237 }
238}
239
240impl JpegBitstreamHeader {
241 fn app_data_len(&self) -> usize {
242 self.app_markers
243 .iter()
244 .filter_map(|marker| (marker.ty == 0).then_some(marker.length as usize))
245 .sum::<usize>()
246 }
247
248 fn com_data_len(&self) -> usize {
249 self.com_lengths.iter().map(|&x| x as usize).sum::<usize>()
250 }
251
252 fn intermarker_data_len(&self) -> usize {
253 self.intermarker_lengths
254 .iter()
255 .map(|&x| x as usize)
256 .sum::<usize>()
257 }
258
259 fn expected_data_len(&self) -> usize {
260 self.app_data_len()
261 + self.com_data_len()
262 + self.intermarker_data_len()
263 + self.tail_data_length as usize
264 }
265
266 pub fn expected_icc_len(&self) -> usize {
267 self.app_markers
268 .iter()
269 .filter(|am| am.ty == 1)
270 .map(|am| am.length as usize - 5 - HEADER_ICC.len())
271 .sum::<usize>()
272 }
273
274 pub fn expected_exif_len(&self) -> usize {
275 self.app_markers
276 .iter()
277 .find(|am| am.ty == 2)
278 .map(|am| am.length as usize - 3 - HEADER_EXIF.len())
279 .unwrap_or(0)
280 }
281
282 pub fn expected_xmp_len(&self) -> usize {
283 self.app_markers
284 .iter()
285 .find(|am| am.ty == 3)
286 .map(|am| am.length as usize - 3 - HEADER_XMP.len())
287 .unwrap_or(0)
288 }
289}
290
291#[derive(Debug)]
292struct AppMarker {
293 ty: u32,
294 length: u32,
295}
296
297impl Bundle for AppMarker {
298 type Error = jxl_bitstream::Error;
299
300 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
301 Ok(Self {
302 ty: bitstream.read_u32(0, 1, 2 + U(1), 4 + U(2))?,
303 length: bitstream.read_bits(16)? + 1,
304 })
305 }
306}
307
308#[derive(Debug)]
309struct QuantTable {
310 precision: u8,
311 index: u8,
312 is_last: bool,
313}
314
315impl Bundle for QuantTable {
316 type Error = jxl_bitstream::Error;
317
318 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
319 Ok(Self {
320 precision: bitstream.read_bits(1)? as u8,
321 index: bitstream.read_bits(2)? as u8,
322 is_last: bitstream.read_bool()?,
323 })
324 }
325}
326
327#[derive(Debug)]
328struct Component {
329 id: u8,
330 q_idx: u8,
331}
332
333#[derive(Debug)]
334struct ScanInfo {
335 ss: u8,
336 se: u8,
337 al: u8,
338 ah: u8,
339 component_info: Vec<ScanComponentInfo>,
340 #[allow(unused)]
341 last_needed_pass: u8,
342}
343
344impl Bundle for ScanInfo {
345 type Error = jxl_bitstream::Error;
346
347 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
348 let num_comps = bitstream.read_bits(2)? as u8 + 1;
349 let ss = bitstream.read_bits(6)? as u8;
350 let se = bitstream.read_bits(6)? as u8;
351 let al = bitstream.read_bits(4)? as u8;
352 let ah = bitstream.read_bits(4)? as u8;
353 let component_info = (0..num_comps)
354 .map(|_| ScanComponentInfo::parse(bitstream, ()))
355 .collect::<Result<_, _>>()?;
356 let last_needed_pass = bitstream.read_u32(0, 1, 2, 3 + U(3))? as u8;
357 Ok(Self {
358 ss,
359 se,
360 ah,
361 al,
362 component_info,
363 last_needed_pass,
364 })
365 }
366}
367
368impl ScanInfo {
369 fn num_comps(&self) -> u8 {
370 self.component_info.len() as u8
371 }
372}
373
374#[derive(Debug)]
375struct ScanComponentInfo {
376 comp_idx: u8,
377 ac_tbl_idx: u8,
378 dc_tbl_idx: u8,
379}
380
381impl Bundle for ScanComponentInfo {
382 type Error = jxl_bitstream::Error;
383
384 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
385 Ok(Self {
386 comp_idx: bitstream.read_bits(2)? as u8,
387 ac_tbl_idx: bitstream.read_bits(2)? as u8,
388 dc_tbl_idx: bitstream.read_bits(2)? as u8,
389 })
390 }
391}
392
393#[derive(Debug)]
394struct ScanMoreInfo {
395 reset_points: HashSet<u32>,
396 extra_zero_runs: HashMap<u32, u32>,
397}
398
399impl Bundle for ScanMoreInfo {
400 type Error = jxl_bitstream::Error;
401
402 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
403 let num_reset_points = bitstream.read_u32(0, 1 + U(2), 4 + U(4), 20 + U(16))?;
404 let mut last_block_idx: Option<u32> = None;
405 let reset_points = (0..num_reset_points)
406 .map(|_| -> Result<_, Self::Error> {
407 let diff = bitstream.read_u32(0, 1 + U(3), 9 + U(5), 41 + U(28))?;
408 let block_idx = if let Some(last_block_idx) = last_block_idx {
409 last_block_idx.saturating_add(diff + 1)
410 } else {
411 diff
412 };
413 if block_idx > (3 << 26) {
414 tracing::error!(value = block_idx, "reset_points too large");
415 return Err(jxl_bitstream::Error::ValidationFailed(
416 "reset_points too large",
417 ));
418 }
419 last_block_idx = Some(block_idx);
420 Ok(block_idx)
421 })
422 .collect::<Result<_, _>>()?;
423
424 let num_extra_zero_runs = bitstream.read_u32(0, 1 + U(2), 4 + U(4), 20 + U(16))?;
425 let mut last_block_idx: Option<u32> = None;
426 let extra_zero_runs = (0..num_extra_zero_runs)
427 .map(|_| -> Result<_, jxl_bitstream::Error> {
428 let ExtraZeroRun {
429 num_runs,
430 run_length,
431 } = ExtraZeroRun::parse(bitstream, ())?;
432 let block_idx = if let Some(last_block_idx) = last_block_idx {
433 last_block_idx.saturating_add(run_length + 1)
434 } else {
435 run_length
436 };
437 if block_idx > (3 << 26) {
438 tracing::error!(block_idx, "extra_zero_runs.block_idx too large");
439 return Err(jxl_bitstream::Error::ValidationFailed(
440 "extra_zero_runs.block_idx too large",
441 ));
442 }
443 last_block_idx = Some(block_idx);
444 Ok((block_idx, num_runs))
445 })
446 .collect::<Result<_, _>>()?;
447
448 Ok(Self {
449 reset_points,
450 extra_zero_runs,
451 })
452 }
453}
454
455#[derive(Debug)]
456struct ExtraZeroRun {
457 num_runs: u32,
458 run_length: u32,
459}
460
461impl Bundle for ExtraZeroRun {
462 type Error = jxl_bitstream::Error;
463
464 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
465 Ok(Self {
466 num_runs: bitstream.read_u32(1, 2 + U(2), 5 + U(4), 20 + U(8))?,
467 run_length: bitstream.read_u32(0, 1 + U(3), 9 + U(5), 41 + U(28))?,
468 })
469 }
470}
471
472#[derive(Debug)]
473struct Padding {
474 bits: Vec<u8>,
475}
476
477impl Bundle for Padding {
478 type Error = jxl_bitstream::Error;
479
480 fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
481 let num_bits = bitstream.read_bits(24)?;
482 let full_bytes = num_bits / 8;
483 let extra_bits = num_bits % 8;
484 let mut bits = Vec::with_capacity(full_bytes as usize + (extra_bits != 0) as usize);
485 for _ in 0..full_bytes {
486 bits.push(bitstream.read_bits(8)? as u8);
487 }
488 bits.push(bitstream.read_bits(extra_bits as usize)? as u8);
489
490 Ok(Self { bits })
491 }
492}