use crate::error::CodecError;
use crate::options::{DecodeOptions, ImageInfo};
use crate::pixel::ImagePixel;
use edgefirst_tensor::{Tensor, TensorDyn};
use std::io::Read;
pub struct ImageDecoder {
pub(crate) jpeg_state: crate::jpeg::JpegDecoderState,
pub(crate) scratch: Vec<u8>,
pub(crate) png_rot_scratch: Vec<u8>,
pub(crate) input_buffer: Vec<u8>,
}
impl ImageDecoder {
pub fn new() -> Self {
Self {
jpeg_state: crate::jpeg::JpegDecoderState::new(),
scratch: Vec::new(),
png_rot_scratch: Vec::new(),
input_buffer: Vec::new(),
}
}
pub fn decode_into<T: ImagePixel>(
&mut self,
data: &[u8],
dst: &mut Tensor<T>,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
if is_jpeg(data) {
crate::jpeg::decode_jpeg_into(data, dst, opts, &mut self.jpeg_state)
} else if is_png(data) {
crate::png::decode_png_into(
data,
dst,
opts,
&mut self.scratch,
&mut self.png_rot_scratch,
)
} else {
Err(CodecError::InvalidData(
"unrecognized image format (expected JPEG or PNG magic bytes)".into(),
))
}
}
pub fn decode_into_dyn(
&mut self,
data: &[u8],
dst: &mut TensorDyn,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
match dst {
TensorDyn::U8(t) => self.decode_into(data, t, opts),
TensorDyn::I8(t) => self.decode_into(data, t, opts),
TensorDyn::U16(t) => self.decode_into(data, t, opts),
TensorDyn::I16(t) => self.decode_into(data, t, opts),
TensorDyn::F32(t) => self.decode_into(data, t, opts),
other => Err(CodecError::UnsupportedDtype(other.dtype())),
}
}
pub fn decode_from_reader<T: ImagePixel, R: Read>(
&mut self,
mut reader: R,
dst: &mut Tensor<T>,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
self.input_buffer.clear();
reader.read_to_end(&mut self.input_buffer)?;
decode_into_inner(
&mut self.jpeg_state,
&mut self.scratch,
&mut self.png_rot_scratch,
&self.input_buffer,
dst,
opts,
)
}
pub fn decode_from_reader_dyn<R: Read>(
&mut self,
mut reader: R,
dst: &mut TensorDyn,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
self.input_buffer.clear();
reader.read_to_end(&mut self.input_buffer)?;
decode_into_inner_dyn(
&mut self.jpeg_state,
&mut self.scratch,
&mut self.png_rot_scratch,
&self.input_buffer,
dst,
opts,
)
}
}
pub(crate) fn decode_into_inner<T: ImagePixel>(
jpeg_state: &mut crate::jpeg::JpegDecoderState,
scratch: &mut Vec<u8>,
png_rot_scratch: &mut Vec<u8>,
data: &[u8],
dst: &mut Tensor<T>,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
if is_jpeg(data) {
crate::jpeg::decode_jpeg_into(data, dst, opts, jpeg_state)
} else if is_png(data) {
crate::png::decode_png_into(data, dst, opts, scratch, png_rot_scratch)
} else {
Err(CodecError::InvalidData(
"unrecognized image format (expected JPEG or PNG magic bytes)".into(),
))
}
}
pub(crate) fn decode_into_inner_dyn(
jpeg_state: &mut crate::jpeg::JpegDecoderState,
scratch: &mut Vec<u8>,
png_rot_scratch: &mut Vec<u8>,
data: &[u8],
dst: &mut TensorDyn,
opts: &DecodeOptions,
) -> crate::Result<ImageInfo> {
match dst {
TensorDyn::U8(t) => decode_into_inner(jpeg_state, scratch, png_rot_scratch, data, t, opts),
TensorDyn::I8(t) => decode_into_inner(jpeg_state, scratch, png_rot_scratch, data, t, opts),
TensorDyn::U16(t) => decode_into_inner(jpeg_state, scratch, png_rot_scratch, data, t, opts),
TensorDyn::I16(t) => decode_into_inner(jpeg_state, scratch, png_rot_scratch, data, t, opts),
TensorDyn::F32(t) => decode_into_inner(jpeg_state, scratch, png_rot_scratch, data, t, opts),
other => Err(CodecError::UnsupportedDtype(other.dtype())),
}
}
impl Default for ImageDecoder {
fn default() -> Self {
Self::new()
}
}
pub fn peek_info(data: &[u8], opts: &DecodeOptions) -> crate::Result<ImageInfo> {
if is_jpeg(data) {
crate::jpeg::peek_jpeg_info(data, opts)
} else if is_png(data) {
crate::png::peek_png_info(data, opts)
} else {
Err(CodecError::InvalidData(
"unrecognized image format (expected JPEG or PNG magic bytes)".into(),
))
}
}
fn is_jpeg(data: &[u8]) -> bool {
data.len() >= 3 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF
}
fn is_png(data: &[u8]) -> bool {
data.len() >= 4 && data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn magic_bytes_jpeg() {
assert!(is_jpeg(&[0xFF, 0xD8, 0xFF, 0xE0]));
assert!(!is_jpeg(&[0x89, 0x50, 0x4E, 0x47]));
assert!(!is_jpeg(&[]));
assert!(!is_jpeg(&[0xFF]));
}
#[test]
fn magic_bytes_png() {
assert!(is_png(&[0x89, 0x50, 0x4E, 0x47]));
assert!(!is_png(&[0xFF, 0xD8, 0xFF, 0xE0]));
assert!(!is_png(&[]));
}
#[test]
fn invalid_data() {
let mut decoder = ImageDecoder::new();
let mut tensor = Tensor::<u8>::image(
100,
100,
edgefirst_tensor::PixelFormat::Rgb,
Some(edgefirst_tensor::TensorMemory::Mem),
)
.unwrap();
let result = decoder.decode_into(b"not an image", &mut tensor, &DecodeOptions::default());
assert!(matches!(result, Err(CodecError::InvalidData(_))));
}
}