pub mod bitstream;
pub mod color;
pub mod convert;
pub mod huffman;
pub mod idct;
pub mod markers;
pub mod mcu;
pub mod types;
pub mod upsample;
use crate::error::CodecError;
use crate::exif::{apply_exif_u8, read_exif_orientation, rotated_dims};
use crate::options::{DecodeOptions, ImageInfo};
use crate::pixel::ImagePixel;
use edgefirst_tensor::{PixelFormat, Tensor, TensorTrait};
pub struct JpegDecoderState {
mcu_scratch: Option<mcu::McuScratch>,
exif_scratch: Vec<u8>,
rot_scratch: Vec<u8>,
}
impl JpegDecoderState {
pub fn new() -> Self {
Self {
mcu_scratch: None,
exif_scratch: Vec::new(),
rot_scratch: Vec::new(),
}
}
}
impl Default for JpegDecoderState {
fn default() -> Self {
Self::new()
}
}
fn validate_output_format(fmt: PixelFormat) -> crate::Result<PixelFormat> {
match fmt {
PixelFormat::Rgb
| PixelFormat::Rgba
| PixelFormat::Grey
| PixelFormat::Bgra
| PixelFormat::Nv12 => Ok(fmt),
_ => Err(CodecError::UnsupportedFormat(fmt)),
}
}
pub fn peek_jpeg_info(data: &[u8], opts: &DecodeOptions) -> crate::Result<ImageInfo> {
let headers = markers::parse_markers(data)?;
let hdr = &headers.header;
let img_w = hdr.width as usize;
let img_h = hdr.height as usize;
let dest_fmt = opts.format.unwrap_or(PixelFormat::Rgb);
let output_fmt = validate_output_format(dest_fmt)?;
if hdr.components.len() == 1 && output_fmt == PixelFormat::Nv12 {
return Err(CodecError::InvalidData(
"cannot decode greyscale JPEG to NV12".into(),
));
}
if output_fmt == PixelFormat::Nv12 && (!img_w.is_multiple_of(2) || !img_h.is_multiple_of(2)) {
return Err(CodecError::InvalidData(format!(
"NV12 requires even dimensions; got {img_w}×{img_h}"
)));
}
let (rotation_deg, _flip_h) = if opts.apply_exif && output_fmt != PixelFormat::Nv12 {
headers
.exif_data
.as_deref()
.map(read_exif_orientation)
.unwrap_or((0, false))
} else {
(0, false)
};
let (final_w, final_h) = rotated_dims(img_w, img_h, rotation_deg);
let channels = output_fmt.channels();
Ok(ImageInfo {
width: final_w,
height: final_h,
format: output_fmt,
row_stride: final_w * channels,
})
}
pub fn decode_jpeg_into<T: ImagePixel>(
data: &[u8],
dst: &mut Tensor<T>,
opts: &DecodeOptions,
state: &mut JpegDecoderState,
) -> crate::Result<ImageInfo> {
let headers = markers::parse_markers(data)?;
let hdr = &headers.header;
let mut img_w = hdr.width as usize;
let mut img_h = hdr.height as usize;
let dest_fmt = opts.format.unwrap_or(PixelFormat::Rgb);
let output_fmt = validate_output_format(dest_fmt)?;
let output_fmt = if hdr.components.len() == 1 && output_fmt == PixelFormat::Nv12 {
return Err(CodecError::InvalidData(
"cannot decode greyscale JPEG to NV12".into(),
));
} else if hdr.components.len() == 1 && output_fmt != PixelFormat::Grey {
output_fmt
} else {
output_fmt
};
if output_fmt == PixelFormat::Nv12 && (!img_w.is_multiple_of(2) || !img_h.is_multiple_of(2)) {
return Err(CodecError::InvalidData(format!(
"NV12 requires even dimensions; got {img_w}×{img_h}"
)));
}
let (rotation_deg, flip_h) = if opts.apply_exif && output_fmt != PixelFormat::Nv12 {
headers
.exif_data
.as_deref()
.map(read_exif_orientation)
.unwrap_or((0, false))
} else {
(0, false)
};
let (final_w, final_h) = rotated_dims(img_w, img_h, rotation_deg);
let tensor_w = dst
.width()
.unwrap_or_else(|| dst.shape().get(1).copied().unwrap_or(0));
let tensor_h = dst
.height()
.unwrap_or_else(|| dst.shape().first().copied().unwrap_or(0));
if final_w > tensor_w || final_h > tensor_h {
return Err(CodecError::InsufficientCapacity {
image: (final_w, final_h),
tensor: (tensor_w, tensor_h),
});
}
let channels = output_fmt.channels();
let elem_size = std::mem::size_of::<T>();
let dst_stride = dst
.effective_row_stride()
.unwrap_or(tensor_w * channels * elem_size);
match &mut state.mcu_scratch {
Some(scratch) => scratch.ensure_capacity(&headers),
None => state.mcu_scratch = Some(mcu::McuScratch::new(&headers)),
}
let mcu_scratch = state.mcu_scratch.as_mut().unwrap();
let mut map = dst.map()?;
let dst_bytes: &mut [T] = &mut map;
if T::dtype() == edgefirst_tensor::DType::U8 {
let dst_u8: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(dst_bytes.as_mut_ptr() as *mut u8, dst_bytes.len())
};
if flip_h || rotation_deg != 0 {
let native_stride = img_w * channels;
state.exif_scratch.resize(native_stride * img_h, 0);
mcu::decode_image(
data,
&headers,
mcu_scratch,
&mut state.exif_scratch,
native_stride,
output_fmt,
)?;
apply_exif_u8(
&mut state.exif_scratch,
native_stride,
&mut img_w,
&mut img_h,
channels,
rotation_deg,
flip_h,
&mut state.rot_scratch,
);
let final_native_stride = img_w * channels;
for y in 0..img_h {
let src_off = y * final_native_stride;
let dst_off = y * dst_stride;
dst_u8[dst_off..dst_off + final_native_stride]
.copy_from_slice(&state.exif_scratch[src_off..src_off + final_native_stride]);
}
} else {
mcu::decode_image(data, &headers, mcu_scratch, dst_u8, dst_stride, output_fmt)?;
}
} else {
let temp_stride = img_w * channels;
let temp_size = temp_stride * img_h;
state.exif_scratch.resize(temp_size, 0);
mcu::decode_image(
data,
&headers,
mcu_scratch,
&mut state.exif_scratch,
temp_stride,
output_fmt,
)?;
if flip_h || rotation_deg != 0 {
apply_exif_u8(
&mut state.exif_scratch,
temp_stride,
&mut img_w,
&mut img_h,
channels,
rotation_deg,
flip_h,
&mut state.rot_scratch,
);
}
let src_stride = img_w * channels;
let dst_stride_elems = dst_stride / elem_size;
if T::dtype() == edgefirst_tensor::DType::I8 {
let dst_u8: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(dst_bytes.as_mut_ptr() as *mut u8, dst_bytes.len())
};
for y in 0..img_h {
let s = y * src_stride;
let d = y * dst_stride;
dst_u8[d..d + src_stride].copy_from_slice(&state.exif_scratch[s..s + src_stride]);
for b in &mut dst_u8[d..d + src_stride] {
*b ^= 0x80;
}
}
} else if T::dtype() == edgefirst_tensor::DType::F32 {
let dst_f32: &mut [f32] = unsafe {
std::slice::from_raw_parts_mut(dst_bytes.as_mut_ptr() as *mut f32, dst_bytes.len())
};
for y in 0..img_h {
let s = y * src_stride;
let d = y * dst_stride_elems;
convert::convert_u8_to_f32(
&state.exif_scratch[s..s + src_stride],
&mut dst_f32[d..d + src_stride],
);
}
} else if T::dtype() == edgefirst_tensor::DType::U16 {
let dst_u16: &mut [u16] = unsafe {
std::slice::from_raw_parts_mut(dst_bytes.as_mut_ptr() as *mut u16, dst_bytes.len())
};
for y in 0..img_h {
let s = y * src_stride;
let d = y * dst_stride_elems;
convert::convert_u8_to_u16(
&state.exif_scratch[s..s + src_stride],
&mut dst_u16[d..d + src_stride],
);
}
} else if T::dtype() == edgefirst_tensor::DType::I16 {
let dst_i16: &mut [i16] = unsafe {
std::slice::from_raw_parts_mut(dst_bytes.as_mut_ptr() as *mut i16, dst_bytes.len())
};
for y in 0..img_h {
let s = y * src_stride;
let d = y * dst_stride_elems;
convert::convert_u8_to_i16(
&state.exif_scratch[s..s + src_stride],
&mut dst_i16[d..d + src_stride],
);
}
} else {
for y in 0..img_h {
let s = y * src_stride;
let d = y * dst_stride_elems;
for x in 0..src_stride {
dst_bytes[d + x] = T::from_u8(state.exif_scratch[s + x]);
}
}
}
}
Ok(ImageInfo {
width: img_w,
height: img_h,
format: output_fmt,
row_stride: dst_stride,
})
}