use std::io::{self, Read};
use super::bitreader::BitReader;
use super::modular::{ModularDecoder, ModularTransform};
use super::types::{
JxlAnimation, JxlColorSpace, JxlFrame, JxlHeader, JXL_CODESTREAM_SIGNATURE,
JXL_CONTAINER_SIGNATURE,
};
use crate::container::isobmff::BoxIter;
use crate::error::{CodecError, CodecResult};
type PeekedReader<R> = io::Chain<io::Cursor<Vec<u8>>, R>;
#[derive(Clone, Debug)]
pub struct DecodedImage {
pub width: u32,
pub height: u32,
pub channels: u8,
pub bit_depth: u8,
pub data: Vec<u8>,
pub color_space: JxlColorSpace,
}
impl DecodedImage {
pub fn sample_count(&self) -> usize {
self.width as usize * self.height as usize * self.channels as usize
}
pub fn data_size(&self) -> usize {
let bytes_per_sample = if self.bit_depth > 8 { 2 } else { 1 };
self.sample_count() * bytes_per_sample
}
}
pub struct JxlDecoder;
impl JxlDecoder {
pub fn new() -> Self {
Self
}
pub fn is_jxl(data: &[u8]) -> bool {
Self::is_codestream(data) || Self::is_container(data)
}
pub fn is_codestream(data: &[u8]) -> bool {
data.len() >= 2
&& data[0] == JXL_CODESTREAM_SIGNATURE[0]
&& data[1] == JXL_CODESTREAM_SIGNATURE[1]
}
pub fn is_container(data: &[u8]) -> bool {
data.len() >= 12 && data[..12] == JXL_CONTAINER_SIGNATURE
}
pub fn decode(&self, data: &[u8]) -> CodecResult<DecodedImage> {
let codestream = self.extract_codestream(data)?;
let mut reader = BitReader::new(&codestream);
let _ = reader.read_bits(16)?;
let (width, height) = self.parse_size_header(&mut reader)?;
let header = self.parse_image_metadata(&mut reader, width, height)?;
header.validate()?;
let channels_data = self.decode_modular(&mut reader, &header)?;
let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
Ok(DecodedImage {
width: header.width,
height: header.height,
channels: header.num_channels,
bit_depth: header.bits_per_sample,
data: pixel_data,
color_space: header.color_space,
})
}
pub fn read_header(&self, data: &[u8]) -> CodecResult<JxlHeader> {
let codestream = self.extract_codestream(data)?;
let mut reader = BitReader::new(&codestream);
let _ = reader.read_bits(16)?;
let (width, height) = self.parse_size_header(&mut reader)?;
let header = self.parse_image_metadata(&mut reader, width, height)?;
header.validate()?;
Ok(header)
}
fn extract_codestream<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
if Self::is_codestream(data) {
return Ok(data);
}
if Self::is_container(data) {
return self.find_jxlc_box(data);
}
Err(CodecError::InvalidBitstream(
"Not a valid JPEG-XL file: invalid signature".into(),
))
}
fn find_jxlc_box<'a>(&self, data: &'a [u8]) -> CodecResult<&'a [u8]> {
let mut offset = 0;
while offset + 8 <= data.len() {
let box_size = u32::from_be_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]) as usize;
let box_type = &data[offset + 4..offset + 8];
if box_size < 8 {
break;
}
if box_type == b"jxlc" {
let content_start = offset + 8;
let content_end = offset + box_size;
if content_end <= data.len() {
return Ok(&data[content_start..content_end]);
}
return Err(CodecError::InvalidBitstream(
"jxlc box extends past end of file".into(),
));
}
offset += box_size;
}
Err(CodecError::InvalidBitstream(
"No jxlc (codestream) box found in container".into(),
))
}
fn parse_size_header(&self, reader: &mut BitReader) -> CodecResult<(u32, u32)> {
let small = reader.read_bool()?;
if small {
let height_div8 = reader.read_bits(5)? + 1;
let width_div8 = reader.read_bits(5)?;
let width_div8 = if width_div8 == 0 {
height_div8
} else {
width_div8
};
Ok((width_div8 * 8, height_div8 * 8))
} else {
let height = self.read_size_u32(reader)?;
let width = self.read_size_u32(reader)?;
Ok((width, height))
}
}
fn read_size_u32(&self, reader: &mut BitReader) -> CodecResult<u32> {
let selector = reader.read_bits(2)?;
match selector {
0 => Ok(1),
1 => {
let extra = reader.read_bits(9)?;
Ok(1 + extra)
}
2 => {
let extra = reader.read_bits(13)?;
Ok(1 + extra)
}
3 => {
let extra = reader.read_bits(18)?;
Ok(1 + extra)
}
_ => Err(CodecError::InvalidBitstream("Invalid size selector".into())),
}
}
fn parse_image_metadata(
&self,
reader: &mut BitReader,
width: u32,
height: u32,
) -> CodecResult<JxlHeader> {
let all_default = reader.read_bool()?;
if all_default {
return Ok(JxlHeader {
width,
height,
bits_per_sample: 8,
num_channels: 3,
is_float: false,
has_alpha: false,
color_space: JxlColorSpace::Srgb,
orientation: 1,
animation: None,
});
}
let has_extra_fields = reader.read_bool()?;
let orientation = if has_extra_fields {
reader.read_bits(3)? as u8 + 1
} else {
1
};
let float_flag = reader.read_bool()?;
let bits_per_sample = if float_flag {
let _exp_bits = reader.read_bits(4)?;
let mantissa_bits = reader.read_bits(4)? + 1;
(mantissa_bits + 1) as u8 } else {
let depth_selector = reader.read_bits(2)?;
match depth_selector {
0 => 8,
1 => 10,
2 => 12,
3 => {
let custom = reader.read_bits(6)?;
(custom + 1) as u8
}
_ => 8,
}
};
let color_space_selector = reader.read_bits(2)?;
let color_space = match color_space_selector {
0 => JxlColorSpace::Srgb,
1 => JxlColorSpace::LinearSrgb,
2 => JxlColorSpace::Gray,
3 => JxlColorSpace::Xyb,
_ => JxlColorSpace::Srgb,
};
let num_color_channels = if color_space == JxlColorSpace::Gray {
1u8
} else {
3u8
};
let has_alpha = reader.read_bool()?;
let num_channels = if has_alpha {
num_color_channels + 1
} else {
num_color_channels
};
let has_animation = reader.read_bool()?;
let animation = if has_animation {
Some(Self::parse_animation_header(reader)?)
} else {
None
};
Ok(JxlHeader {
width,
height,
bits_per_sample,
num_channels,
is_float: float_flag,
has_alpha,
color_space,
orientation,
animation,
})
}
fn parse_animation_header(reader: &mut BitReader) -> CodecResult<JxlAnimation> {
let tps_numerator = reader.read_bits(32)?;
let tps_denominator = reader.read_bits(32)?;
let num_loops = reader.read_bits(32)?;
let have_timecodes = reader.read_bool()?;
if tps_numerator == 0 {
return Err(CodecError::InvalidBitstream(
"Animation tps_numerator must be non-zero".into(),
));
}
if tps_denominator == 0 {
return Err(CodecError::InvalidBitstream(
"Animation tps_denominator must be non-zero".into(),
));
}
Ok(JxlAnimation {
tps_numerator,
tps_denominator,
num_loops,
have_timecodes,
})
}
fn parse_frame_header(reader: &mut BitReader) -> CodecResult<(u32, bool)> {
let duration_ticks = reader.read_bits(32)?;
let is_last = reader.read_bool()?;
Ok((duration_ticks, is_last))
}
fn decode_modular(
&self,
reader: &mut BitReader,
header: &JxlHeader,
) -> CodecResult<Vec<Vec<i32>>> {
reader.align_to_byte();
let remaining_bits = reader.remaining_bits();
if remaining_bits == 0 {
return Err(CodecError::InvalidBitstream(
"No image data after header".into(),
));
}
let remaining_bytes = (remaining_bits + 7) / 8;
let mut data = Vec::with_capacity(remaining_bytes);
for _ in 0..remaining_bytes {
match reader.read_u8(8) {
Ok(byte) => data.push(byte),
Err(_) => break,
}
}
let mut decoder = ModularDecoder::new();
if header.color_channels() >= 3 {
decoder.add_transform(ModularTransform::Rct {
begin_channel: 0,
rct_type: 0,
});
}
decoder.decode_image(
&data,
header.width,
header.height,
header.num_channels as u32,
header.bits_per_sample,
)
}
fn channels_to_interleaved(
&self,
channels: &[Vec<i32>],
header: &JxlHeader,
) -> CodecResult<Vec<u8>> {
let pixel_count = header.width as usize * header.height as usize;
let num_channels = header.num_channels as usize;
let bytes_per_sample = header.bytes_per_sample();
if channels.len() != num_channels {
return Err(CodecError::Internal(format!(
"Expected {} channels, got {}",
num_channels,
channels.len()
)));
}
let total_bytes = pixel_count * num_channels * bytes_per_sample;
let mut output = Vec::with_capacity(total_bytes);
for i in 0..pixel_count {
for ch in 0..num_channels {
let value = channels[ch][i];
match bytes_per_sample {
1 => {
let clamped = value.clamp(0, 255) as u8;
output.push(clamped);
}
2 => {
let clamped = value.clamp(0, 65535) as u16;
output.push(clamped as u8);
output.push((clamped >> 8) as u8);
}
_ => {
let bytes = (value as u32).to_le_bytes();
output.extend_from_slice(&bytes);
}
}
}
}
Ok(output)
}
pub fn decode_animated(&self, data: &[u8]) -> CodecResult<Vec<JxlFrame>> {
let codestream = self.extract_codestream(data)?;
let mut reader = BitReader::new(codestream);
let _ = reader.read_bits(16)?;
let (width, height) = self.parse_size_header(&mut reader)?;
let header = self.parse_image_metadata(&mut reader, width, height)?;
header.validate()?;
if header.animation.is_none() {
let channels_data = self.decode_modular(&mut reader, &header)?;
let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
return Ok(vec![JxlFrame {
data: pixel_data,
width: header.width,
height: header.height,
channels: header.num_channels,
bit_depth: header.bits_per_sample,
duration_ticks: 0,
is_last: true,
color_space: header.color_space,
}]);
}
let mut frames = Vec::new();
loop {
if reader.remaining_bits() < 33 {
break;
}
let (duration_ticks, is_last) = Self::parse_frame_header(&mut reader)?;
reader.align_to_byte();
if reader.remaining_bits() < 32 {
return Err(CodecError::InvalidBitstream(
"Unexpected end of animated codestream before frame data length".into(),
));
}
let data_len = reader.read_bits(32)? as usize;
if reader.remaining_bits() < data_len * 8 {
return Err(CodecError::InvalidBitstream(format!(
"Animated frame data truncated: expected {data_len} bytes, \
have {} bits remaining",
reader.remaining_bits()
)));
}
let mut frame_data_bytes = Vec::with_capacity(data_len);
for _ in 0..data_len {
frame_data_bytes.push(reader.read_u8(8)?);
}
let channels_data = self.decode_frame_modular(&frame_data_bytes, &header)?;
let pixel_data = self.channels_to_interleaved(&channels_data, &header)?;
frames.push(JxlFrame {
data: pixel_data,
width: header.width,
height: header.height,
channels: header.num_channels,
bit_depth: header.bits_per_sample,
duration_ticks,
is_last,
color_space: header.color_space,
});
if is_last {
break;
}
}
if frames.is_empty() {
return Err(CodecError::InvalidBitstream(
"Animated codestream contains no frames".into(),
));
}
Ok(frames)
}
fn decode_frame_modular(&self, data: &[u8], header: &JxlHeader) -> CodecResult<Vec<Vec<i32>>> {
let mut decoder = ModularDecoder::new();
if header.color_channels() >= 3 {
decoder.add_transform(ModularTransform::Rct {
begin_channel: 0,
rct_type: 0,
});
}
decoder.decode_image(
data,
header.width,
header.height,
header.num_channels as u32,
header.bits_per_sample,
)
}
pub fn is_animated(&self, data: &[u8]) -> CodecResult<bool> {
let header = self.read_header(data)?;
Ok(header.animation.is_some())
}
pub fn read_animation_header(&self, data: &[u8]) -> CodecResult<Option<JxlAnimation>> {
let header = self.read_header(data)?;
Ok(header.animation)
}
}
impl Default for JxlDecoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum JxlFormat {
Isobmff,
Native,
}
pub struct JxlStreamingDecoder<R: Read> {
format: JxlFormat,
box_iter: Option<BoxIter<PeekedReader<R>>>,
codestream_buf: Vec<u8>,
pending_frames: std::vec::IntoIter<JxlFrame>,
done: bool,
}
impl<R: Read> JxlStreamingDecoder<R> {
pub fn new(mut reader: R) -> CodecResult<Self> {
let mut peek = [0u8; 12];
let n = reader.read(&mut peek)?;
let peek_bytes = peek[..n].to_vec();
let format = if n >= 12 && peek_bytes[4..8] == *b"ftyp" && peek_bytes[8..12] == *b"jxl " {
JxlFormat::Isobmff
} else {
JxlFormat::Native
};
let mut chained: PeekedReader<R> = io::Cursor::new(peek_bytes).chain(reader);
match format {
JxlFormat::Isobmff => Ok(Self {
format,
box_iter: Some(BoxIter::new(chained)),
codestream_buf: Vec::new(),
pending_frames: Vec::new().into_iter(),
done: false,
}),
JxlFormat::Native => {
let mut all_bytes = Vec::new();
chained
.read_to_end(&mut all_bytes)
.map_err(CodecError::Io)?;
let frames = JxlDecoder::new().decode_animated(&all_bytes)?;
Ok(Self {
format,
box_iter: None,
codestream_buf: Vec::new(),
pending_frames: frames.into_iter(),
done: false,
})
}
}
}
}
impl<R: Read> Iterator for JxlStreamingDecoder<R> {
type Item = CodecResult<JxlFrame>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(frame) = self.pending_frames.next() {
return Some(Ok(frame));
}
match self.format {
JxlFormat::Native => {
self.done = true;
None
}
JxlFormat::Isobmff => {
let box_iter = match self.box_iter.as_mut() {
Some(bi) => bi,
None => {
self.done = true;
return None;
}
};
loop {
match box_iter.next() {
None => {
self.done = true;
if !self.codestream_buf.is_empty() {
let buf = std::mem::take(&mut self.codestream_buf);
return Some(Self::flush_codestream(buf, &mut self.pending_frames));
}
return None;
}
Some(Err(e)) => {
self.done = true;
return Some(Err(CodecError::Io(e)));
}
Some(Ok((fourcc, payload))) => {
if fourcc != *b"jxlp" {
continue;
}
if payload.len() < 4 {
self.done = true;
return Some(Err(CodecError::InvalidBitstream(
"jxlp box payload too short (< 4 bytes)".into(),
)));
}
let mut idx_buf = [0u8; 4];
idx_buf.copy_from_slice(&payload[0..4]);
let is_last = (u32::from_be_bytes(idx_buf) & 0x8000_0000) != 0;
self.codestream_buf.extend_from_slice(&payload[4..]);
if is_last {
let buf = std::mem::take(&mut self.codestream_buf);
self.box_iter = None;
return Some(Self::flush_codestream(buf, &mut self.pending_frames));
}
}
}
}
}
}
}
}
impl<R: Read> JxlStreamingDecoder<R> {
fn flush_codestream(
buf: Vec<u8>,
pending: &mut std::vec::IntoIter<JxlFrame>,
) -> CodecResult<JxlFrame> {
let mut frames = JxlDecoder::new().decode_animated(&buf)?;
if frames.is_empty() {
return Err(CodecError::InvalidBitstream(
"jxlp codestream contained no frames".into(),
));
}
let first = frames.remove(0);
*pending = frames.into_iter();
Ok(first)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore]
fn test_is_codestream_signature() {
assert!(JxlDecoder::is_codestream(&[0xFF, 0x0A, 0x00]));
assert!(!JxlDecoder::is_codestream(&[0xFF, 0x0B, 0x00]));
assert!(!JxlDecoder::is_codestream(&[0xFF]));
assert!(!JxlDecoder::is_codestream(&[]));
}
#[test]
#[ignore]
fn test_is_container_signature() {
let mut container = vec![0u8; 16];
container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
assert!(JxlDecoder::is_container(&container));
assert!(!JxlDecoder::is_container(&[0xFF, 0x0A]));
}
#[test]
#[ignore]
fn test_is_jxl() {
assert!(JxlDecoder::is_jxl(&[0xFF, 0x0A]));
let mut container = vec![0u8; 16];
container[..12].copy_from_slice(&JXL_CONTAINER_SIGNATURE);
assert!(JxlDecoder::is_jxl(&container));
assert!(!JxlDecoder::is_jxl(&[0x00, 0x00]));
}
#[test]
#[ignore]
fn test_extract_codestream_bare() {
let decoder = JxlDecoder::new();
let data = [0xFF, 0x0A, 0x01, 0x02];
let result = decoder.extract_codestream(&data).expect("ok");
assert_eq!(result, &data);
}
#[test]
#[ignore]
fn test_extract_codestream_invalid() {
let decoder = JxlDecoder::new();
assert!(decoder.extract_codestream(&[0x00, 0x00]).is_err());
}
#[test]
#[ignore]
fn test_parse_size_header_small() {
let decoder = JxlDecoder::new();
let mut writer = super::super::bitreader::BitWriter::new();
writer.write_bool(true); writer.write_bits(2, 5); writer.write_bits(0, 5); let data = writer.finish();
let mut reader = BitReader::new(&data);
let (w, h) = decoder.parse_size_header(&mut reader).expect("ok");
assert_eq!(h, 24);
assert_eq!(w, 24);
}
#[test]
#[ignore]
fn test_read_header_invalid_data() {
let decoder = JxlDecoder::new();
assert!(decoder.read_header(&[0x00]).is_err());
}
#[test]
#[ignore]
fn test_decoded_image_metrics() {
let img = DecodedImage {
width: 10,
height: 10,
channels: 3,
bit_depth: 8,
data: vec![0u8; 300],
color_space: JxlColorSpace::Srgb,
};
assert_eq!(img.sample_count(), 300);
assert_eq!(img.data_size(), 300);
}
#[test]
#[ignore]
fn test_decoded_image_16bit() {
let img = DecodedImage {
width: 10,
height: 10,
channels: 3,
bit_depth: 16,
data: vec![0u8; 600],
color_space: JxlColorSpace::Srgb,
};
assert_eq!(img.sample_count(), 300);
assert_eq!(img.data_size(), 600);
}
}