use arrow::buffer::ScalarBuffer;
use smallvec::{smallvec, SmallVec};
use crate::{
datatypes::ChannelDatatype,
datatypes::{Blob, TensorBuffer, TensorData},
};
#[cfg(feature = "image")]
use crate::datatypes::ImageFormat;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ImageKind {
Color,
Depth,
Segmentation,
}
#[cfg(feature = "image")]
#[derive(thiserror::Error, Clone, Debug)]
pub enum ImageConversionError {
#[error(
"Unsupported color type: {0:?}. We support 8-bit, 16-bit, and f32 images, and RGB, RGBA, Luminance, and Luminance-Alpha."
)]
UnsupportedImageColorType(image::ColorType),
}
#[cfg(feature = "image")]
#[derive(thiserror::Error, Clone, Debug)]
pub enum ImageLoadError {
#[error(transparent)]
Image(std::sync::Arc<image::ImageError>),
#[error(transparent)]
Tiff(std::sync::Arc<tiff::TiffError>),
#[error("Failed to load file: {0}")]
ReadError(std::sync::Arc<std::io::Error>),
#[error(transparent)]
ImageConversionError(#[from] ImageConversionError),
#[error("MIME type '{0}' is not supported for images")]
UnsupportedMimeType(String),
#[error("Could not detect MIME type from the image contents")]
UnrecognizedMimeType,
}
#[cfg(feature = "image")]
impl From<image::ImageError> for ImageLoadError {
#[inline]
fn from(err: image::ImageError) -> Self {
Self::Image(std::sync::Arc::new(err))
}
}
#[cfg(feature = "image")]
impl From<tiff::TiffError> for ImageLoadError {
#[inline]
fn from(err: tiff::TiffError) -> Self {
Self::Tiff(std::sync::Arc::new(err))
}
}
#[cfg(feature = "image")]
impl From<std::io::Error> for ImageLoadError {
#[inline]
fn from(err: std::io::Error) -> Self {
Self::ReadError(std::sync::Arc::new(err))
}
}
#[derive(thiserror::Error, Clone, Debug)]
pub enum ImageConstructionError<T: TryInto<TensorData>>
where
T::Error: std::error::Error,
{
#[error("Could not convert source to TensorData: {0}")]
TensorDataConversion(T::Error),
#[error("Could not create Image from TensorData with shape {0:?}")]
BadImageShape(ScalarBuffer<u64>),
#[error(
"Chroma downsampling is not supported for this image type (e.g. DepthImage or SegmentationImage)"
)]
ChromaDownsamplingNotSupported,
}
pub fn blob_and_datatype_from_tensor(tensor_buffer: TensorBuffer) -> (Blob, ChannelDatatype) {
match tensor_buffer {
TensorBuffer::U8(buffer) => (Blob(buffer), ChannelDatatype::U8),
TensorBuffer::U16(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::U16),
TensorBuffer::U32(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::U32),
TensorBuffer::U64(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::U64),
TensorBuffer::I8(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::I8),
TensorBuffer::I16(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::I16),
TensorBuffer::I32(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::I32),
TensorBuffer::I64(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::I64),
TensorBuffer::F16(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::F16),
TensorBuffer::F32(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::F32),
TensorBuffer::F64(buffer) => (Blob(cast_to_u8(&buffer)), ChannelDatatype::F64),
}
}
#[inline]
pub fn cast_to_u8<T: arrow::datatypes::ArrowNativeType>(
buffer: &arrow::buffer::ScalarBuffer<T>,
) -> ScalarBuffer<u8> {
arrow::buffer::ScalarBuffer::new(buffer.inner().clone(), 0, buffer.inner().len())
}
pub trait ImageChannelType: bytemuck::Pod {
const CHANNEL_TYPE: ChannelDatatype;
}
impl ImageChannelType for u8 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::U8;
}
impl ImageChannelType for u16 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::U16;
}
impl ImageChannelType for u32 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::U32;
}
impl ImageChannelType for u64 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::U64;
}
impl ImageChannelType for i8 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::I8;
}
impl ImageChannelType for i16 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::I16;
}
impl ImageChannelType for i32 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::I32;
}
impl ImageChannelType for i64 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::I64;
}
impl ImageChannelType for half::f16 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::F16;
}
impl ImageChannelType for f32 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::F32;
}
impl ImageChannelType for f64 {
const CHANNEL_TYPE: ChannelDatatype = ChannelDatatype::F64;
}
pub fn find_non_empty_dim_indices(shape: &[u64]) -> SmallVec<[usize; 4]> {
match shape.len() {
0 => return smallvec![],
1 => return smallvec![0],
2 => return smallvec![0, 1],
_ => {}
}
let mut non_unit_indices =
shape
.iter()
.enumerate()
.filter_map(|(ind, &dim)| if dim != 1 { Some(ind) } else { None });
let mut min = non_unit_indices.next().unwrap_or(0);
let mut max = non_unit_indices.last().unwrap_or(min);
while max == min && max + 1 < shape.len() {
max += 1;
}
let target_len = match shape[max] {
3 | 4 => 3,
_ => 2,
};
while max - min + 1 < target_len && 0 < min {
min -= 1;
}
(min..=max).collect()
}
#[test]
fn test_find_non_empty_dim_indices() {
fn expect(shape: &[u64], expected: &[usize]) {
let got = find_non_empty_dim_indices(shape);
assert!(
got.as_slice() == expected,
"Input: {shape:?}, got {got:?}, expected {expected:?}"
);
}
expect(&[], &[]);
expect(&[0], &[0]);
expect(&[1], &[0]);
expect(&[100], &[0]);
expect(&[480, 640], &[0, 1]);
expect(&[480, 640, 1], &[0, 1]);
expect(&[480, 640, 1, 1], &[0, 1]);
expect(&[480, 640, 3], &[0, 1, 2]);
expect(&[1, 480, 640], &[1, 2]);
expect(&[1, 480, 640, 3, 1], &[1, 2, 3]);
expect(&[1, 3, 480, 640, 1], &[1, 2, 3]);
expect(&[1, 1, 480, 640], &[2, 3]);
expect(&[1, 1, 480, 640, 1, 1], &[2, 3]);
expect(&[1, 1, 3], &[0, 1, 2]);
expect(&[1, 1, 3, 1], &[2, 3]);
}
#[derive(Clone, Copy, Debug)]
pub enum YuvMatrixCoefficients {
Bt601,
Bt709,
}
pub fn rgb_from_yuv(
y: u8,
u: u8,
v: u8,
limited_range: bool,
coefficients: YuvMatrixCoefficients,
) -> [u8; 3] {
let (mut y, mut u, mut v) = (y as f32, u as f32, v as f32);
if limited_range {
y = (y - 16.0) / 219.0;
u = (u - 128.0) / 224.0;
v = (v - 128.0) / 224.0;
} else {
y /= 255.0;
u = (u - 128.0) / 255.0;
v = (v - 128.0) / 255.0;
}
let r;
let g;
let b;
match coefficients {
YuvMatrixCoefficients::Bt601 => {
r = y + 1.402 * v;
g = y - 0.344 * u - 0.714 * v;
b = y + 1.772 * u;
}
YuvMatrixCoefficients::Bt709 => {
r = y + 1.575 * v;
g = y - 0.187 * u - 0.468 * v;
b = y + 1.856 * u;
}
}
[(255.0 * r) as u8, (255.0 * g) as u8, (255.0 * b) as u8]
}
#[cfg(feature = "image")]
pub fn blob_and_format_from_tiff(bytes: &[u8]) -> Result<(Blob, ImageFormat), ImageLoadError> {
use tiff::decoder::DecodingResult;
let cursor = std::io::Cursor::new(bytes);
let mut decoder = tiff::decoder::Decoder::new(cursor)?;
let img = decoder.read_image()?;
let (bytes, data_type): (&[u8], ChannelDatatype) = match &img {
DecodingResult::U8(data) => (bytemuck::cast_slice(data), ChannelDatatype::U8),
DecodingResult::U16(data) => (bytemuck::cast_slice(data), ChannelDatatype::U16),
DecodingResult::U32(data) => (bytemuck::cast_slice(data), ChannelDatatype::U32),
DecodingResult::U64(data) => (bytemuck::cast_slice(data), ChannelDatatype::U64),
DecodingResult::F32(data) => (bytemuck::cast_slice(data), ChannelDatatype::F32),
DecodingResult::F64(data) => (bytemuck::cast_slice(data), ChannelDatatype::F64),
DecodingResult::I8(data) => (bytemuck::cast_slice(data), ChannelDatatype::I8),
DecodingResult::I16(data) => (bytemuck::cast_slice(data), ChannelDatatype::I16),
DecodingResult::I32(data) => (bytemuck::cast_slice(data), ChannelDatatype::I32),
DecodingResult::I64(data) => (bytemuck::cast_slice(data), ChannelDatatype::I64),
};
let (width, height) = decoder.dimensions()?;
let image_format = ImageFormat {
width,
height,
channel_datatype: Some(data_type),
pixel_format: None,
color_model: None,
};
Ok((Blob::from(bytes), image_format))
}