use std::{mem::MaybeUninit, ptr::null};
#[allow(clippy::wildcard_imports)]
use jpegxl_sys::*;
use crate::{
common::{Endianness, PixelType},
errors::{check_dec_status, DecodeError},
memory::MemoryManager,
parallel::JxlParallelRunner,
utils::check_valid_signature,
};
mod result;
pub use result::*;
pub type BasicInfo = JxlBasicInfo;
pub type ProgressiveDetail = JxlProgressiveDetail;
pub type Orientation = JxlOrientation;
#[derive(Clone, Copy, Debug)]
pub struct PixelFormat {
pub num_channels: u32,
pub endianness: Endianness,
pub align: usize,
}
impl Default for PixelFormat {
fn default() -> Self {
Self {
num_channels: 0,
endianness: Endianness::Native,
align: 0,
}
}
}
#[derive(Builder)]
#[builder(build_fn(skip))]
#[builder(setter(strip_option))]
pub struct JxlDecoder<'pr, 'mm> {
#[builder(setter(skip))]
dec: *mut jpegxl_sys::JxlDecoder,
pub pixel_format: Option<PixelFormat>,
pub skip_reorientation: Option<bool>,
pub unpremul_alpha: Option<bool>,
pub render_spotcolors: Option<bool>,
pub coalescing: Option<bool>,
pub desired_intensity_target: Option<f32>,
pub decompress: Option<bool>,
pub progressive_detail: Option<JxlProgressiveDetail>,
pub icc_profile: bool,
pub init_jpeg_buffer: usize,
pub parallel_runner: Option<&'pr dyn JxlParallelRunner>,
#[allow(dead_code)]
memory_manager: Option<&'mm dyn MemoryManager>,
}
impl<'pr, 'mm> JxlDecoderBuilder<'pr, 'mm> {
pub fn build(&self) -> Result<JxlDecoder<'pr, 'mm>, DecodeError> {
let mm = self.memory_manager.flatten();
let dec = unsafe {
mm.map_or_else(
|| JxlDecoderCreate(null()),
|mm| JxlDecoderCreate(&mm.manager()),
)
};
if dec.is_null() {
return Err(DecodeError::CannotCreateDecoder);
}
Ok(JxlDecoder {
dec,
pixel_format: self.pixel_format.flatten(),
skip_reorientation: self.skip_reorientation.flatten(),
unpremul_alpha: self.unpremul_alpha.flatten(),
render_spotcolors: self.render_spotcolors.flatten(),
coalescing: self.coalescing.flatten(),
desired_intensity_target: self.desired_intensity_target.flatten(),
decompress: self.decompress.flatten(),
progressive_detail: self.progressive_detail.flatten(),
icc_profile: self.icc_profile.unwrap_or_default(),
init_jpeg_buffer: self.init_jpeg_buffer.unwrap_or(512 * 1024),
parallel_runner: self.parallel_runner.flatten(),
memory_manager: mm,
})
}
}
impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn decode_internal(
&self,
data: &[u8],
data_type: Option<JxlDataType>,
with_icc_profile: bool,
reconstruct_jpeg_buffer: Option<&mut Vec<u8>>,
(format, pixels): (*mut JxlPixelFormat, &mut Vec<u8>),
) -> Result<Metadata, DecodeError> {
let Some(sig) = check_valid_signature(data) else { return Err(DecodeError::InvalidInput) };
if !sig {
return Err(DecodeError::InvalidInput);
}
let mut basic_info = MaybeUninit::uninit();
let mut icc = if with_icc_profile { Some(vec![]) } else { None };
self.setup_decoder(with_icc_profile, reconstruct_jpeg_buffer.is_some())?;
let next_in = data.as_ptr();
let avail_in = std::mem::size_of_val(data) as _;
check_dec_status(unsafe { JxlDecoderSetInput(self.dec, next_in, avail_in) })?;
unsafe { JxlDecoderCloseInput(self.dec) };
let mut status;
loop {
use JxlDecoderStatus::{
BasicInfo, ColorEncoding, Error, FullImage, JpegNeedMoreOutput, JpegReconstruction,
NeedImageOutBuffer, NeedMoreInput, Success,
};
status = unsafe { JxlDecoderProcessInput(self.dec) };
match status {
Error => return Err(DecodeError::GenericError),
NeedMoreInput => {
unimplemented!()
}
BasicInfo => {
check_dec_status(unsafe {
JxlDecoderGetBasicInfo(self.dec, basic_info.as_mut_ptr())
})?;
if let Some(pr) = self.parallel_runner {
pr.callback_basic_info(unsafe { &*basic_info.as_ptr() });
}
}
ColorEncoding => {
if let Some(icc) = icc.as_mut() {
self.get_icc_profile(icc)?;
}
}
JpegReconstruction => {
if let Some(&mut ref mut buf) = reconstruct_jpeg_buffer {
buf.resize(self.init_jpeg_buffer, 0);
check_dec_status(unsafe {
JxlDecoderSetJPEGBuffer(self.dec, buf.as_mut_ptr(), buf.len())
})?;
}
}
JpegNeedMoreOutput => {
if let Some(&mut ref mut buf) = reconstruct_jpeg_buffer {
let need_to_write = unsafe { JxlDecoderReleaseJPEGBuffer(self.dec) };
buf.resize(buf.len() + need_to_write, 0);
check_dec_status(unsafe {
JxlDecoderSetJPEGBuffer(self.dec, buf.as_mut_ptr(), buf.len())
})?;
}
}
NeedImageOutBuffer => {
self.output(unsafe { &*basic_info.as_ptr() }, data_type, format, pixels)?;
}
FullImage => continue,
Success => {
if let Some(&mut ref mut buf) = reconstruct_jpeg_buffer {
let remaining = unsafe { JxlDecoderReleaseJPEGBuffer(self.dec) };
buf.truncate(buf.len() - remaining);
buf.shrink_to_fit();
}
unsafe { JxlDecoderReset(self.dec) };
let info = unsafe { basic_info.assume_init() };
return Ok(Metadata {
width: info.xsize,
height: info.ysize,
intensity_target: info.intensity_target,
min_nits: info.min_nits,
orientation: info.orientation,
num_color_channels: info.num_color_channels,
has_alpha_channel: info.alpha_bits > 0,
intrinsic_width: info.intrinsic_xsize,
intrinsic_height: info.intrinsic_ysize,
icc_profile: icc,
});
}
_ => return Err(DecodeError::UnknownStatus(status)),
}
}
}
fn setup_decoder(&self, icc: bool, reconstruct_jpeg: bool) -> Result<(), DecodeError> {
if let Some(runner) = self.parallel_runner {
check_dec_status(unsafe {
JxlDecoderSetParallelRunner(self.dec, runner.runner(), runner.as_opaque_ptr())
})?;
}
let events = {
use JxlDecoderStatus::{BasicInfo, ColorEncoding, FullImage, JpegReconstruction};
let mut events = BasicInfo as i32 | FullImage as i32;
if icc {
events |= ColorEncoding as i32;
}
if reconstruct_jpeg {
events |= JpegReconstruction as i32;
}
events
};
check_dec_status(unsafe { JxlDecoderSubscribeEvents(self.dec, events) })?;
if let Some(val) = self.skip_reorientation {
check_dec_status(unsafe { JxlDecoderSetKeepOrientation(self.dec, val.into()) })?;
}
if let Some(val) = self.unpremul_alpha {
check_dec_status(unsafe { JxlDecoderSetUnpremultiplyAlpha(self.dec, val.into()) })?;
}
if let Some(val) = self.render_spotcolors {
check_dec_status(unsafe { JxlDecoderSetRenderSpotcolors(self.dec, val.into()) })?;
}
if let Some(val) = self.coalescing {
check_dec_status(unsafe { JxlDecoderSetCoalescing(self.dec, val.into()) })?;
}
if let Some(val) = self.desired_intensity_target {
check_dec_status(unsafe { JxlDecoderSetDesiredIntensityTarget(self.dec, val) })?;
}
Ok(())
}
fn get_icc_profile(&self, icc_profile: &mut Vec<u8>) -> Result<(), DecodeError> {
let mut icc_size = 0;
check_dec_status(unsafe {
JxlDecoderGetICCProfileSize(
self.dec,
null(),
JxlColorProfileTarget::Data,
&mut icc_size,
)
})?;
icc_profile.resize(icc_size, 0);
check_dec_status(unsafe {
JxlDecoderGetColorAsICCProfile(
self.dec,
null(),
JxlColorProfileTarget::Data,
icc_profile.as_mut_ptr(),
icc_size,
)
})?;
Ok(())
}
fn output(
&self,
info: &BasicInfo,
data_type: Option<JxlDataType>,
format: *mut JxlPixelFormat,
pixels: &mut Vec<u8>,
) -> Result<(), DecodeError> {
let data_type = match data_type {
Some(v) => v,
None => match info.bits_per_sample {
8 => JxlDataType::Uint8,
16 if info.exponent_bits_per_sample == 0 => JxlDataType::Uint16,
16 => JxlDataType::Float16,
32 => JxlDataType::Float,
_ => return Err(DecodeError::InvalidInput),
},
};
let f = self.pixel_format.unwrap_or_default();
let pixel_format = JxlPixelFormat {
num_channels: if f.num_channels == 0 {
info.num_color_channels + u32::from(info.alpha_bits > 0)
} else {
f.num_channels
},
data_type,
endianness: f.endianness,
align: 0, };
let mut size = 0;
check_dec_status(unsafe {
JxlDecoderImageOutBufferSize(self.dec, &pixel_format, &mut size)
})?;
pixels.resize(size, 0);
check_dec_status(unsafe {
JxlDecoderSetImageOutBuffer(self.dec, &pixel_format, pixels.as_mut_ptr().cast(), size)
})?;
unsafe { *format = pixel_format };
Ok(())
}
pub fn decode(&self, data: &[u8]) -> Result<(Metadata, Pixels), DecodeError> {
let mut buffer = vec![];
let mut pixel_format = MaybeUninit::uninit();
let metadata = self.decode_internal(
data,
None,
self.icc_profile,
None,
(pixel_format.as_mut_ptr(), &mut buffer),
)?;
Ok((
metadata,
Pixels::new(buffer, unsafe { &pixel_format.assume_init() }),
))
}
pub fn decode_with<T: PixelType>(
&self,
data: &[u8],
) -> Result<(Metadata, Vec<T>), DecodeError> {
let mut buffer = vec![];
let mut pixel_format = MaybeUninit::uninit();
let metadata = self.decode_internal(
data,
Some(T::pixel_type()),
self.icc_profile,
None,
(pixel_format.as_mut_ptr(), &mut buffer),
)?;
let buf = unsafe {
let pixel_format = pixel_format.assume_init();
debug_assert!(T::pixel_type() == pixel_format.data_type);
match T::pixel_type() {
JxlDataType::Float => std::mem::transmute(to_f32(&buffer, &pixel_format)),
JxlDataType::Uint8 => std::mem::transmute(buffer),
JxlDataType::Uint16 => std::mem::transmute(to_u16(&buffer, &pixel_format)),
JxlDataType::Float16 => std::mem::transmute(to_f16(&buffer, &pixel_format)),
}
};
Ok((metadata, buf))
}
pub fn reconstruct(&self, data: &[u8]) -> Result<(Metadata, Data), DecodeError> {
let mut buffer = vec![];
let mut pixel_format = MaybeUninit::uninit();
let mut jpeg_buf = vec![];
let metadata = self.decode_internal(
data,
None,
self.icc_profile,
Some(&mut jpeg_buf),
(pixel_format.as_mut_ptr(), &mut buffer),
)?;
Ok((
metadata,
if jpeg_buf.is_empty() {
Data::Pixels(Pixels::new(buffer, unsafe { &pixel_format.assume_init() }))
} else {
Data::Jpeg(jpeg_buf)
},
))
}
}
impl<'prl, 'mm> Drop for JxlDecoder<'prl, 'mm> {
fn drop(&mut self) {
unsafe { JxlDecoderDestroy(self.dec) };
}
}
#[must_use]
pub fn decoder_builder<'prl, 'mm>() -> JxlDecoderBuilder<'prl, 'mm> {
JxlDecoderBuilder::default()
}