1use std::io::{self, Read};
8
9use super::bitreader::BitReader;
10use super::modular::{ModularDecoder, ModularTransform};
11use super::types::{
12 JxlAnimation, JxlColorSpace, JxlFrame, JxlHeader, JXL_CODESTREAM_SIGNATURE,
13 JXL_CONTAINER_SIGNATURE,
14};
15use crate::container::isobmff::BoxIter;
16use crate::error::{CodecError, CodecResult};
17
18type PeekedReader<R> = io::Chain<io::Cursor<Vec<u8>>, R>;
21
22#[derive(Clone, Debug)]
24pub struct DecodedImage {
25 pub width: u32,
27 pub height: u32,
29 pub channels: u8,
31 pub bit_depth: u8,
33 pub data: Vec<u8>,
37 pub color_space: JxlColorSpace,
39}
40
41impl DecodedImage {
42 pub fn sample_count(&self) -> usize {
44 self.width as usize * self.height as usize * self.channels as usize
45 }
46
47 pub fn data_size(&self) -> usize {
49 let bytes_per_sample = if self.bit_depth > 8 { 2 } else { 1 };
50 self.sample_count() * bytes_per_sample
51 }
52}
53
54pub struct JxlDecoder;
59
60impl JxlDecoder {
61 pub fn new() -> Self {
63 Self
64 }
65
66 pub fn is_jxl(data: &[u8]) -> bool {
71 Self::is_codestream(data) || Self::is_container(data)
72 }
73
74 pub fn is_codestream(data: &[u8]) -> bool {
76 data.len() >= 2
77 && data[0] == JXL_CODESTREAM_SIGNATURE[0]
78 && data[1] == JXL_CODESTREAM_SIGNATURE[1]
79 }
80
81 pub fn is_container(data: &[u8]) -> bool {
83 data.len() >= 12 && data[..12] == JXL_CONTAINER_SIGNATURE
84 }
85
86 pub fn decode(&self, data: &[u8]) -> CodecResult<DecodedImage> {
96 let codestream = self.extract_codestream(data)?;
97 let mut reader = BitReader::new(&codestream);
98
99 let _ = reader.read_bits(16)?;
101
102 let (width, height) = self.parse_size_header(&mut reader)?;
104
105 let header = self.parse_image_metadata(&mut reader, width, height)?;
107 header.validate()?;
108
109 let channels_data = self.decode_modular(&mut reader, &header)?;
111
112 let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
114
115 Ok(DecodedImage {
116 width: header.width,
117 height: header.height,
118 channels: header.num_channels,
119 bit_depth: header.bits_per_sample,
120 data: pixel_data,
121 color_space: header.color_space,
122 })
123 }
124
125 pub fn read_header(&self, data: &[u8]) -> CodecResult<JxlHeader> {
131 let codestream = self.extract_codestream(data)?;
132 let mut reader = BitReader::new(&codestream);
133
134 let _ = reader.read_bits(16)?;
136
137 let (width, height) = self.parse_size_header(&mut reader)?;
138 let header = self.parse_image_metadata(&mut reader, width, height)?;
139 header.validate()?;
140 Ok(header)
141 }
142
143 fn extract_codestream<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
148 if Self::is_codestream(data) {
149 return Ok(data);
150 }
151 if Self::is_container(data) {
152 return self.find_jxlc_box(data);
154 }
155 Err(CodecError::InvalidBitstream(
156 "Not a valid JPEG-XL file: invalid signature".into(),
157 ))
158 }
159
160 fn find_jxlc_box<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
162 let mut offset = 0;
163 while offset + 8 <= data.len() {
164 let box_size = u32::from_be_bytes([
165 data[offset],
166 data[offset + 1],
167 data[offset + 2],
168 data[offset + 3],
169 ]) as usize;
170
171 let box_type = &data[offset + 4..offset + 8];
172
173 if box_size < 8 {
174 break;
175 }
176
177 if box_type == b"jxlc" {
178 let content_start = offset + 8;
179 let content_end = offset + box_size;
180 if content_end <= data.len() {
181 return Ok(&data[content_start..content_end]);
182 }
183 return Err(CodecError::InvalidBitstream(
184 "jxlc box extends past end of file".into(),
185 ));
186 }
187
188 offset += box_size;
189 }
190
191 Err(CodecError::InvalidBitstream(
192 "No jxlc (codestream) box found in container".into(),
193 ))
194 }
195
196 fn parse_size_header(&self, reader: &mut BitReader) -> CodecResult<(u32, u32)> {
203 let small = reader.read_bool()?;
204
205 if small {
206 let height_div8 = reader.read_bits(5)? + 1;
207 let width_div8 = reader.read_bits(5)?;
208 let width_div8 = if width_div8 == 0 {
210 height_div8
211 } else {
212 width_div8
213 };
214 Ok((width_div8 * 8, height_div8 * 8))
215 } else {
216 let height = self.read_size_u32(reader)?;
218 let width = self.read_size_u32(reader)?;
219 Ok((width, height))
220 }
221 }
222
223 fn read_size_u32(&self, reader: &mut BitReader) -> CodecResult<u32> {
227 let selector = reader.read_bits(2)?;
228 match selector {
229 0 => Ok(1),
230 1 => {
231 let extra = reader.read_bits(9)?;
232 Ok(1 + extra)
233 }
234 2 => {
235 let extra = reader.read_bits(13)?;
236 Ok(1 + extra)
237 }
238 3 => {
239 let extra = reader.read_bits(18)?;
240 Ok(1 + extra)
241 }
242 _ => Err(CodecError::InvalidBitstream("Invalid size selector".into())),
243 }
244 }
245
246 fn parse_image_metadata(
255 &self,
256 reader: &mut BitReader,
257 width: u32,
258 height: u32,
259 ) -> CodecResult<JxlHeader> {
260 let all_default = reader.read_bool()?;
262
263 if all_default {
264 return Ok(JxlHeader {
265 width,
266 height,
267 bits_per_sample: 8,
268 num_channels: 3,
269 is_float: false,
270 has_alpha: false,
271 color_space: JxlColorSpace::Srgb,
272 orientation: 1,
273 animation: None,
274 });
275 }
276
277 let has_extra_fields = reader.read_bool()?;
279 let orientation = if has_extra_fields {
280 reader.read_bits(3)? as u8 + 1
281 } else {
282 1
283 };
284
285 let float_flag = reader.read_bool()?;
287 let bits_per_sample = if float_flag {
288 let _exp_bits = reader.read_bits(4)?;
290 let mantissa_bits = reader.read_bits(4)? + 1;
291 (mantissa_bits + 1) as u8 } else {
293 let depth_selector = reader.read_bits(2)?;
294 match depth_selector {
295 0 => 8,
296 1 => 10,
297 2 => 12,
298 3 => {
299 let custom = reader.read_bits(6)?;
300 (custom + 1) as u8
301 }
302 _ => 8,
303 }
304 };
305
306 let color_space_selector = reader.read_bits(2)?;
308 let color_space = match color_space_selector {
309 0 => JxlColorSpace::Srgb,
310 1 => JxlColorSpace::LinearSrgb,
311 2 => JxlColorSpace::Gray,
312 3 => JxlColorSpace::Xyb,
313 _ => JxlColorSpace::Srgb,
314 };
315
316 let num_color_channels = if color_space == JxlColorSpace::Gray {
317 1u8
318 } else {
319 3u8
320 };
321
322 let has_alpha = reader.read_bool()?;
324 let num_channels = if has_alpha {
325 num_color_channels + 1
326 } else {
327 num_color_channels
328 };
329
330 let has_animation = reader.read_bool()?;
332 let animation = if has_animation {
333 Some(Self::parse_animation_header(reader)?)
334 } else {
335 None
336 };
337
338 Ok(JxlHeader {
339 width,
340 height,
341 bits_per_sample,
342 num_channels,
343 is_float: float_flag,
344 has_alpha,
345 color_space,
346 orientation,
347 animation,
348 })
349 }
350
351 fn parse_animation_header(reader: &mut BitReader) -> CodecResult<JxlAnimation> {
353 let tps_numerator = reader.read_bits(32)?;
354 let tps_denominator = reader.read_bits(32)?;
355 let num_loops = reader.read_bits(32)?;
356 let have_timecodes = reader.read_bool()?;
357
358 if tps_numerator == 0 {
359 return Err(CodecError::InvalidBitstream(
360 "Animation tps_numerator must be non-zero".into(),
361 ));
362 }
363 if tps_denominator == 0 {
364 return Err(CodecError::InvalidBitstream(
365 "Animation tps_denominator must be non-zero".into(),
366 ));
367 }
368
369 Ok(JxlAnimation {
370 tps_numerator,
371 tps_denominator,
372 num_loops,
373 have_timecodes,
374 })
375 }
376
377 fn parse_frame_header(reader: &mut BitReader) -> CodecResult<(u32, bool)> {
381 let duration_ticks = reader.read_bits(32)?;
382 let is_last = reader.read_bool()?;
383 Ok((duration_ticks, is_last))
384 }
385
386 fn decode_modular(
388 &self,
389 reader: &mut BitReader,
390 header: &JxlHeader,
391 ) -> CodecResult<Vec<Vec<i32>>> {
392 reader.align_to_byte();
393
394 let remaining_bits = reader.remaining_bits();
396 if remaining_bits == 0 {
397 return Err(CodecError::InvalidBitstream(
398 "No image data after header".into(),
399 ));
400 }
401
402 let remaining_bytes = (remaining_bits + 7) / 8;
404 let mut data = Vec::with_capacity(remaining_bytes);
405 for _ in 0..remaining_bytes {
406 match reader.read_u8(8) {
407 Ok(byte) => data.push(byte),
408 Err(_) => break,
409 }
410 }
411
412 let mut decoder = ModularDecoder::new();
413
414 if header.color_channels() >= 3 {
416 decoder.add_transform(ModularTransform::Rct {
417 begin_channel: 0,
418 rct_type: 0,
419 });
420 }
421
422 decoder.decode_image(
423 &data,
424 header.width,
425 header.height,
426 header.num_channels as u32,
427 header.bits_per_sample,
428 )
429 }
430
431 fn channels_to_interleaved(
433 &self,
434 channels: &[Vec<i32>],
435 header: &JxlHeader,
436 ) -> CodecResult<Vec<u8>> {
437 let pixel_count = header.width as usize * header.height as usize;
438 let num_channels = header.num_channels as usize;
439 let bytes_per_sample = header.bytes_per_sample();
440
441 if channels.len() != num_channels {
442 return Err(CodecError::Internal(format!(
443 "Expected {} channels, got {}",
444 num_channels,
445 channels.len()
446 )));
447 }
448
449 let total_bytes = pixel_count * num_channels * bytes_per_sample;
450 let mut output = Vec::with_capacity(total_bytes);
451
452 for i in 0..pixel_count {
453 for ch in 0..num_channels {
454 let value = channels[ch][i];
455
456 match bytes_per_sample {
457 1 => {
458 let clamped = value.clamp(0, 255) as u8;
460 output.push(clamped);
461 }
462 2 => {
463 let clamped = value.clamp(0, 65535) as u16;
465 output.push(clamped as u8);
466 output.push((clamped >> 8) as u8);
467 }
468 _ => {
469 let bytes = (value as u32).to_le_bytes();
471 output.extend_from_slice(&bytes);
472 }
473 }
474 }
475 }
476
477 Ok(output)
478 }
479
480 pub fn decode_animated(&self, data: &[u8]) -> CodecResult<Vec<JxlFrame>> {
489 let codestream = self.extract_codestream(data)?;
490 let mut reader = BitReader::new(codestream);
491
492 let _ = reader.read_bits(16)?;
494
495 let (width, height) = self.parse_size_header(&mut reader)?;
497
498 let header = self.parse_image_metadata(&mut reader, width, height)?;
500 header.validate()?;
501
502 if header.animation.is_none() {
503 let channels_data = self.decode_modular(&mut reader, &header)?;
505 let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
506
507 return Ok(vec![JxlFrame {
508 data: pixel_data,
509 width: header.width,
510 height: header.height,
511 channels: header.num_channels,
512 bit_depth: header.bits_per_sample,
513 duration_ticks: 0,
514 is_last: true,
515 color_space: header.color_space,
516 }]);
517 }
518
519 let mut frames = Vec::new();
521
522 loop {
523 if reader.remaining_bits() < 33 {
525 break;
527 }
528
529 let (duration_ticks, is_last) = Self::parse_frame_header(&mut reader)?;
530
531 reader.align_to_byte();
533
534 if reader.remaining_bits() < 32 {
536 return Err(CodecError::InvalidBitstream(
537 "Unexpected end of animated codestream before frame data length".into(),
538 ));
539 }
540 let data_len = reader.read_bits(32)? as usize;
541
542 if reader.remaining_bits() < data_len * 8 {
544 return Err(CodecError::InvalidBitstream(format!(
545 "Animated frame data truncated: expected {data_len} bytes, \
546 have {} bits remaining",
547 reader.remaining_bits()
548 )));
549 }
550
551 let mut frame_data_bytes = Vec::with_capacity(data_len);
552 for _ in 0..data_len {
553 frame_data_bytes.push(reader.read_u8(8)?);
554 }
555
556 let channels_data = self.decode_frame_modular(&frame_data_bytes, &header)?;
558 let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
559
560 frames.push(JxlFrame {
561 data: pixel_data,
562 width: header.width,
563 height: header.height,
564 channels: header.num_channels,
565 bit_depth: header.bits_per_sample,
566 duration_ticks,
567 is_last,
568 color_space: header.color_space,
569 });
570
571 if is_last {
572 break;
573 }
574 }
575
576 if frames.is_empty() {
577 return Err(CodecError::InvalidBitstream(
578 "Animated codestream contains no frames".into(),
579 ));
580 }
581
582 Ok(frames)
583 }
584
585 fn decode_frame_modular(&self, data: &[u8], header: &JxlHeader) -> CodecResult<Vec<Vec<i32>>> {
587 let mut decoder = ModularDecoder::new();
588
589 if header.color_channels() >= 3 {
591 decoder.add_transform(ModularTransform::Rct {
592 begin_channel: 0,
593 rct_type: 0,
594 });
595 }
596
597 decoder.decode_image(
598 data,
599 header.width,
600 header.height,
601 header.num_channels as u32,
602 header.bits_per_sample,
603 )
604 }
605
606 pub fn is_animated(&self, data: &[u8]) -> CodecResult<bool> {
612 let header = self.read_header(data)?;
613 Ok(header.animation.is_some())
614 }
615
616 pub fn read_animation_header(&self, data: &[u8]) -> CodecResult<Option<JxlAnimation>> {
624 let header = self.read_header(data)?;
625 Ok(header.animation)
626 }
627}
628
629impl Default for JxlDecoder {
630 fn default() -> Self {
631 Self::new()
632 }
633}
634
635#[derive(Copy, Clone, Debug, PartialEq, Eq)]
639enum JxlFormat {
640 Isobmff,
643 Native,
646}
647
648pub struct JxlStreamingDecoder<R: Read> {
672 format: JxlFormat,
673 box_iter: Option<BoxIter<PeekedReader<R>>>,
676 codestream_buf: Vec<u8>,
678 pending_frames: std::vec::IntoIter<JxlFrame>,
680 done: bool,
682}
683
684impl<R: Read> JxlStreamingDecoder<R> {
685 pub fn new(mut reader: R) -> CodecResult<Self> {
695 let mut peek = [0u8; 12];
697 let n = reader.read(&mut peek)?;
698 let peek_bytes = peek[..n].to_vec();
699
700 let format = if n >= 12 && peek_bytes[4..8] == *b"ftyp" && peek_bytes[8..12] == *b"jxl " {
702 JxlFormat::Isobmff
703 } else {
704 JxlFormat::Native
705 };
706
707 let mut chained: PeekedReader<R> = io::Cursor::new(peek_bytes).chain(reader);
709
710 match format {
711 JxlFormat::Isobmff => Ok(Self {
712 format,
713 box_iter: Some(BoxIter::new(chained)),
714 codestream_buf: Vec::new(),
715 pending_frames: Vec::new().into_iter(),
716 done: false,
717 }),
718 JxlFormat::Native => {
719 let mut all_bytes = Vec::new();
721 chained
722 .read_to_end(&mut all_bytes)
723 .map_err(CodecError::Io)?;
724
725 let frames = JxlDecoder::new().decode_animated(&all_bytes)?;
726 Ok(Self {
727 format,
728 box_iter: None,
729 codestream_buf: Vec::new(),
730 pending_frames: frames.into_iter(),
731 done: false,
732 })
733 }
734 }
735 }
736}
737
738impl<R: Read> Iterator for JxlStreamingDecoder<R> {
739 type Item = CodecResult<JxlFrame>;
740
741 fn next(&mut self) -> Option<Self::Item> {
742 if self.done {
743 return None;
744 }
745
746 if let Some(frame) = self.pending_frames.next() {
748 return Some(Ok(frame));
749 }
750
751 match self.format {
752 JxlFormat::Native => {
756 self.done = true;
757 None
758 }
759
760 JxlFormat::Isobmff => {
763 let box_iter = match self.box_iter.as_mut() {
764 Some(bi) => bi,
765 None => {
766 self.done = true;
767 return None;
768 }
769 };
770
771 loop {
772 match box_iter.next() {
773 None => {
775 self.done = true;
776 if !self.codestream_buf.is_empty() {
777 let buf = std::mem::take(&mut self.codestream_buf);
778 return Some(Self::flush_codestream(buf, &mut self.pending_frames));
779 }
780 return None;
781 }
782
783 Some(Err(e)) => {
785 self.done = true;
786 return Some(Err(CodecError::Io(e)));
787 }
788
789 Some(Ok((fourcc, payload))) => {
791 if fourcc != *b"jxlp" {
792 continue;
794 }
795
796 if payload.len() < 4 {
798 self.done = true;
799 return Some(Err(CodecError::InvalidBitstream(
800 "jxlp box payload too short (< 4 bytes)".into(),
801 )));
802 }
803
804 let mut idx_buf = [0u8; 4];
805 idx_buf.copy_from_slice(&payload[0..4]);
806 let is_last = (u32::from_be_bytes(idx_buf) & 0x8000_0000) != 0;
807
808 self.codestream_buf.extend_from_slice(&payload[4..]);
809
810 if is_last {
811 let buf = std::mem::take(&mut self.codestream_buf);
812 self.box_iter = None;
813 return Some(Self::flush_codestream(buf, &mut self.pending_frames));
814 }
815 }
817 }
818 }
819 }
820 }
821 }
822}
823
824impl<R: Read> JxlStreamingDecoder<R> {
825 fn flush_codestream(
828 buf: Vec<u8>,
829 pending: &mut std::vec::IntoIter<JxlFrame>,
830 ) -> CodecResult<JxlFrame> {
831 let mut frames = JxlDecoder::new().decode_animated(&buf)?;
832 if frames.is_empty() {
833 return Err(CodecError::InvalidBitstream(
834 "jxlp codestream contained no frames".into(),
835 ));
836 }
837 let first = frames.remove(0);
838 *pending = frames.into_iter();
839 Ok(first)
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846
847 #[test]
848 #[ignore]
849 fn test_is_codestream_signature() {
850 assert!(JxlDecoder::is_codestream(&[0xFF, 0x0A, 0x00]));
851 assert!(!JxlDecoder::is_codestream(&[0xFF, 0x0B, 0x00]));
852 assert!(!JxlDecoder::is_codestream(&[0xFF]));
853 assert!(!JxlDecoder::is_codestream(&[]));
854 }
855
856 #[test]
857 #[ignore]
858 fn test_is_container_signature() {
859 let mut container = vec![0u8; 16];
860 container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
861 assert!(JxlDecoder::is_container(&container));
862 assert!(!JxlDecoder::is_container(&[0xFF, 0x0A]));
863 }
864
865 #[test]
866 #[ignore]
867 fn test_is_jxl() {
868 assert!(JxlDecoder::is_jxl(&[0xFF, 0x0A]));
869 let mut container = vec![0u8; 16];
870 container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
871 assert!(JxlDecoder::is_jxl(&container));
872 assert!(!JxlDecoder::is_jxl(&[0x00, 0x00]));
873 }
874
875 #[test]
876 #[ignore]
877 fn test_extract_codestream_bare() {
878 let decoder = JxlDecoder::new();
879 let data = [0xFF, 0x0A, 0x01, 0x02];
880 let result = decoder.extract_codestream(&data).expect("ok");
881 assert_eq!(result, &data);
882 }
883
884 #[test]
885 #[ignore]
886 fn test_extract_codestream_invalid() {
887 let decoder = JxlDecoder::new();
888 assert!(decoder.extract_codestream(&[0x00, 0x00]).is_err());
889 }
890
891 #[test]
892 #[ignore]
893 fn test_parse_size_header_small() {
894 let decoder = JxlDecoder::new();
896 let mut writer = super::super::bitreader::BitWriter::new();
897 writer.write_bool(true); writer.write_bits(2, 5); writer.write_bits(0, 5); let data = writer.finish();
901
902 let mut reader = BitReader::new(&data);
903 let (w, h) = decoder.parse_size_header(&mut reader).expect("ok");
904 assert_eq!(h, 24);
905 assert_eq!(w, 24);
906 }
907
908 #[test]
909 #[ignore]
910 fn test_read_header_invalid_data() {
911 let decoder = JxlDecoder::new();
912 assert!(decoder.read_header(&[0x00]).is_err());
913 }
914
915 #[test]
916 #[ignore]
917 fn test_decoded_image_metrics() {
918 let img = DecodedImage {
919 width: 10,
920 height: 10,
921 channels: 3,
922 bit_depth: 8,
923 data: vec![0u8; 300],
924 color_space: JxlColorSpace::Srgb,
925 };
926 assert_eq!(img.sample_count(), 300);
927 assert_eq!(img.data_size(), 300);
928 }
929
930 #[test]
931 #[ignore]
932 fn test_decoded_image_16bit() {
933 let img = DecodedImage {
934 width: 10,
935 height: 10,
936 channels: 3,
937 bit_depth: 16,
938 data: vec![0u8; 600],
939 color_space: JxlColorSpace::Srgb,
940 };
941 assert_eq!(img.sample_count(), 300);
942 assert_eq!(img.data_size(), 600);
943 }
944}