1use std::{mem::MaybeUninit, ptr::null};
21
22#[allow(clippy::wildcard_imports)]
23use jpegxl_sys::{
24 common::types::{JxlDataType, JxlPixelFormat},
25 decode::*,
26 metadata::codestream_header::{JxlBasicInfo, JxlOrientation},
27};
28
29use crate::{
30 common::{Endianness, PixelType},
31 errors::{check_dec_status, DecodeError},
32 memory::MemoryManager,
33 parallel::ParallelRunner,
34 utils::check_valid_signature,
35};
36
37mod result;
38pub use result::*;
39
40pub type BasicInfo = JxlBasicInfo;
42pub type ProgressiveDetail = JxlProgressiveDetail;
44pub type Orientation = JxlOrientation;
46
47#[derive(Clone, Copy, Debug)]
49pub struct PixelFormat {
50 pub num_channels: u32,
60 pub endianness: Endianness,
66 pub align: usize,
71}
72
73impl Default for PixelFormat {
74 fn default() -> Self {
75 Self {
76 num_channels: 0,
77 endianness: Endianness::Native,
78 align: 0,
79 }
80 }
81}
82
83#[derive(Builder)]
85#[builder(build_fn(skip, error = "None"))]
86#[builder(setter(strip_option))]
87pub struct JxlDecoder<'pr, 'mm> {
88 #[builder(setter(skip))]
90 ptr: *mut jpegxl_sys::decode::JxlDecoder,
91
92 pub pixel_format: Option<PixelFormat>,
94
95 pub skip_reorientation: Option<bool>,
101 pub unpremul_alpha: Option<bool>,
107 pub render_spotcolors: Option<bool>,
114 pub coalescing: Option<bool>,
121 pub desired_intensity_target: Option<f32>,
128 pub decompress: Option<bool>,
133
134 pub progressive_detail: Option<JxlProgressiveDetail>,
139
140 pub icc_profile: bool,
145
146 pub init_jpeg_buffer: usize,
151
152 pub parallel_runner: Option<&'pr dyn ParallelRunner>,
154
155 pub memory_manager: Option<&'mm dyn MemoryManager>,
157}
158
159impl<'pr, 'mm> JxlDecoderBuilder<'pr, 'mm> {
160 pub fn build(&mut self) -> Result<JxlDecoder<'pr, 'mm>, DecodeError> {
165 let mm = self.memory_manager.flatten();
166 let dec = unsafe {
167 mm.map_or_else(
168 || JxlDecoderCreate(null()),
169 |mm| JxlDecoderCreate(&mm.manager()),
170 )
171 };
172
173 if dec.is_null() {
174 return Err(DecodeError::CannotCreateDecoder);
175 }
176
177 Ok(JxlDecoder {
178 ptr: dec,
179 pixel_format: self.pixel_format.flatten(),
180 skip_reorientation: self.skip_reorientation.flatten(),
181 unpremul_alpha: self.unpremul_alpha.flatten(),
182 render_spotcolors: self.render_spotcolors.flatten(),
183 coalescing: self.coalescing.flatten(),
184 desired_intensity_target: self.desired_intensity_target.flatten(),
185 decompress: self.decompress.flatten(),
186 progressive_detail: self.progressive_detail.flatten(),
187 icc_profile: self.icc_profile.unwrap_or_default(),
188 init_jpeg_buffer: self.init_jpeg_buffer.unwrap_or(512 * 1024),
189 parallel_runner: self.parallel_runner.flatten(),
190 memory_manager: mm,
191 })
192 }
193}
194
195impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
196 pub(crate) fn decode_internal(
197 &self,
198 data: &[u8],
199 data_type: Option<JxlDataType>,
200 with_icc_profile: bool,
201 mut reconstruct_jpeg_buffer: Option<&mut Vec<u8>>,
202 format: *mut JxlPixelFormat,
203 pixels: &mut Vec<u8>,
204 ) -> Result<Metadata, DecodeError> {
205 let Some(sig) = check_valid_signature(data) else {
206 return Err(DecodeError::InvalidInput);
207 };
208 if !sig {
209 return Err(DecodeError::InvalidInput);
210 }
211
212 let mut basic_info = MaybeUninit::uninit();
213 let mut icc = if with_icc_profile { Some(vec![]) } else { None };
214
215 self.setup_decoder(with_icc_profile, reconstruct_jpeg_buffer.is_some())?;
216
217 let next_in = data.as_ptr();
218 let avail_in = std::mem::size_of_val(data) as _;
219
220 check_dec_status(unsafe { JxlDecoderSetInput(self.ptr, next_in, avail_in) })?;
221 unsafe { JxlDecoderCloseInput(self.ptr) };
222
223 let mut status;
224 loop {
225 use JxlDecoderStatus as s;
226
227 status = unsafe { JxlDecoderProcessInput(self.ptr) };
228
229 match status {
230 s::NeedMoreInput | s::Error => return Err(DecodeError::GenericError),
231
232 s::BasicInfo => {
234 check_dec_status(unsafe {
235 JxlDecoderGetBasicInfo(self.ptr, basic_info.as_mut_ptr())
236 })?;
237
238 if let Some(pr) = self.parallel_runner {
239 pr.callback_basic_info(unsafe { &*basic_info.as_ptr() });
240 }
241 }
242
243 s::ColorEncoding => {
245 self.get_icc_profile(unsafe { icc.as_mut().unwrap_unchecked() })?;
246 }
247
248 s::JPEGReconstruction => {
250 let buf = unsafe { reconstruct_jpeg_buffer.as_mut().unwrap_unchecked() };
253 buf.resize(self.init_jpeg_buffer, 0);
254 check_dec_status(unsafe {
255 JxlDecoderSetJPEGBuffer(self.ptr, buf.as_mut_ptr(), buf.len())
256 })?;
257 }
258
259 s::JPEGNeedMoreOutput => {
261 let buf = unsafe { reconstruct_jpeg_buffer.as_mut().unwrap_unchecked() };
264 let need_to_write = unsafe { JxlDecoderReleaseJPEGBuffer(self.ptr) };
265
266 buf.resize(buf.len() + need_to_write, 0);
267 check_dec_status(unsafe {
268 JxlDecoderSetJPEGBuffer(self.ptr, buf.as_mut_ptr(), buf.len())
269 })?;
270 }
271
272 s::NeedImageOutBuffer => {
274 self.output(unsafe { &*basic_info.as_ptr() }, data_type, format, pixels)?;
275 }
276
277 s::FullImage => continue,
278 s::Success => {
279 if let Some(buf) = reconstruct_jpeg_buffer.as_mut() {
280 let remaining = unsafe { JxlDecoderReleaseJPEGBuffer(self.ptr) };
281
282 buf.truncate(buf.len() - remaining);
283 buf.shrink_to_fit();
284 }
285
286 unsafe { JxlDecoderReset(self.ptr) };
287
288 let info = unsafe { basic_info.assume_init() };
289 return Ok(Metadata {
290 width: info.xsize,
291 height: info.ysize,
292 intensity_target: info.intensity_target,
293 min_nits: info.min_nits,
294 orientation: info.orientation,
295 num_color_channels: info.num_color_channels,
296 has_alpha_channel: info.alpha_bits > 0,
297 intrinsic_width: info.intrinsic_xsize,
298 intrinsic_height: info.intrinsic_ysize,
299 icc_profile: icc,
300 });
301 }
302 s::NeedPreviewOutBuffer => todo!(),
303 s::BoxNeedMoreOutput => todo!(),
304 s::PreviewImage => todo!(),
305 s::Frame => todo!(),
306 s::Box => todo!(),
307 s::BoxComplete => todo!(),
308 s::FrameProgression => todo!(),
309 }
310 }
311 }
312
313 fn setup_decoder(&self, icc: bool, reconstruct_jpeg: bool) -> Result<(), DecodeError> {
314 if let Some(runner) = self.parallel_runner {
315 check_dec_status(unsafe {
316 JxlDecoderSetParallelRunner(self.ptr, runner.runner(), runner.as_opaque_ptr())
317 })?;
318 }
319
320 let events = {
321 use JxlDecoderStatus::{BasicInfo, ColorEncoding, FullImage, JPEGReconstruction};
322
323 let mut events = BasicInfo as i32 | FullImage as i32;
324 if icc {
325 events |= ColorEncoding as i32;
326 }
327 if reconstruct_jpeg {
328 events |= JPEGReconstruction as i32;
329 }
330
331 events
332 };
333 check_dec_status(unsafe { JxlDecoderSubscribeEvents(self.ptr, events) })?;
334
335 if let Some(val) = self.skip_reorientation {
336 check_dec_status(unsafe { JxlDecoderSetKeepOrientation(self.ptr, val.into()) })?;
337 }
338 if let Some(val) = self.unpremul_alpha {
339 check_dec_status(unsafe { JxlDecoderSetUnpremultiplyAlpha(self.ptr, val.into()) })?;
340 }
341 if let Some(val) = self.render_spotcolors {
342 check_dec_status(unsafe { JxlDecoderSetRenderSpotcolors(self.ptr, val.into()) })?;
343 }
344 if let Some(val) = self.coalescing {
345 check_dec_status(unsafe { JxlDecoderSetCoalescing(self.ptr, val.into()) })?;
346 }
347 if let Some(val) = self.desired_intensity_target {
348 check_dec_status(unsafe { JxlDecoderSetDesiredIntensityTarget(self.ptr, val) })?;
349 }
350
351 Ok(())
352 }
353
354 fn get_icc_profile(&self, icc_profile: &mut Vec<u8>) -> Result<(), DecodeError> {
355 let mut icc_size = 0;
356 check_dec_status(unsafe {
357 JxlDecoderGetICCProfileSize(self.ptr, JxlColorProfileTarget::Data, &mut icc_size)
358 })?;
359 icc_profile.resize(icc_size, 0);
360
361 check_dec_status(unsafe {
362 JxlDecoderGetColorAsICCProfile(
363 self.ptr,
364 JxlColorProfileTarget::Data,
365 icc_profile.as_mut_ptr(),
366 icc_size,
367 )
368 })?;
369
370 Ok(())
371 }
372
373 fn output(
374 &self,
375 info: &BasicInfo,
376 data_type: Option<JxlDataType>,
377 format: *mut JxlPixelFormat,
378 pixels: &mut Vec<u8>,
379 ) -> Result<(), DecodeError> {
380 let data_type = match data_type {
381 Some(v) => v,
382 None => match (info.bits_per_sample, info.exponent_bits_per_sample) {
383 (x, 0) if x <= 8 => JxlDataType::Uint8,
384 (x, 0) if x <= 16 => JxlDataType::Uint16,
385 (16, _) => JxlDataType::Float16,
386 (32, _) => JxlDataType::Float,
387 (x, _) => return Err(DecodeError::UnsupportedBitWidth(x)),
388 },
389 };
390
391 let f = self.pixel_format.unwrap_or_default();
392 let pixel_format = JxlPixelFormat {
393 num_channels: if f.num_channels == 0 {
394 info.num_color_channels + u32::from(info.alpha_bits > 0)
395 } else {
396 f.num_channels
397 },
398 data_type,
399 endianness: f.endianness,
400 align: f.align,
401 };
402
403 let mut size = 0;
404 check_dec_status(unsafe {
405 JxlDecoderImageOutBufferSize(self.ptr, &pixel_format, &mut size)
406 })?;
407 pixels.resize(size, 0);
408
409 check_dec_status(unsafe {
410 JxlDecoderSetImageOutBuffer(self.ptr, &pixel_format, pixels.as_mut_ptr().cast(), size)
411 })?;
412
413 unsafe { *format = pixel_format };
414 Ok(())
415 }
416
417 pub fn decode(&self, data: &[u8]) -> Result<(Metadata, Pixels), DecodeError> {
422 let mut buffer = vec![];
423 let mut pixel_format = MaybeUninit::uninit();
424 let metadata = self.decode_internal(
425 data,
426 None,
427 self.icc_profile,
428 None,
429 pixel_format.as_mut_ptr(),
430 &mut buffer,
431 )?;
432 Ok((
433 metadata,
434 Pixels::new(buffer, unsafe { &pixel_format.assume_init() }),
435 ))
436 }
437
438 pub fn decode_with<T: PixelType>(
443 &self,
444 data: &[u8],
445 ) -> Result<(Metadata, Vec<T>), DecodeError> {
446 let mut buffer = vec![];
447 let mut pixel_format = MaybeUninit::uninit();
448 let metadata = self.decode_internal(
449 data,
450 Some(T::pixel_type()),
451 self.icc_profile,
452 None,
453 pixel_format.as_mut_ptr(),
454 &mut buffer,
455 )?;
456
457 let buf = unsafe {
459 let pixel_format = pixel_format.assume_init();
460 debug_assert!(T::pixel_type() == pixel_format.data_type);
461 T::convert(&buffer, &pixel_format)
462 };
463
464 Ok((metadata, buf))
465 }
466
467 pub fn reconstruct(&self, data: &[u8]) -> Result<(Metadata, Data), DecodeError> {
475 let mut buffer = vec![];
476 let mut pixel_format = MaybeUninit::uninit();
477 let mut jpeg_buf = vec![];
478 let metadata = self.decode_internal(
479 data,
480 None,
481 self.icc_profile,
482 Some(&mut jpeg_buf),
483 pixel_format.as_mut_ptr(),
484 &mut buffer,
485 )?;
486
487 Ok((
488 metadata,
489 if jpeg_buf.is_empty() {
490 Data::Pixels(Pixels::new(buffer, unsafe { &pixel_format.assume_init() }))
491 } else {
492 Data::Jpeg(jpeg_buf)
493 },
494 ))
495 }
496}
497
498impl<'prl, 'mm> Drop for JxlDecoder<'prl, 'mm> {
499 fn drop(&mut self) {
500 unsafe { JxlDecoderDestroy(self.ptr) };
501 }
502}
503
504#[must_use]
506pub fn decoder_builder<'prl, 'mm>() -> JxlDecoderBuilder<'prl, 'mm> {
507 JxlDecoderBuilder::default()
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 #[allow(clippy::clone_on_copy)]
516 fn test_derive() {
517 let e = PixelFormat::default().clone();
518 println!("{e:?}");
519
520 _ = decoder_builder().clone();
521 }
522}