1use super::bitreader::BitReader;
8use super::modular::{ModularDecoder, ModularTransform};
9use super::types::{JxlColorSpace, JxlHeader, JXL_CODESTREAM_SIGNATURE, JXL_CONTAINER_SIGNATURE};
10use crate::error::{CodecError, CodecResult};
11
12#[derive(Clone, Debug)]
14pub struct DecodedImage {
15 pub width: u32,
17 pub height: u32,
19 pub channels: u8,
21 pub bit_depth: u8,
23 pub data: Vec<u8>,
27 pub color_space: JxlColorSpace,
29}
30
31impl DecodedImage {
32 pub fn sample_count(&self) -> usize {
34 self.width as usize * self.height as usize * self.channels as usize
35 }
36
37 pub fn data_size(&self) -> usize {
39 let bytes_per_sample = if self.bit_depth > 8 { 2 } else { 1 };
40 self.sample_count() * bytes_per_sample
41 }
42}
43
44pub struct JxlDecoder;
49
50impl JxlDecoder {
51 pub fn new() -> Self {
53 Self
54 }
55
56 pub fn is_jxl(data: &[u8]) -> bool {
61 Self::is_codestream(data) || Self::is_container(data)
62 }
63
64 pub fn is_codestream(data: &[u8]) -> bool {
66 data.len() >= 2
67 && data[0] == JXL_CODESTREAM_SIGNATURE[0]
68 && data[1] == JXL_CODESTREAM_SIGNATURE[1]
69 }
70
71 pub fn is_container(data: &[u8]) -> bool {
73 data.len() >= 12 && data[..12] == JXL_CONTAINER_SIGNATURE
74 }
75
76 pub fn decode(&self, data: &[u8]) -> CodecResult<DecodedImage> {
86 let codestream = self.extract_codestream(data)?;
87 let mut reader = BitReader::new(&codestream);
88
89 let _ = reader.read_bits(16)?;
91
92 let (width, height) = self.parse_size_header(&mut reader)?;
94
95 let header = self.parse_image_metadata(&mut reader, width, height)?;
97 header.validate()?;
98
99 let channels_data = self.decode_modular(&mut reader, &header)?;
101
102 let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
104
105 Ok(DecodedImage {
106 width: header.width,
107 height: header.height,
108 channels: header.num_channels,
109 bit_depth: header.bits_per_sample,
110 data: pixel_data,
111 color_space: header.color_space,
112 })
113 }
114
115 pub fn read_header(&self, data: &[u8]) -> CodecResult<JxlHeader> {
121 let codestream = self.extract_codestream(data)?;
122 let mut reader = BitReader::new(&codestream);
123
124 let _ = reader.read_bits(16)?;
126
127 let (width, height) = self.parse_size_header(&mut reader)?;
128 let header = self.parse_image_metadata(&mut reader, width, height)?;
129 header.validate()?;
130 Ok(header)
131 }
132
133 fn extract_codestream<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
138 if Self::is_codestream(data) {
139 return Ok(data);
140 }
141 if Self::is_container(data) {
142 return self.find_jxlc_box(data);
144 }
145 Err(CodecError::InvalidBitstream(
146 "Not a valid JPEG-XL file: invalid signature".into(),
147 ))
148 }
149
150 fn find_jxlc_box<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
152 let mut offset = 0;
153 while offset + 8 <= data.len() {
154 let box_size = u32::from_be_bytes([
155 data[offset],
156 data[offset + 1],
157 data[offset + 2],
158 data[offset + 3],
159 ]) as usize;
160
161 let box_type = &data[offset + 4..offset + 8];
162
163 if box_size < 8 {
164 break;
165 }
166
167 if box_type == b"jxlc" {
168 let content_start = offset + 8;
169 let content_end = offset + box_size;
170 if content_end <= data.len() {
171 return Ok(&data[content_start..content_end]);
172 }
173 return Err(CodecError::InvalidBitstream(
174 "jxlc box extends past end of file".into(),
175 ));
176 }
177
178 offset += box_size;
179 }
180
181 Err(CodecError::InvalidBitstream(
182 "No jxlc (codestream) box found in container".into(),
183 ))
184 }
185
186 fn parse_size_header(&self, reader: &mut BitReader) -> CodecResult<(u32, u32)> {
193 let small = reader.read_bool()?;
194
195 if small {
196 let height_div8 = reader.read_bits(5)? + 1;
197 let width_div8 = reader.read_bits(5)?;
198 let width_div8 = if width_div8 == 0 {
200 height_div8
201 } else {
202 width_div8
203 };
204 Ok((width_div8 * 8, height_div8 * 8))
205 } else {
206 let height = self.read_size_u32(reader)?;
208 let width = self.read_size_u32(reader)?;
209 Ok((width, height))
210 }
211 }
212
213 fn read_size_u32(&self, reader: &mut BitReader) -> CodecResult<u32> {
217 let selector = reader.read_bits(2)?;
218 match selector {
219 0 => Ok(1),
220 1 => {
221 let extra = reader.read_bits(9)?;
222 Ok(1 + extra)
223 }
224 2 => {
225 let extra = reader.read_bits(13)?;
226 Ok(1 + extra)
227 }
228 3 => {
229 let extra = reader.read_bits(18)?;
230 Ok(1 + extra)
231 }
232 _ => Err(CodecError::InvalidBitstream("Invalid size selector".into())),
233 }
234 }
235
236 fn parse_image_metadata(
244 &self,
245 reader: &mut BitReader,
246 width: u32,
247 height: u32,
248 ) -> CodecResult<JxlHeader> {
249 let all_default = reader.read_bool()?;
251
252 if all_default {
253 return Ok(JxlHeader {
254 width,
255 height,
256 bits_per_sample: 8,
257 num_channels: 3,
258 is_float: false,
259 has_alpha: false,
260 color_space: JxlColorSpace::Srgb,
261 orientation: 1,
262 });
263 }
264
265 let has_extra_fields = reader.read_bool()?;
267 let orientation = if has_extra_fields {
268 reader.read_bits(3)? as u8 + 1
269 } else {
270 1
271 };
272
273 let float_flag = reader.read_bool()?;
275 let bits_per_sample = if float_flag {
276 let _exp_bits = reader.read_bits(4)?;
278 let mantissa_bits = reader.read_bits(4)? + 1;
279 (mantissa_bits + 1) as u8 } else {
281 let depth_selector = reader.read_bits(2)?;
282 match depth_selector {
283 0 => 8,
284 1 => 10,
285 2 => 12,
286 3 => {
287 let custom = reader.read_bits(6)?;
288 (custom + 1) as u8
289 }
290 _ => 8,
291 }
292 };
293
294 let color_space_selector = reader.read_bits(2)?;
296 let color_space = match color_space_selector {
297 0 => JxlColorSpace::Srgb,
298 1 => JxlColorSpace::LinearSrgb,
299 2 => JxlColorSpace::Gray,
300 3 => JxlColorSpace::Xyb,
301 _ => JxlColorSpace::Srgb,
302 };
303
304 let num_color_channels = if color_space == JxlColorSpace::Gray {
305 1u8
306 } else {
307 3u8
308 };
309
310 let has_alpha = reader.read_bool()?;
312 let num_channels = if has_alpha {
313 num_color_channels + 1
314 } else {
315 num_color_channels
316 };
317
318 Ok(JxlHeader {
319 width,
320 height,
321 bits_per_sample,
322 num_channels,
323 is_float: float_flag,
324 has_alpha,
325 color_space,
326 orientation,
327 })
328 }
329
330 fn decode_modular(
332 &self,
333 reader: &mut BitReader,
334 header: &JxlHeader,
335 ) -> CodecResult<Vec<Vec<i32>>> {
336 reader.align_to_byte();
337
338 let remaining_bits = reader.remaining_bits();
340 if remaining_bits == 0 {
341 return Err(CodecError::InvalidBitstream(
342 "No image data after header".into(),
343 ));
344 }
345
346 let remaining_bytes = (remaining_bits + 7) / 8;
348 let mut data = Vec::with_capacity(remaining_bytes);
349 for _ in 0..remaining_bytes {
350 match reader.read_u8(8) {
351 Ok(byte) => data.push(byte),
352 Err(_) => break,
353 }
354 }
355
356 let mut decoder = ModularDecoder::new();
357
358 if header.color_channels() >= 3 {
360 decoder.add_transform(ModularTransform::Rct {
361 begin_channel: 0,
362 rct_type: 0,
363 });
364 }
365
366 decoder.decode_image(
367 &data,
368 header.width,
369 header.height,
370 header.num_channels as u32,
371 header.bits_per_sample,
372 )
373 }
374
375 fn channels_to_interleaved(
377 &self,
378 channels: &[Vec<i32>],
379 header: &JxlHeader,
380 ) -> CodecResult<Vec<u8>> {
381 let pixel_count = header.width as usize * header.height as usize;
382 let num_channels = header.num_channels as usize;
383 let bytes_per_sample = header.bytes_per_sample();
384
385 if channels.len() != num_channels {
386 return Err(CodecError::Internal(format!(
387 "Expected {} channels, got {}",
388 num_channels,
389 channels.len()
390 )));
391 }
392
393 let total_bytes = pixel_count * num_channels * bytes_per_sample;
394 let mut output = Vec::with_capacity(total_bytes);
395
396 for i in 0..pixel_count {
397 for ch in 0..num_channels {
398 let value = channels[ch][i];
399
400 match bytes_per_sample {
401 1 => {
402 let clamped = value.clamp(0, 255) as u8;
404 output.push(clamped);
405 }
406 2 => {
407 let clamped = value.clamp(0, 65535) as u16;
409 output.push(clamped as u8);
410 output.push((clamped >> 8) as u8);
411 }
412 _ => {
413 let bytes = (value as u32).to_le_bytes();
415 output.extend_from_slice(&bytes);
416 }
417 }
418 }
419 }
420
421 Ok(output)
422 }
423}
424
425impl Default for JxlDecoder {
426 fn default() -> Self {
427 Self::new()
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 #[ignore]
437 fn test_is_codestream_signature() {
438 assert!(JxlDecoder::is_codestream(&[0xFF, 0x0A, 0x00]));
439 assert!(!JxlDecoder::is_codestream(&[0xFF, 0x0B, 0x00]));
440 assert!(!JxlDecoder::is_codestream(&[0xFF]));
441 assert!(!JxlDecoder::is_codestream(&[]));
442 }
443
444 #[test]
445 #[ignore]
446 fn test_is_container_signature() {
447 let mut container = vec![0u8; 16];
448 container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
449 assert!(JxlDecoder::is_container(&container));
450 assert!(!JxlDecoder::is_container(&[0xFF, 0x0A]));
451 }
452
453 #[test]
454 #[ignore]
455 fn test_is_jxl() {
456 assert!(JxlDecoder::is_jxl(&[0xFF, 0x0A]));
457 let mut container = vec![0u8; 16];
458 container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
459 assert!(JxlDecoder::is_jxl(&container));
460 assert!(!JxlDecoder::is_jxl(&[0x00, 0x00]));
461 }
462
463 #[test]
464 #[ignore]
465 fn test_extract_codestream_bare() {
466 let decoder = JxlDecoder::new();
467 let data = [0xFF, 0x0A, 0x01, 0x02];
468 let result = decoder.extract_codestream(&data).expect("ok");
469 assert_eq!(result, &data);
470 }
471
472 #[test]
473 #[ignore]
474 fn test_extract_codestream_invalid() {
475 let decoder = JxlDecoder::new();
476 assert!(decoder.extract_codestream(&[0x00, 0x00]).is_err());
477 }
478
479 #[test]
480 #[ignore]
481 fn test_parse_size_header_small() {
482 let decoder = JxlDecoder::new();
484 let mut writer = super::super::bitreader::BitWriter::new();
485 writer.write_bool(true); writer.write_bits(2, 5); writer.write_bits(0, 5); let data = writer.finish();
489
490 let mut reader = BitReader::new(&data);
491 let (w, h) = decoder.parse_size_header(&mut reader).expect("ok");
492 assert_eq!(h, 24);
493 assert_eq!(w, 24);
494 }
495
496 #[test]
497 #[ignore]
498 fn test_read_header_invalid_data() {
499 let decoder = JxlDecoder::new();
500 assert!(decoder.read_header(&[0x00]).is_err());
501 }
502
503 #[test]
504 #[ignore]
505 fn test_decoded_image_metrics() {
506 let img = DecodedImage {
507 width: 10,
508 height: 10,
509 channels: 3,
510 bit_depth: 8,
511 data: vec![0u8; 300],
512 color_space: JxlColorSpace::Srgb,
513 };
514 assert_eq!(img.sample_count(), 300);
515 assert_eq!(img.data_size(), 300);
516 }
517
518 #[test]
519 #[ignore]
520 fn test_decoded_image_16bit() {
521 let img = DecodedImage {
522 width: 10,
523 height: 10,
524 channels: 3,
525 bit_depth: 16,
526 data: vec![0u8; 600],
527 color_space: JxlColorSpace::Srgb,
528 };
529 assert_eq!(img.sample_count(), 300);
530 assert_eq!(img.data_size(), 600);
531 }
532}