use std::collections::HashSet;
use std::io::Cursor;
use std::sync::{Arc, OnceLock};
use crate::logging::{img_debug, img_error, img_info};
use crate::Error;
#[derive(Clone, Debug)]
pub enum Pixels {
Rgba8(Arc<[u8]>),
Rgba16(Arc<[u16]>),
}
#[derive(Clone, Debug)]
pub struct RawImage {
pub width: u32,
pub height: u32,
pub pixels: Pixels,
}
impl RawImage {
pub fn has_transparency(&self) -> bool {
match &self.pixels {
Pixels::Rgba8(bytes) => bytes.chunks_exact(4).any(|px| px[3] < 255),
Pixels::Rgba16(samples) => samples.chunks_exact(4).any(|px| px[3] < 65535),
}
}
}
fn heif_brands() -> &'static HashSet<[u8; 4]> {
static BRANDS: OnceLock<HashSet<[u8; 4]>> = OnceLock::new();
BRANDS.get_or_init(|| {
[
*b"heic", *b"heis", *b"hevc", *b"hevx", *b"heim", *b"heix", *b"mif1", *b"msf1", *b"avif", ]
.into_iter()
.collect()
})
}
fn is_heif_ftyp(data: &[u8]) -> bool {
if data.len() < 12 || data[4..8] != *b"ftyp" {
return false;
}
let brand: [u8; 4] = data[8..12]
.try_into()
.expect("data[8..12] must be exactly 4 bytes — guaranteed by the len >= 12 check above");
heif_brands().contains(&brand)
}
pub fn decode(data: &[u8], max_pixels: u64) -> Result<RawImage, Error> {
img_debug!("decode: {} bytes, max_pixels={}", data.len(), max_pixels);
if is_heif_ftyp(data) {
img_info!("decode: detected HEIC/HEIF container (ftyp magic)");
return decode_heif(data, max_pixels);
}
decode_via_image_crate(data, max_pixels)
}
fn decode_via_image_crate(data: &[u8], max_pixels: u64) -> Result<RawImage, Error> {
let mut reader = image::ImageReader::new(Cursor::new(data))
.with_guessed_format()
.map_err(|e| Error::Decode(e.to_string()))?;
match reader.format() {
Some(image::ImageFormat::Jpeg | image::ImageFormat::Png | image::ImageFormat::WebP) => {
img_debug!("decode: detected format {:?}", reader.format());
}
Some(other) => {
img_error!("decode: unsupported format {:?}", other);
return Err(Error::UnsupportedFormat(format!("{other:?}")));
}
None => {
img_error!("decode: could not detect image format");
return Err(Error::Decode("unrecognised image format".into()));
}
}
let alloc_cap = max_pixels
.saturating_mul(8)
.saturating_add(64 * 1024 * 1024);
let mut limits = image::Limits::default();
limits.max_alloc = Some(alloc_cap);
reader.limits(limits);
let img = reader.decode().map_err(|e| {
img_error!("decode: image decode error: {}", e);
Error::Decode(e.to_string())
})?;
let (width, height) = (img.width(), img.height());
let pixel_count = u64::from(width) * u64::from(height);
img_debug!(
"decode: raw dimensions {}×{} ({} Mpx), colour type={:?}",
width,
height,
pixel_count / 1_000_000,
img.color()
);
if pixel_count > max_pixels {
img_error!(
"decode: image {}×{} ({} px) exceeds max_pixels={}",
width,
height,
pixel_count,
max_pixels
);
return Err(Error::InputTooLarge {
width,
height,
max_pixels,
});
}
let pixels = match img.color() {
image::ColorType::Rgb16 | image::ColorType::Rgba16 => {
img_info!("decode: 16-bit PNG detected — preserving full precision for 10-bit AVIF");
Pixels::Rgba16(Arc::from(img.into_rgba16().into_raw()))
}
_ => {
img_debug!("decode: converting to RGBA8");
Pixels::Rgba8(Arc::from(img.into_rgba8().into_raw()))
}
};
img_info!("decode: {}×{} decoded OK", width, height);
Ok(RawImage {
width,
height,
pixels,
})
}
fn decode_heif(data: &[u8], max_pixels: u64) -> Result<RawImage, Error> {
#[cfg(feature = "heic-experimental")]
return decode_heif_impl(data, max_pixels);
#[cfg(not(feature = "heic-experimental"))]
{
let _ = (data, max_pixels); img_error!("decode: HEIC/HEIF input but `heic-experimental` feature is not enabled");
Err(Error::UnsupportedFormat(
"HEIC/HEIF (enable the `heic-experimental` Cargo feature and \
ensure `libheif` is installed on the system)"
.into(),
))
}
}
#[cfg(feature = "heic-experimental")]
fn decode_heif_impl(data: &[u8], max_pixels: u64) -> Result<RawImage, Error> {
use libheif_rs::{ColorSpace, HeifContext, LibHeif, RgbChroma};
let _lib = LibHeif::new();
let ctx = HeifContext::read_from_bytes(data).map_err(|e| {
img_error!("decode_heif: context parse error: {}", e);
Error::Decode(format!("HEIF context: {e}"))
})?;
let handle = ctx.primary_image_handle().map_err(|e| {
img_error!("decode_heif: could not get primary image handle: {}", e);
Error::Decode(format!("HEIF primary image: {e}"))
})?;
let width = handle.width();
let height = handle.height();
let pixel_count = u64::from(width) * u64::from(height);
img_debug!(
"decode_heif: {}×{} ({} Mpx)",
width,
height,
pixel_count / 1_000_000
);
if pixel_count > max_pixels {
img_error!(
"decode_heif: {}×{} ({} px) exceeds max_pixels={}",
width,
height,
pixel_count,
max_pixels
);
return Err(Error::InputTooLarge {
width,
height,
max_pixels,
});
}
let image = _lib
.decode(&handle, ColorSpace::Rgb(RgbChroma::Rgba), None)
.map_err(|e| {
img_error!("decode_heif: pixel decode error: {}", e);
Error::Decode(format!("HEIF decode: {e}"))
})?;
let planes = image.planes();
let interleaved = planes.interleaved.ok_or_else(|| {
img_error!("decode_heif: no interleaved RGBA plane in decoded image");
Error::Decode("HEIF image has no interleaved RGBA plane".into())
})?;
let pixels =
heif_interleaved_to_rgba_pixels(interleaved.data, width, height, interleaved.stride)
.map_err(|e| {
img_error!("decode_heif: malformed interleaved plane: {}", e);
e
})?;
img_info!("decode_heif: {}×{} decoded OK", width, height);
Ok(RawImage {
width,
height,
pixels: Pixels::Rgba8(Arc::from(pixels)),
})
}
#[cfg(any(feature = "heic-experimental", test))]
fn heif_interleaved_to_rgba_pixels(
data: &[u8],
width: u32,
height: u32,
stride: usize,
) -> Result<Vec<u8>, Error> {
let row_bytes = width as usize * 4;
let rows = height as usize;
if stride < row_bytes {
return Err(Error::Decode(format!(
"HEIF interleaved stride {stride} is smaller than row size {row_bytes}"
)));
}
let expected_len = stride
.checked_mul(rows)
.ok_or_else(|| Error::Decode("HEIF interleaved plane size overflow".into()))?;
if data.len() < expected_len {
return Err(Error::Decode(format!(
"HEIF interleaved plane too short: got {} bytes, expected at least {} \
for {} rows with stride {}",
data.len(),
expected_len,
rows,
stride
)));
}
if stride == row_bytes {
return Ok(data[..expected_len].to_vec());
}
img_debug!(
"decode_heif: row stride {} != expected {} — stripping per-row padding",
stride,
row_bytes
);
let out_len = row_bytes
.checked_mul(rows)
.ok_or_else(|| Error::Decode("HEIF RGBA output size overflow".into()))?;
let mut pixels = Vec::with_capacity(out_len);
for y in 0..rows {
let row_start = y * stride;
let row_end = row_start + row_bytes;
pixels.extend_from_slice(&data[row_start..row_end]);
}
Ok(pixels)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn heif_stride_smaller_than_row_is_decode_error() {
let err = heif_interleaved_to_rgba_pixels(&[0; 8], 3, 1, 8).unwrap_err();
assert!(matches!(err, Error::Decode(_)));
}
#[test]
fn heif_short_plane_is_decode_error() {
let err = heif_interleaved_to_rgba_pixels(&[0; 11], 2, 2, 6).unwrap_err();
assert!(matches!(err, Error::Decode(_)));
}
#[test]
fn heif_valid_padding_layout_is_compacted() {
let src = vec![
1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 10, 11, 12, 13, 14, 15, 16, 17, 8, 8, ];
let out = heif_interleaved_to_rgba_pixels(&src, 2, 2, 10).unwrap();
assert_eq!(
out,
vec![1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17]
);
}
#[test]
fn has_transparency_detects_alpha_in_rgba8() {
let opaque = RawImage {
width: 2,
height: 1,
pixels: Pixels::Rgba8(Arc::from(vec![
255, 0, 0, 255, 0, 255, 0, 255, ])),
};
assert!(!opaque.has_transparency());
let transparent = RawImage {
width: 2,
height: 1,
pixels: Pixels::Rgba8(Arc::from(vec![
255, 0, 0, 255, 0, 255, 0, 128, ])),
};
assert!(transparent.has_transparency());
let fully_transparent = RawImage {
width: 1,
height: 1,
pixels: Pixels::Rgba8(Arc::from(vec![0, 0, 0, 0])),
};
assert!(fully_transparent.has_transparency());
}
#[test]
fn has_transparency_detects_alpha_in_rgba16() {
let opaque = RawImage {
width: 2,
height: 1,
pixels: Pixels::Rgba16(Arc::from(vec![
65535, 0, 0, 65535, 0, 65535, 0, 65535, ])),
};
assert!(!opaque.has_transparency());
let transparent = RawImage {
width: 2,
height: 1,
pixels: Pixels::Rgba16(Arc::from(vec![
65535, 0, 0, 65535, 0, 65535, 0, 32768, ])),
};
assert!(transparent.has_transparency());
}
}