Skip to main content

oximedia_codec/jpegxl/
decoder.rs

1//! JPEG-XL decoder implementation.
2//!
3//! Decodes JPEG-XL codestreams (bare and container format) into raw pixel data.
4//! Currently supports lossless Modular mode for 8-bit and 16-bit images in
5//! grayscale, RGB, and RGBA color spaces.
6
7use super::bitreader::BitReader;
8use super::modular::{ModularDecoder, ModularTransform};
9use super::types::{JxlColorSpace, JxlHeader, JXL_CODESTREAM_SIGNATURE, JXL_CONTAINER_SIGNATURE};
10use crate::error::{CodecError, CodecResult};
11
12/// Decoded JPEG-XL image.
13#[derive(Clone, Debug)]
14pub struct DecodedImage {
15    /// Image width in pixels.
16    pub width: u32,
17    /// Image height in pixels.
18    pub height: u32,
19    /// Number of channels (1=gray, 3=RGB, 4=RGBA).
20    pub channels: u8,
21    /// Bits per sample (8 or 16).
22    pub bit_depth: u8,
23    /// Interleaved pixel data.
24    /// For 8-bit: one byte per sample.
25    /// For 16-bit: two bytes per sample (little-endian).
26    pub data: Vec<u8>,
27    /// Color space of the decoded image.
28    pub color_space: JxlColorSpace,
29}
30
31impl DecodedImage {
32    /// Total number of samples in the image.
33    pub fn sample_count(&self) -> usize {
34        self.width as usize * self.height as usize * self.channels as usize
35    }
36
37    /// Total size of pixel data in bytes.
38    pub fn data_size(&self) -> usize {
39        let bytes_per_sample = if self.bit_depth > 8 { 2 } else { 1 };
40        self.sample_count() * bytes_per_sample
41    }
42}
43
44/// JPEG-XL decoder.
45///
46/// Decodes JPEG-XL files (both bare codestream and ISOBMFF container format)
47/// into raw pixel data.
48pub struct JxlDecoder;
49
50impl JxlDecoder {
51    /// Create a new JPEG-XL decoder.
52    pub fn new() -> Self {
53        Self
54    }
55
56    /// Check if the data starts with a valid JXL signature.
57    ///
58    /// Returns `true` for both bare codestream (0xFF 0x0A) and
59    /// container format signatures.
60    pub fn is_jxl(data: &[u8]) -> bool {
61        Self::is_codestream(data) || Self::is_container(data)
62    }
63
64    /// Check for bare codestream signature.
65    pub fn is_codestream(data: &[u8]) -> bool {
66        data.len() >= 2
67            && data[0] == JXL_CODESTREAM_SIGNATURE[0]
68            && data[1] == JXL_CODESTREAM_SIGNATURE[1]
69    }
70
71    /// Check for ISOBMFF container signature.
72    pub fn is_container(data: &[u8]) -> bool {
73        data.len() >= 12 && data[..12] == JXL_CONTAINER_SIGNATURE
74    }
75
76    /// Decode a JPEG-XL file from bytes.
77    ///
78    /// # Errors
79    ///
80    /// Returns error if:
81    /// - The data does not have a valid JXL signature
82    /// - The header is malformed
83    /// - The image data is corrupt
84    /// - Unsupported features are encountered
85    pub fn decode(&self, data: &[u8]) -> CodecResult<DecodedImage> {
86        let codestream = self.extract_codestream(data)?;
87        let mut reader = BitReader::new(&codestream);
88
89        // Skip signature (2 bytes = 16 bits)
90        let _ = reader.read_bits(16)?;
91
92        // Parse size header
93        let (width, height) = self.parse_size_header(&mut reader)?;
94
95        // Parse image metadata
96        let header = self.parse_image_metadata(&mut reader, width, height)?;
97        header.validate()?;
98
99        // Decode using modular mode
100        let channels_data = self.decode_modular(&mut reader, &header)?;
101
102        // Convert channel data to interleaved byte output
103        let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
104
105        Ok(DecodedImage {
106            width: header.width,
107            height: header.height,
108            channels: header.num_channels,
109            bit_depth: header.bits_per_sample,
110            data: pixel_data,
111            color_space: header.color_space,
112        })
113    }
114
115    /// Read only the header from a JPEG-XL file without fully decoding.
116    ///
117    /// # Errors
118    ///
119    /// Returns error if the signature or header is invalid.
120    pub fn read_header(&self, data: &[u8]) -> CodecResult<JxlHeader> {
121        let codestream = self.extract_codestream(data)?;
122        let mut reader = BitReader::new(&codestream);
123
124        // Skip signature
125        let _ = reader.read_bits(16)?;
126
127        let (width, height) = self.parse_size_header(&mut reader)?;
128        let header = self.parse_image_metadata(&mut reader, width, height)?;
129        header.validate()?;
130        Ok(header)
131    }
132
133    /// Extract the bare codestream from either format.
134    ///
135    /// If the data is a bare codestream, returns it as-is.
136    /// If it is a container, extracts the jxlc box contents.
137    fn extract_codestream<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
138        if Self::is_codestream(data) {
139            return Ok(data);
140        }
141        if Self::is_container(data) {
142            // Parse ISOBMFF boxes to find jxlc (codestream) box
143            return self.find_jxlc_box(data);
144        }
145        Err(CodecError::InvalidBitstream(
146            "Not a valid JPEG-XL file: invalid signature".into(),
147        ))
148    }
149
150    /// Find the jxlc box in an ISOBMFF container.
151    fn find_jxlc_box<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
152        let mut offset = 0;
153        while offset + 8 <= data.len() {
154            let box_size = u32::from_be_bytes([
155                data[offset],
156                data[offset + 1],
157                data[offset + 2],
158                data[offset + 3],
159            ]) as usize;
160
161            let box_type = &data[offset + 4..offset + 8];
162
163            if box_size < 8 {
164                break;
165            }
166
167            if box_type == b"jxlc" {
168                let content_start = offset + 8;
169                let content_end = offset + box_size;
170                if content_end <= data.len() {
171                    return Ok(&data[content_start..content_end]);
172                }
173                return Err(CodecError::InvalidBitstream(
174                    "jxlc box extends past end of file".into(),
175                ));
176            }
177
178            offset += box_size;
179        }
180
181        Err(CodecError::InvalidBitstream(
182            "No jxlc (codestream) box found in container".into(),
183        ))
184    }
185
186    /// Parse the JPEG-XL SizeHeader.
187    ///
188    /// The size header uses a compact variable-length encoding:
189    /// - 1 bit: small flag
190    /// - If small: 5 bits height_div8, 5 bits width_div8 (sizes * 8)
191    /// - If not small: read height and width using U32 encoding
192    fn parse_size_header(&self, reader: &mut BitReader) -> CodecResult<(u32, u32)> {
193        let small = reader.read_bool()?;
194
195        if small {
196            let height_div8 = reader.read_bits(5)? + 1;
197            let width_div8 = reader.read_bits(5)?;
198            // Width uses ratio based on height if not specified
199            let width_div8 = if width_div8 == 0 {
200                height_div8
201            } else {
202                width_div8
203            };
204            Ok((width_div8 * 8, height_div8 * 8))
205        } else {
206            // Full U32 encoding for height and width
207            let height = self.read_size_u32(reader)?;
208            let width = self.read_size_u32(reader)?;
209            Ok((width, height))
210        }
211    }
212
213    /// Read a size value using JPEG-XL's SizeHeader U32 encoding.
214    ///
215    /// Distribution: d0=1, d1=1+read(9), d2=1+read(13), d3=1+read(18)
216    fn read_size_u32(&self, reader: &mut BitReader) -> CodecResult<u32> {
217        let selector = reader.read_bits(2)?;
218        match selector {
219            0 => Ok(1),
220            1 => {
221                let extra = reader.read_bits(9)?;
222                Ok(1 + extra)
223            }
224            2 => {
225                let extra = reader.read_bits(13)?;
226                Ok(1 + extra)
227            }
228            3 => {
229                let extra = reader.read_bits(18)?;
230                Ok(1 + extra)
231            }
232            _ => Err(CodecError::InvalidBitstream("Invalid size selector".into())),
233        }
234    }
235
236    /// Parse the ImageMetadata section.
237    ///
238    /// This is a simplified parser that reads the essential fields:
239    /// - all_default flag
240    /// - bit_depth
241    /// - color space
242    /// - alpha flag
243    fn parse_image_metadata(
244        &self,
245        reader: &mut BitReader,
246        width: u32,
247        height: u32,
248    ) -> CodecResult<JxlHeader> {
249        // all_default flag: if true, use default 8-bit sRGB
250        let all_default = reader.read_bool()?;
251
252        if all_default {
253            return Ok(JxlHeader {
254                width,
255                height,
256                bits_per_sample: 8,
257                num_channels: 3,
258                is_float: false,
259                has_alpha: false,
260                color_space: JxlColorSpace::Srgb,
261                orientation: 1,
262            });
263        }
264
265        // Extra fields present
266        let has_extra_fields = reader.read_bool()?;
267        let orientation = if has_extra_fields {
268            reader.read_bits(3)? as u8 + 1
269        } else {
270            1
271        };
272
273        // Bit depth
274        let float_flag = reader.read_bool()?;
275        let bits_per_sample = if float_flag {
276            // Float samples: read exponent bits
277            let _exp_bits = reader.read_bits(4)?;
278            let mantissa_bits = reader.read_bits(4)? + 1;
279            (mantissa_bits + 1) as u8 // approximate total bits
280        } else {
281            let depth_selector = reader.read_bits(2)?;
282            match depth_selector {
283                0 => 8,
284                1 => 10,
285                2 => 12,
286                3 => {
287                    let custom = reader.read_bits(6)?;
288                    (custom + 1) as u8
289                }
290                _ => 8,
291            }
292        };
293
294        // Color space
295        let color_space_selector = reader.read_bits(2)?;
296        let color_space = match color_space_selector {
297            0 => JxlColorSpace::Srgb,
298            1 => JxlColorSpace::LinearSrgb,
299            2 => JxlColorSpace::Gray,
300            3 => JxlColorSpace::Xyb,
301            _ => JxlColorSpace::Srgb,
302        };
303
304        let num_color_channels = if color_space == JxlColorSpace::Gray {
305            1u8
306        } else {
307            3u8
308        };
309
310        // Alpha
311        let has_alpha = reader.read_bool()?;
312        let num_channels = if has_alpha {
313            num_color_channels + 1
314        } else {
315            num_color_channels
316        };
317
318        Ok(JxlHeader {
319            width,
320            height,
321            bits_per_sample,
322            num_channels,
323            is_float: float_flag,
324            has_alpha,
325            color_space,
326            orientation,
327        })
328    }
329
330    /// Decode the image data using the Modular sub-codec.
331    fn decode_modular(
332        &self,
333        reader: &mut BitReader,
334        header: &JxlHeader,
335    ) -> CodecResult<Vec<Vec<i32>>> {
336        reader.align_to_byte();
337
338        // Collect remaining data for the modular decoder
339        let remaining_bits = reader.remaining_bits();
340        if remaining_bits == 0 {
341            return Err(CodecError::InvalidBitstream(
342                "No image data after header".into(),
343            ));
344        }
345
346        // Read all remaining bytes into a buffer for the modular decoder
347        let remaining_bytes = (remaining_bits + 7) / 8;
348        let mut data = Vec::with_capacity(remaining_bytes);
349        for _ in 0..remaining_bytes {
350            match reader.read_u8(8) {
351                Ok(byte) => data.push(byte),
352                Err(_) => break,
353            }
354        }
355
356        let mut decoder = ModularDecoder::new();
357
358        // Add RCT transform for RGB/RGBA images (3+ color channels)
359        if header.color_channels() >= 3 {
360            decoder.add_transform(ModularTransform::Rct {
361                begin_channel: 0,
362                rct_type: 0,
363            });
364        }
365
366        decoder.decode_image(
367            &data,
368            header.width,
369            header.height,
370            header.num_channels as u32,
371            header.bits_per_sample,
372        )
373    }
374
375    /// Convert decoded channel data to interleaved byte output.
376    fn channels_to_interleaved(
377        &self,
378        channels: &[Vec<i32>],
379        header: &JxlHeader,
380    ) -> CodecResult<Vec<u8>> {
381        let pixel_count = header.width as usize * header.height as usize;
382        let num_channels = header.num_channels as usize;
383        let bytes_per_sample = header.bytes_per_sample();
384
385        if channels.len() != num_channels {
386            return Err(CodecError::Internal(format!(
387                "Expected {} channels, got {}",
388                num_channels,
389                channels.len()
390            )));
391        }
392
393        let total_bytes = pixel_count * num_channels * bytes_per_sample;
394        let mut output = Vec::with_capacity(total_bytes);
395
396        for i in 0..pixel_count {
397            for ch in 0..num_channels {
398                let value = channels[ch][i];
399
400                match bytes_per_sample {
401                    1 => {
402                        // Clamp to [0, 255]
403                        let clamped = value.clamp(0, 255) as u8;
404                        output.push(clamped);
405                    }
406                    2 => {
407                        // Clamp to [0, 65535], little-endian
408                        let clamped = value.clamp(0, 65535) as u16;
409                        output.push(clamped as u8);
410                        output.push((clamped >> 8) as u8);
411                    }
412                    _ => {
413                        // 32-bit: store as 4 bytes, little-endian
414                        let bytes = (value as u32).to_le_bytes();
415                        output.extend_from_slice(&bytes);
416                    }
417                }
418            }
419        }
420
421        Ok(output)
422    }
423}
424
425impl Default for JxlDecoder {
426    fn default() -> Self {
427        Self::new()
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    #[ignore]
437    fn test_is_codestream_signature() {
438        assert!(JxlDecoder::is_codestream(&[0xFF, 0x0A, 0x00]));
439        assert!(!JxlDecoder::is_codestream(&[0xFF, 0x0B, 0x00]));
440        assert!(!JxlDecoder::is_codestream(&[0xFF]));
441        assert!(!JxlDecoder::is_codestream(&[]));
442    }
443
444    #[test]
445    #[ignore]
446    fn test_is_container_signature() {
447        let mut container = vec![0u8; 16];
448        container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
449        assert!(JxlDecoder::is_container(&container));
450        assert!(!JxlDecoder::is_container(&[0xFF, 0x0A]));
451    }
452
453    #[test]
454    #[ignore]
455    fn test_is_jxl() {
456        assert!(JxlDecoder::is_jxl(&[0xFF, 0x0A]));
457        let mut container = vec![0u8; 16];
458        container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
459        assert!(JxlDecoder::is_jxl(&container));
460        assert!(!JxlDecoder::is_jxl(&[0x00, 0x00]));
461    }
462
463    #[test]
464    #[ignore]
465    fn test_extract_codestream_bare() {
466        let decoder = JxlDecoder::new();
467        let data = [0xFF, 0x0A, 0x01, 0x02];
468        let result = decoder.extract_codestream(&data).expect("ok");
469        assert_eq!(result, &data);
470    }
471
472    #[test]
473    #[ignore]
474    fn test_extract_codestream_invalid() {
475        let decoder = JxlDecoder::new();
476        assert!(decoder.extract_codestream(&[0x00, 0x00]).is_err());
477    }
478
479    #[test]
480    #[ignore]
481    fn test_parse_size_header_small() {
482        // small=1, height_div8=3 (24px), width_div8=0 (use height -> 24px)
483        let decoder = JxlDecoder::new();
484        let mut writer = super::super::bitreader::BitWriter::new();
485        writer.write_bool(true); // small = true
486        writer.write_bits(2, 5); // height_div8 - 1 = 2 -> height = 3*8 = 24
487        writer.write_bits(0, 5); // width_div8 = 0 -> use height_div8
488        let data = writer.finish();
489
490        let mut reader = BitReader::new(&data);
491        let (w, h) = decoder.parse_size_header(&mut reader).expect("ok");
492        assert_eq!(h, 24);
493        assert_eq!(w, 24);
494    }
495
496    #[test]
497    #[ignore]
498    fn test_read_header_invalid_data() {
499        let decoder = JxlDecoder::new();
500        assert!(decoder.read_header(&[0x00]).is_err());
501    }
502
503    #[test]
504    #[ignore]
505    fn test_decoded_image_metrics() {
506        let img = DecodedImage {
507            width: 10,
508            height: 10,
509            channels: 3,
510            bit_depth: 8,
511            data: vec![0u8; 300],
512            color_space: JxlColorSpace::Srgb,
513        };
514        assert_eq!(img.sample_count(), 300);
515        assert_eq!(img.data_size(), 300);
516    }
517
518    #[test]
519    #[ignore]
520    fn test_decoded_image_16bit() {
521        let img = DecodedImage {
522            width: 10,
523            height: 10,
524            channels: 3,
525            bit_depth: 16,
526            data: vec![0u8; 600],
527            color_space: JxlColorSpace::Srgb,
528        };
529        assert_eq!(img.sample_count(), 300);
530        assert_eq!(img.data_size(), 600);
531    }
532}