jxl_jbr/
lib.rs

1//! This crate provides JPEG bitstream reconstruction feature from JPEG XL images.
2
3use std::collections::{HashMap, HashSet};
4use std::io::Write;
5
6use brotli_decompressor::DecompressorWriter;
7use jxl_bitstream::{Bitstream, U};
8use jxl_frame::Frame;
9use jxl_oxide_common::Bundle;
10
11use crate::huffman::HuffmanCode;
12
13mod bit_writer;
14mod error;
15mod huffman;
16mod reconstruct;
17
18pub use error::Error;
19pub use reconstruct::JpegBitstreamReconstructor;
20
21use error::Result;
22
23const HEADER_ICC: &[u8] = b"ICC_PROFILE\0";
24const HEADER_EXIF: &[u8] = b"Exif\0\0";
25const HEADER_XMP: &[u8] = b"http://ns.adobe.com/xap/1.0/\0";
26
27/// JPEG bitstream reconstruction data.
28pub struct JpegBitstreamData {
29    header: Box<JpegBitstreamHeader>,
30    data_stream: Box<DecompressorWriter<Vec<u8>>>,
31}
32
33impl std::fmt::Debug for JpegBitstreamData {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("JpegBitstreamData")
36            .field("header", &self.header)
37            .finish_non_exhaustive()
38    }
39}
40
41impl JpegBitstreamData {
42    /// Decodes the JPEG bitstream reconstruction data from the buffer.
43    ///
44    /// Returns `Ok(Some(_))` if header is decoded successfully, and `Ok(None)` if the input is insufficient.
45    ///
46    /// # Errors
47    /// This function returns an error if the header is invalid, or data section could not be decompressed.
48    pub fn try_parse(data: &[u8]) -> Result<Option<Self>> {
49        let mut bitstream = Bitstream::new(data);
50        let header = match JpegBitstreamHeader::parse(&mut bitstream, ()) {
51            Ok(header) => Box::new(header),
52            Err(e) if e.unexpected_eof() => return Ok(None),
53            Err(e) => return Err(e.into()),
54        };
55        bitstream.zero_pad_to_byte()?;
56
57        let bytes_read = bitstream.num_read_bits() / 8;
58        let compressed_data = &data[bytes_read..];
59        let mut data_stream = Box::new(DecompressorWriter::new(Vec::new(), 4096));
60        data_stream
61            .write_all(compressed_data)
62            .map_err(Error::Brotli)?;
63
64        Ok(Some(Self {
65            header,
66            data_stream,
67        }))
68    }
69
70    /// Feeds more bytes to be decompressed.
71    pub fn feed_bytes(&mut self, data: &[u8]) -> Result<()> {
72        self.data_stream.write_all(data).map_err(Error::Brotli)
73    }
74
75    /// Finalizes the stream and checks for potential errors.
76    pub fn finalize(&mut self) -> Result<()> {
77        self.data_stream.flush().map_err(Error::Brotli)?;
78
79        let decompressed_len = self.data_stream.get_ref().len();
80        if decompressed_len != self.header.expected_data_len() {
81            tracing::error!(
82                decompressed_len,
83                expected = self.header.expected_data_len(),
84                "Data section length of jbrd box doesn't match expected length"
85            );
86            return Err(Error::InvalidData);
87        }
88
89        Ok(())
90    }
91
92    /// Creates a reconstruction context with given JPEG XL frame and metadata.
93    ///
94    /// `icc_profile`, `exif` or `xmp` can be empty if no corresponding metadata was found.
95    pub fn reconstruct<'jbrd, 'frame, 'meta>(
96        &'jbrd self,
97        frame: &'frame Frame,
98        icc_profile: &'meta [u8],
99        exif: &'meta [u8],
100        xmp: &'meta [u8],
101        pool: &jxl_threadpool::JxlThreadPool,
102    ) -> Result<JpegBitstreamReconstructor<'jbrd, 'frame, 'meta>> {
103        let Self {
104            ref header,
105            ref data_stream,
106        } = *self;
107        JpegBitstreamReconstructor::new(
108            header,
109            data_stream.get_ref(),
110            frame,
111            icc_profile,
112            exif,
113            xmp,
114            pool,
115        )
116    }
117
118    pub fn header(&self) -> &JpegBitstreamHeader {
119        &self.header
120    }
121}
122
123#[derive(Debug)]
124pub struct JpegBitstreamHeader {
125    is_gray: bool,
126    markers: Vec<u8>,
127    app_markers: Vec<AppMarker>,
128    com_lengths: Vec<u32>,
129    quant_tables: Vec<QuantTable>,
130    components: Vec<Component>,
131    huffman_codes: Vec<HuffmanCode>,
132    scan_info: Vec<ScanInfo>,
133    restart_interval: u32,
134    scan_more_info: Vec<ScanMoreInfo>,
135    intermarker_lengths: Vec<u32>,
136    tail_data_length: u32,
137    padding_bits: Option<Padding>,
138}
139
140impl Bundle for JpegBitstreamHeader {
141    type Error = jxl_bitstream::Error;
142
143    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
144        let is_gray = bitstream.read_bool()?;
145
146        let mut markers = Vec::new();
147        let mut num_app_markers = 0usize;
148        let mut num_com_markers = 0usize;
149        let mut num_scans = 0usize;
150        let mut num_intermarkers = 0usize;
151        let mut has_dri = false;
152        while markers.last() != Some(&0xd9) {
153            let marker_bits = bitstream.read_bits(6)? as u8 + 0xc0;
154            match marker_bits {
155                0xe0..=0xef => num_app_markers += 1,
156                0xfe => num_com_markers += 1,
157                0xda => num_scans += 1,
158                0xff => num_intermarkers += 1,
159                0xdd => has_dri = true,
160                _ => {}
161            }
162            markers.push(marker_bits);
163        }
164
165        let app_markers = (0..num_app_markers)
166            .map(|_| AppMarker::parse(bitstream, ()))
167            .collect::<Result<_, _>>()?;
168        let com_lengths = (0..num_com_markers)
169            .map(|_| bitstream.read_bits(16).map(|x| x + 1))
170            .collect::<Result<_, _>>()?;
171
172        let num_quant_tables = bitstream.read_bits(2)? + 1;
173        let quant_tables = (0..num_quant_tables)
174            .map(|_| QuantTable::parse(bitstream, ()))
175            .collect::<Result<_, _>>()?;
176
177        let comp_type = bitstream.read_bits(2)?;
178        let component_ids = match comp_type {
179            0 => vec![1u8],
180            1 => vec![1u8, 2, 3],
181            2 => vec![b'R', b'G', b'B'],
182            3 => {
183                let num_comp = bitstream.read_bits(2)? as u8 + 1;
184                (0..num_comp)
185                    .map(|_| bitstream.read_bits(8).map(|x| x as u8))
186                    .collect::<Result<_, _>>()?
187            }
188            _ => unreachable!(),
189        };
190        let components = component_ids
191            .into_iter()
192            .map(|id| -> Result<_, Self::Error> {
193                let q_idx = bitstream.read_bits(2)? as u8;
194                Ok(Component { id, q_idx })
195            })
196            .collect::<Result<_, _>>()?;
197
198        let num_huff = bitstream.read_u32(4, 2 + U(3), 10 + U(4), 26 + U(6))?;
199        let huffman_codes = (0..num_huff)
200            .map(|_| HuffmanCode::parse(bitstream, ()))
201            .collect::<Result<_, _>>()?;
202
203        let scan_info = (0..num_scans)
204            .map(|_| ScanInfo::parse(bitstream, ()))
205            .collect::<Result<_, _>>()?;
206        let restart_interval = if has_dri { bitstream.read_bits(16)? } else { 0 };
207        let scan_more_info = (0..num_scans)
208            .map(|_| ScanMoreInfo::parse(bitstream, ()))
209            .collect::<Result<_, _>>()?;
210
211        let intermarker_lengths = (0..num_intermarkers)
212            .map(|_| bitstream.read_bits(16))
213            .collect::<Result<_, _>>()?;
214
215        let tail_data_length = bitstream.read_u32(0, 1 + U(8), 257 + U(16), 65793 + U(22))?;
216
217        let has_padding = bitstream.read_bool()?;
218        let padding_bits = has_padding
219            .then(|| Padding::parse(bitstream, ()))
220            .transpose()?;
221
222        Ok(Self {
223            is_gray,
224            markers,
225            app_markers,
226            com_lengths,
227            quant_tables,
228            components,
229            huffman_codes,
230            scan_info,
231            restart_interval,
232            scan_more_info,
233            intermarker_lengths,
234            tail_data_length,
235            padding_bits,
236        })
237    }
238}
239
240impl JpegBitstreamHeader {
241    fn app_data_len(&self) -> usize {
242        self.app_markers
243            .iter()
244            .filter_map(|marker| (marker.ty == 0).then_some(marker.length as usize))
245            .sum::<usize>()
246    }
247
248    fn com_data_len(&self) -> usize {
249        self.com_lengths.iter().map(|&x| x as usize).sum::<usize>()
250    }
251
252    fn intermarker_data_len(&self) -> usize {
253        self.intermarker_lengths
254            .iter()
255            .map(|&x| x as usize)
256            .sum::<usize>()
257    }
258
259    fn expected_data_len(&self) -> usize {
260        self.app_data_len()
261            + self.com_data_len()
262            + self.intermarker_data_len()
263            + self.tail_data_length as usize
264    }
265
266    pub fn expected_icc_len(&self) -> usize {
267        self.app_markers
268            .iter()
269            .filter(|am| am.ty == 1)
270            .map(|am| am.length as usize - 5 - HEADER_ICC.len())
271            .sum::<usize>()
272    }
273
274    pub fn expected_exif_len(&self) -> usize {
275        self.app_markers
276            .iter()
277            .find(|am| am.ty == 2)
278            .map(|am| am.length as usize - 3 - HEADER_EXIF.len())
279            .unwrap_or(0)
280    }
281
282    pub fn expected_xmp_len(&self) -> usize {
283        self.app_markers
284            .iter()
285            .find(|am| am.ty == 3)
286            .map(|am| am.length as usize - 3 - HEADER_XMP.len())
287            .unwrap_or(0)
288    }
289}
290
291#[derive(Debug)]
292struct AppMarker {
293    ty: u32,
294    length: u32,
295}
296
297impl Bundle for AppMarker {
298    type Error = jxl_bitstream::Error;
299
300    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
301        Ok(Self {
302            ty: bitstream.read_u32(0, 1, 2 + U(1), 4 + U(2))?,
303            length: bitstream.read_bits(16)? + 1,
304        })
305    }
306}
307
308#[derive(Debug)]
309struct QuantTable {
310    precision: u8,
311    index: u8,
312    is_last: bool,
313}
314
315impl Bundle for QuantTable {
316    type Error = jxl_bitstream::Error;
317
318    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
319        Ok(Self {
320            precision: bitstream.read_bits(1)? as u8,
321            index: bitstream.read_bits(2)? as u8,
322            is_last: bitstream.read_bool()?,
323        })
324    }
325}
326
327#[derive(Debug)]
328struct Component {
329    id: u8,
330    q_idx: u8,
331}
332
333#[derive(Debug)]
334struct ScanInfo {
335    ss: u8,
336    se: u8,
337    al: u8,
338    ah: u8,
339    component_info: Vec<ScanComponentInfo>,
340    #[allow(unused)]
341    last_needed_pass: u8,
342}
343
344impl Bundle for ScanInfo {
345    type Error = jxl_bitstream::Error;
346
347    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
348        let num_comps = bitstream.read_bits(2)? as u8 + 1;
349        let ss = bitstream.read_bits(6)? as u8;
350        let se = bitstream.read_bits(6)? as u8;
351        let al = bitstream.read_bits(4)? as u8;
352        let ah = bitstream.read_bits(4)? as u8;
353        let component_info = (0..num_comps)
354            .map(|_| ScanComponentInfo::parse(bitstream, ()))
355            .collect::<Result<_, _>>()?;
356        let last_needed_pass = bitstream.read_u32(0, 1, 2, 3 + U(3))? as u8;
357        Ok(Self {
358            ss,
359            se,
360            ah,
361            al,
362            component_info,
363            last_needed_pass,
364        })
365    }
366}
367
368impl ScanInfo {
369    fn num_comps(&self) -> u8 {
370        self.component_info.len() as u8
371    }
372}
373
374#[derive(Debug)]
375struct ScanComponentInfo {
376    comp_idx: u8,
377    ac_tbl_idx: u8,
378    dc_tbl_idx: u8,
379}
380
381impl Bundle for ScanComponentInfo {
382    type Error = jxl_bitstream::Error;
383
384    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
385        Ok(Self {
386            comp_idx: bitstream.read_bits(2)? as u8,
387            ac_tbl_idx: bitstream.read_bits(2)? as u8,
388            dc_tbl_idx: bitstream.read_bits(2)? as u8,
389        })
390    }
391}
392
393#[derive(Debug)]
394struct ScanMoreInfo {
395    reset_points: HashSet<u32>,
396    extra_zero_runs: HashMap<u32, u32>,
397}
398
399impl Bundle for ScanMoreInfo {
400    type Error = jxl_bitstream::Error;
401
402    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
403        let num_reset_points = bitstream.read_u32(0, 1 + U(2), 4 + U(4), 20 + U(16))?;
404        let mut last_block_idx: Option<u32> = None;
405        let reset_points = (0..num_reset_points)
406            .map(|_| -> Result<_, Self::Error> {
407                let diff = bitstream.read_u32(0, 1 + U(3), 9 + U(5), 41 + U(28))?;
408                let block_idx = if let Some(last_block_idx) = last_block_idx {
409                    last_block_idx.saturating_add(diff + 1)
410                } else {
411                    diff
412                };
413                if block_idx > (3 << 26) {
414                    tracing::error!(value = block_idx, "reset_points too large");
415                    return Err(jxl_bitstream::Error::ValidationFailed(
416                        "reset_points too large",
417                    ));
418                }
419                last_block_idx = Some(block_idx);
420                Ok(block_idx)
421            })
422            .collect::<Result<_, _>>()?;
423
424        let num_extra_zero_runs = bitstream.read_u32(0, 1 + U(2), 4 + U(4), 20 + U(16))?;
425        let mut last_block_idx: Option<u32> = None;
426        let extra_zero_runs = (0..num_extra_zero_runs)
427            .map(|_| -> Result<_, jxl_bitstream::Error> {
428                let ExtraZeroRun {
429                    num_runs,
430                    run_length,
431                } = ExtraZeroRun::parse(bitstream, ())?;
432                let block_idx = if let Some(last_block_idx) = last_block_idx {
433                    last_block_idx.saturating_add(run_length + 1)
434                } else {
435                    run_length
436                };
437                if block_idx > (3 << 26) {
438                    tracing::error!(block_idx, "extra_zero_runs.block_idx too large");
439                    return Err(jxl_bitstream::Error::ValidationFailed(
440                        "extra_zero_runs.block_idx too large",
441                    ));
442                }
443                last_block_idx = Some(block_idx);
444                Ok((block_idx, num_runs))
445            })
446            .collect::<Result<_, _>>()?;
447
448        Ok(Self {
449            reset_points,
450            extra_zero_runs,
451        })
452    }
453}
454
455#[derive(Debug)]
456struct ExtraZeroRun {
457    num_runs: u32,
458    run_length: u32,
459}
460
461impl Bundle for ExtraZeroRun {
462    type Error = jxl_bitstream::Error;
463
464    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
465        Ok(Self {
466            num_runs: bitstream.read_u32(1, 2 + U(2), 5 + U(4), 20 + U(8))?,
467            run_length: bitstream.read_u32(0, 1 + U(3), 9 + U(5), 41 + U(28))?,
468        })
469    }
470}
471
472#[derive(Debug)]
473struct Padding {
474    bits: Vec<u8>,
475}
476
477impl Bundle for Padding {
478    type Error = jxl_bitstream::Error;
479
480    fn parse(bitstream: &mut Bitstream, _: ()) -> Result<Self, Self::Error> {
481        let num_bits = bitstream.read_bits(24)?;
482        let full_bytes = num_bits / 8;
483        let extra_bits = num_bits % 8;
484        let mut bits = Vec::with_capacity(full_bytes as usize + (extra_bits != 0) as usize);
485        for _ in 0..full_bytes {
486            bits.push(bitstream.read_bits(8)? as u8);
487        }
488        bits.push(bitstream.read_bits(extra_bits as usize)? as u8);
489
490        Ok(Self { bits })
491    }
492}