use super::entropy::decode_slice_subbands;
use super::markers::{parse_headers, SOC};
use super::nlt::{apply_nlt_reverse, parse_nlt_payload, NltParams};
use super::wavelet::inverse_53_2d;
use super::{JxsError, JxsResult};
#[derive(Debug, Clone)]
pub struct DecodedImage {
pub width: u32,
pub height: u32,
pub num_components: u8,
pub bit_depth: u8,
pub samples: Vec<Vec<u16>>,
}
pub struct JpegXsDecoder;
impl JpegXsDecoder {
pub fn new() -> Self {
Self
}
pub fn is_jpegxs(data: &[u8]) -> bool {
data.len() >= 2 && data[0] == 0xFF && data[1] == 0x10
}
pub fn decode(data: &[u8]) -> JxsResult<DecodedImage> {
if !Self::is_jpegxs(data) {
let got = if data.len() >= 2 {
u16::from_be_bytes([data[0], data[1]])
} else {
0
};
return Err(JxsError::InvalidMarker { expected: SOC, got });
}
let (headers, _header_end) = parse_headers(data)?;
let pih = &headers.pih;
let frame_w = pih.width as usize;
let frame_h = pih.height as usize;
let nc = pih.num_components as usize;
let bit_depth = pih.bit_depth;
let nlt_params = if let Some(ref payload) = headers.nlt_payload {
parse_nlt_payload(payload)?
} else {
NltParams::none()
};
let mut output_planes: Vec<Vec<i32>> = vec![vec![0i32; frame_w * frame_h]; nc];
let n_low_w = (frame_w + 1) / 2;
let n_high_w = frame_w / 2;
let n_low_h = (frame_h + 1) / 2;
let n_high_h = frame_h / 2;
if headers.slices.is_empty() {
} else {
let slice = &headers.slices[0];
let slice_end = (slice.data_offset + slice.data_len).min(data.len());
let slice_bytes = &data[slice.data_offset..slice_end];
let (ll_sb, hl_sb, lh_sb, hh_sb) =
match decode_slice_subbands(slice_bytes, frame_w, frame_h) {
Ok(subbands) => subbands,
Err(JxsError::Unsupported(_)) | Err(JxsError::TruncatedStream { .. }) => {
(
super::entropy::SubbandCoeffs::zeros(n_low_w, n_low_h),
super::entropy::SubbandCoeffs::zeros(n_high_w, n_low_h),
super::entropy::SubbandCoeffs::zeros(n_low_w, n_high_h),
super::entropy::SubbandCoeffs::zeros(n_high_w, n_high_h),
)
}
Err(e) => return Err(e),
};
let reconstructed = inverse_53_2d(
&ll_sb.coeffs,
&hl_sb.coeffs,
&lh_sb.coeffs,
&hh_sb.coeffs,
frame_w,
frame_h,
)?;
for plane in output_planes.iter_mut() {
plane.copy_from_slice(&reconstructed);
}
}
for plane in output_planes.iter_mut() {
apply_nlt_reverse(plane, &nlt_params, bit_depth)?;
}
let max_val = ((1u32 << bit_depth) - 1) as i32;
let samples: Vec<Vec<u16>> = output_planes
.into_iter()
.map(|plane| {
plane
.into_iter()
.map(|s| s.clamp(0, max_val) as u16)
.collect()
})
.collect();
Ok(DecodedImage {
width: pih.width,
height: pih.height,
num_components: pih.num_components,
bit_depth,
samples,
})
}
}
impl Default for JpegXsDecoder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jpegxs::markers::build_test_codestream;
#[test]
fn is_jpegxs_soc_prefix() {
assert!(JpegXsDecoder::is_jpegxs(&[0xFF, 0x10, 0x00, 0x00]));
}
#[test]
fn is_jpegxs_rejects_jpeg() {
assert!(!JpegXsDecoder::is_jpegxs(&[0xFF, 0xD8, 0xFF, 0xE0]));
}
#[test]
fn is_jpegxs_rejects_empty() {
assert!(!JpegXsDecoder::is_jpegxs(&[]));
}
#[test]
fn decode_headers_only_no_slices() {
let data = build_test_codestream(8, 8, 8, 1, 8);
let img = JpegXsDecoder::decode(&data).expect("decode");
assert_eq!(img.width, 8);
assert_eq!(img.height, 8);
assert_eq!(img.num_components, 1);
assert_eq!(img.bit_depth, 8);
assert_eq!(img.samples.len(), 1);
assert_eq!(img.samples[0].len(), 64);
assert!(img.samples[0].iter().all(|&v| v == 0));
}
#[test]
fn decode_rejects_empty_data() {
let result = JpegXsDecoder::decode(&[]);
assert!(result.is_err());
}
#[test]
fn decode_rejects_truncated_soc_only() {
let result = JpegXsDecoder::decode(&[0xFF, 0x10]);
assert!(result.is_err());
}
#[test]
fn decode_rejects_non_jxs_stream() {
let result = JpegXsDecoder::decode(&[0xFF, 0xD8, 0xFF, 0xE0]); assert!(result.is_err());
if let Err(JxsError::InvalidMarker { expected, got }) = result {
assert_eq!(expected, 0xFF10);
assert_eq!(got, 0xFFD8);
} else {
panic!("expected InvalidMarker");
}
}
#[test]
fn decoded_image_has_correct_sample_count() {
let data = build_test_codestream(16, 16, 16, 3, 8);
let img = JpegXsDecoder::decode(&data).expect("decode");
assert_eq!(img.width, 16);
assert_eq!(img.height, 16);
assert_eq!(img.num_components, 3);
assert_eq!(img.samples.len(), 3);
for plane in &img.samples {
assert_eq!(plane.len(), 16 * 16);
}
}
#[test]
fn decoded_image_sample_values_within_bit_depth() {
let data = build_test_codestream(4, 4, 4, 1, 10);
let img = JpegXsDecoder::decode(&data).expect("decode");
let max_val = (1u16 << 10) - 1;
for &s in &img.samples[0] {
assert!(s <= max_val, "sample {s} exceeds 10-bit max {max_val}");
}
}
}