use super::bitwriter::BitWriter;
use super::marker_write::{
write_cdt, write_cwd, write_eoc, write_pih, write_slh, write_soc, write_wgt, CdtComponent,
PihFields,
};
use super::markers::PROFILE_MAIN;
use super::vlc_encode::encode_subband;
use super::wavelet::forward_wavelet_2d;
use super::{JxsError, JxsResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JxsColorSpace {
Grey,
Rgb,
Yuv,
}
pub const JXS_UNIT_WEIGHT: u16 = 1;
#[derive(Debug, Clone)]
pub struct JpegXsEncoderConfig {
pub width: u32,
pub height: u32,
pub bit_depth: u8,
pub components: u8,
pub color_space: JxsColorSpace,
pub wavelet_levels_h: u8,
pub wavelet_levels_v: u8,
pub slice_height: u32,
pub weights: Vec<u16>,
}
impl JpegXsEncoderConfig {
pub fn new(width: u32, height: u32, bit_depth: u8, components: u8) -> Self {
let color_space = match components {
1 => JxsColorSpace::Grey,
3 => JxsColorSpace::Yuv,
_ => JxsColorSpace::Rgb,
};
Self {
width,
height,
bit_depth,
components,
color_space,
wavelet_levels_h: 1,
wavelet_levels_v: 1,
slice_height: height,
weights: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct JpegXsEncoder {
config: JpegXsEncoderConfig,
}
impl JpegXsEncoder {
pub fn new(config: JpegXsEncoderConfig) -> JxsResult<Self> {
if config.width == 0 || config.height == 0 {
return Err(JxsError::InvalidHeader(format!(
"encoder: frame size {}x{} must be non-zero",
config.width, config.height
)));
}
if config.width > u32::from(u16::MAX) || config.height > u32::from(u16::MAX) {
return Err(JxsError::InvalidHeader(format!(
"encoder: frame size {}x{} exceeds 65535",
config.width, config.height
)));
}
if config.bit_depth == 0 || config.bit_depth > 16 {
return Err(JxsError::InvalidHeader(format!(
"encoder: bit depth {} must be 1-16",
config.bit_depth
)));
}
if config.components == 0 {
return Err(JxsError::InvalidHeader(
"encoder: at least one component required".to_string(),
));
}
if config.wavelet_levels_h != 1 || config.wavelet_levels_v != 1 {
return Err(JxsError::Unsupported(format!(
"encoder: only single-level decomposition supported (got {}x{} levels)",
config.wavelet_levels_h, config.wavelet_levels_v
)));
}
if !config.weights.is_empty() {
if config.weights.len() != 4 {
return Err(JxsError::InvalidHeader(format!(
"encoder: weights must have 4 entries (LL,HL,LH,HH), got {}",
config.weights.len()
)));
}
if config.weights.iter().any(|&w| w != JXS_UNIT_WEIGHT) {
return Err(JxsError::Unsupported(
"encoder: non-unit quantisation weights are not supported (decoder has no \
dequantisation stage); use unit weights for lossless encoding"
.to_string(),
));
}
}
Ok(Self { config })
}
pub fn encode(&self, planes: &[Vec<i32>]) -> JxsResult<Vec<u8>> {
let cfg = &self.config;
let width = cfg.width as usize;
let height = cfg.height as usize;
let expected = width * height;
if planes.is_empty() {
return Err(JxsError::InvalidHeader(
"encode: no component planes provided".to_string(),
));
}
if planes.len() != cfg.components as usize {
return Err(JxsError::InvalidHeader(format!(
"encode: expected {} planes, got {}",
cfg.components,
planes.len()
)));
}
for (i, plane) in planes.iter().enumerate() {
if plane.len() != expected {
return Err(JxsError::InvalidHeader(format!(
"encode: plane {i} has {} samples, expected {expected}",
plane.len()
)));
}
}
let (ll, hl, lh, hh) = forward_wavelet_2d(&planes[0], width, height)?;
let mut slice_writer = BitWriter::new();
encode_subband(&mut slice_writer, &ll);
encode_subband(&mut slice_writer, &hl);
encode_subband(&mut slice_writer, &lh);
encode_subband(&mut slice_writer, &hh);
let slice_bytes = slice_writer.finish();
let mut buf = Vec::with_capacity(64 + slice_bytes.len());
write_soc(&mut buf);
let pih = PihFields {
codestream_len: 0, profile: PROFILE_MAIN,
level: 0,
width: cfg.width as u16,
height: cfg.height as u16,
codegroup_width: cfg.width as u16,
slice_height: if cfg.slice_height == 0 {
cfg.height as u16
} else {
(cfg.slice_height.min(cfg.height)) as u16
},
num_components: cfg.components,
ganging: 0,
bit_depth: cfg.bit_depth,
bw_ext: 0,
fq: 0,
bitrate: 0,
fsl: 0,
ppoc: 0,
cpih: 0,
};
write_pih(&mut buf, &pih)?;
let cdt: Vec<CdtComponent> = (0..cfg.components)
.map(|_| CdtComponent {
bit_depth: cfg.bit_depth,
sx: 1,
sy: 1,
})
.collect();
write_cdt(&mut buf, &cdt)?;
if !cfg.weights.is_empty() {
write_wgt(&mut buf, &cfg.weights)?;
}
write_cwd(&mut buf, &[]);
write_slh(&mut buf, 0);
buf.extend_from_slice(&slice_bytes);
write_eoc(&mut buf);
Ok(buf)
}
pub fn config(&self) -> &JpegXsEncoderConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jpegxs::decoder::JpegXsDecoder;
#[test]
fn new_rejects_zero_dimensions() {
let cfg = JpegXsEncoderConfig::new(0, 8, 8, 1);
assert!(JpegXsEncoder::new(cfg).is_err());
}
#[test]
fn new_rejects_bad_bit_depth() {
let cfg = JpegXsEncoderConfig::new(8, 8, 0, 1);
assert!(JpegXsEncoder::new(cfg).is_err());
}
#[test]
fn new_rejects_multilevel() {
let mut cfg = JpegXsEncoderConfig::new(8, 8, 8, 1);
cfg.wavelet_levels_h = 2;
assert!(JpegXsEncoder::new(cfg).is_err());
}
#[test]
fn new_rejects_nonunit_weights() {
let mut cfg = JpegXsEncoderConfig::new(8, 8, 8, 1);
cfg.weights = vec![2, 2, 2, 2];
assert!(matches!(
JpegXsEncoder::new(cfg),
Err(JxsError::Unsupported(_))
));
}
#[test]
fn new_accepts_unit_weights() {
let mut cfg = JpegXsEncoderConfig::new(8, 8, 8, 1);
cfg.weights = vec![1, 1, 1, 1];
assert!(JpegXsEncoder::new(cfg).is_ok());
}
#[test]
fn encode_rejects_wrong_plane_count() {
let cfg = JpegXsEncoderConfig::new(4, 4, 8, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let planes = vec![vec![0i32; 16], vec![0i32; 16]];
assert!(enc.encode(&planes).is_err());
}
#[test]
fn encode_rejects_wrong_plane_length() {
let cfg = JpegXsEncoderConfig::new(4, 4, 8, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let planes = vec![vec![0i32; 10]];
assert!(enc.encode(&planes).is_err());
}
#[test]
fn encode_produces_valid_soc_eoc() {
let cfg = JpegXsEncoderConfig::new(4, 4, 8, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let planes = vec![vec![5i32; 16]];
let stream = enc.encode(&planes).unwrap();
assert_eq!(&stream[0..2], &[0xFF, 0x10]); assert_eq!(&stream[stream.len() - 2..], &[0xFF, 0x11]); assert!(JpegXsDecoder::is_jpegxs(&stream));
}
#[test]
fn roundtrip_small_gradient_lossless() {
let (w, h) = (8u32, 8u32);
let cfg = JpegXsEncoderConfig::new(w, h, 8, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let plane: Vec<i32> = (0..(w * h) as usize)
.map(|i| (i % w as usize) as i32)
.collect();
let stream = enc.encode(std::slice::from_ref(&plane)).unwrap();
let img = JpegXsDecoder::decode(&stream).unwrap();
let decoded: Vec<i32> = img.samples[0].iter().map(|&v| v as i32).collect();
assert_eq!(decoded, plane, "lossless gradient round-trip failed");
}
#[test]
fn roundtrip_constant_lossless() {
let cfg = JpegXsEncoderConfig::new(16, 16, 8, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let plane = vec![200i32; 256];
let stream = enc.encode(std::slice::from_ref(&plane)).unwrap();
let img = JpegXsDecoder::decode(&stream).unwrap();
let decoded: Vec<i32> = img.samples[0].iter().map(|&v| v as i32).collect();
assert_eq!(decoded, plane);
}
#[test]
fn roundtrip_with_unit_weights_marker_lossless() {
let mut cfg = JpegXsEncoderConfig::new(8, 8, 8, 1);
cfg.weights = vec![1, 1, 1, 1];
let enc = JpegXsEncoder::new(cfg).unwrap();
let plane: Vec<i32> = (0..64).map(|i| (i % 8) as i32).collect();
let stream = enc.encode(std::slice::from_ref(&plane)).unwrap();
assert!(
stream.windows(2).any(|w| w == [0xFF, 0x14]),
"WGT marker missing"
);
let img = JpegXsDecoder::decode(&stream).unwrap();
let decoded: Vec<i32> = img.samples[0].iter().map(|&v| v as i32).collect();
assert_eq!(decoded, plane);
}
#[test]
fn roundtrip_10bit_gradient_lossless() {
let (w, h) = (16u32, 8u32);
let cfg = JpegXsEncoderConfig::new(w, h, 10, 1);
let enc = JpegXsEncoder::new(cfg).unwrap();
let plane: Vec<i32> = (0..(w * h) as usize)
.map(|i| ((i * 8) % 1024) as i32)
.collect();
let stream = enc.encode(std::slice::from_ref(&plane)).unwrap();
let img = JpegXsDecoder::decode(&stream).unwrap();
let decoded: Vec<i32> = img.samples[0].iter().map(|&v| v as i32).collect();
assert_eq!(decoded, plane);
}
}