use crate::ImageInfo;
use crate::PixelFormat;
use crate::ImageData;
pub struct TensorImage<'a> {
tensor: &'a tch::Tensor,
info: ImageInfo,
planar: bool,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum TensorPixelFormat {
Planar(PixelFormat),
Interlaced(PixelFormat),
Guess(ColorFormat),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ColorFormat {
Rgb,
Bgr,
}
#[allow(clippy::needless_lifetimes)]
pub trait TensorAsImage {
fn as_image<'a>(&'a self, pixel_format: TensorPixelFormat) -> Result<TensorImage<'a>, String>;
fn as_interlaced<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, String> {
self.as_image(TensorPixelFormat::Interlaced(pixel_format))
}
fn as_planar<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, String> {
self.as_image(TensorPixelFormat::Planar(pixel_format))
}
fn as_image_guess<'a>(&'a self, color_format: ColorFormat) -> Result<TensorImage<'a>, String> {
self.as_image(TensorPixelFormat::Guess(color_format))
}
fn as_image_guess_rgb<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_image_guess(ColorFormat::Rgb)
}
fn as_image_guess_bgr<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_image_guess(ColorFormat::Bgr)
}
fn as_mono8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_interlaced(PixelFormat::Mono8)
}
fn as_interlaced_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_interlaced(PixelFormat::Rgb8)
}
fn as_interlaced_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_interlaced(PixelFormat::Rgba8)
}
fn as_interlaced_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_interlaced(PixelFormat::Bgr8)
}
fn as_interlaced_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_interlaced(PixelFormat::Bgra8)
}
fn as_planar_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_planar(PixelFormat::Rgb8)
}
fn as_planar_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_planar(PixelFormat::Rgba8)
}
fn as_planar_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_planar(PixelFormat::Bgr8)
}
fn as_planar_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, String> {
self.as_planar(PixelFormat::Bgra8)
}
}
impl TensorAsImage for tch::Tensor {
fn as_image(&self, pixel_format: TensorPixelFormat) -> Result<TensorImage, String> {
let (planar, info) = match pixel_format {
TensorPixelFormat::Planar(pixel_format) => tensor_info(self, pixel_format, true)?,
TensorPixelFormat::Interlaced(pixel_format) => tensor_info(self, pixel_format, false)?,
TensorPixelFormat::Guess(color_format) => guess_tensor_info(self, color_format)?,
};
Ok(TensorImage { tensor: self, info, planar })
}
}
impl ImageData for TensorImage<'_> {
fn data(self) -> Box<[u8]> {
if self.planar {
Vec::<u8>::from(self.tensor.permute(&[1, 2, 0])).into_boxed_slice()
} else {
Vec::<u8>::from(self.tensor).into_boxed_slice()
}
}
fn info(&self) -> Result<ImageInfo, String> {
Ok(self.info.clone())
}
}
impl ImageData for Result<TensorImage<'_>, String> {
fn data(self) -> Box<[u8]> {
self.expect("ImageData::data called on an Err variant").data()
}
fn info(&self) -> Result<ImageInfo, String> {
self.as_ref().map_err(|x| x.clone()).and_then(|x| x.info())
}
}
fn tensor_info(tensor: &tch::Tensor, pixel_format: PixelFormat, planar: bool) -> Result<(bool, ImageInfo), String> {
let expected_channels = pixel_format.channels();
let dimensions = tensor.dim();
if dimensions == 3 {
let shape = tensor.size3().unwrap();
if planar {
let (channels, height, width) = shape;
if channels != i64::from(expected_channels) {
Err(format!("expected shape ({}, height, width), found {:?}", expected_channels, shape))
} else {
Ok((false, ImageInfo::new(pixel_format, width as usize, height as usize)))
}
} else {
let (height, width, channels) = shape;
if channels != i64::from(expected_channels) {
Err(format!("expected shape (height, width, {}), found {:?}", expected_channels, shape))
} else {
Ok((false, ImageInfo::new(pixel_format, width as usize, height as usize)))
}
}
} else if dimensions == 2 && expected_channels == 1 {
let (height, width) = tensor.size2().unwrap();
Ok((false, ImageInfo::new(pixel_format, width as usize, height as usize)))
} else {
Err(format!("wrong number of dimensions ({}) for format ({:?})", dimensions, pixel_format))
}
}
fn guess_tensor_info(tensor: &tch::Tensor, color_format: ColorFormat) -> Result<(bool, ImageInfo), String> {
let dimensions = tensor.dim();
if dimensions == 2 {
let (height, width) = tensor.size2().unwrap();
Ok((false, ImageInfo::mono8(width as usize, height as usize)))
} else if dimensions == 3 {
let shape = tensor.size3().unwrap();
match (shape.0 as usize, shape.1 as usize, shape.2 as usize, color_format) {
(h, w, 1, _) => Ok((false, ImageInfo::mono8(w, h))),
(1, h, w, _) => Ok((false, ImageInfo::mono8(w, h))),
(h, w, 3, ColorFormat::Rgb) => Ok((false, ImageInfo::rgb8(w, h))),
(h, w, 3, ColorFormat::Bgr) => Ok((false, ImageInfo::bgr8(w, h))),
(3, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgb8(w, h))),
(3, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgr8(w, h))),
(h, w, 4, ColorFormat::Rgb) => Ok((false, ImageInfo::rgba8(w, h))),
(h, w, 4, ColorFormat::Bgr) => Ok((false, ImageInfo::bgra8(w, h))),
(4, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgba8(w, h))),
(4, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgra8(w, h))),
_ => Err(format!("unable to guess pixel format for tensor with shape {:?}, expected (height, width) or (height, width, channels) or (channels, height, width) where channels is either 1, 3 or 4", shape))
}
} else {
Err(format!("unable to guess pixel format for tensor with {} dimensions, expected 2 or 3 dimensions", dimensions))
}
}
#[cfg(test)]
mod test {
use super::*;
use assert2::assert;
#[test]
fn guess_tensor_info() {
let data = tch::Tensor::of_slice(&(0..120).collect::<Vec<u8>>());
assert!(data.reshape(&[12, 10, 1]).as_image_guess_bgr().info() == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape(&[1, 12, 10]).as_image_guess_bgr().info() == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape(&[12, 10]).as_image_guess_bgr().info() == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape(&[8, 5, 3]).as_image_guess_rgb().info() == Ok(ImageInfo::rgb8(5, 8)));
assert!(data.reshape(&[8, 5, 3]).as_image_guess_bgr().info() == Ok(ImageInfo::bgr8(5, 8)));
assert!(data.reshape(&[5, 6, 4]).as_image_guess_rgb().info() == Ok(ImageInfo::rgba8(6, 5)));
assert!(data.reshape(&[5, 6, 4]).as_image_guess_bgr().info() == Ok(ImageInfo::bgra8(6, 5)));
assert!(data.reshape(&[3, 8, 5]).as_image_guess_rgb().info() == Ok(ImageInfo::rgb8(5, 8)));
assert!(data.reshape(&[3, 8, 5]).as_image_guess_bgr().info() == Ok(ImageInfo::bgr8(5, 8)));
assert!(data.reshape(&[4, 5, 6]).as_image_guess_rgb().info() == Ok(ImageInfo::rgba8(6, 5)));
assert!(data.reshape(&[4, 5, 6]).as_image_guess_bgr().info() == Ok(ImageInfo::bgra8(6, 5)));
assert!(let Err(_) = data.reshape(&[120]).as_image_guess_rgb().info());
assert!(let Err(_) = data.reshape(&[2, 10, 6]).as_image_guess_rgb().info());
assert!(let Err(_) = data.reshape(&[6, 10, 2]).as_image_guess_rgb().info());
assert!(let Err(_) = data.reshape(&[8, 5, 3, 1]).as_image_guess_rgb().info());
assert!(let Err(_) = data.reshape(&[4, 5, 6, 1]).as_image_guess_rgb().info());
}
#[test]
fn tensor_info_interlaced_with_known_format() {
let data = tch::Tensor::of_slice(&(0..60).collect::<Vec<u8>>());
assert!(data.reshape(&[12, 5, 1]).as_mono8().info() == Ok(ImageInfo::mono8(5, 12)));
assert!(data.reshape(&[12, 5]).as_mono8().info() == Ok(ImageInfo::mono8(5, 12)));
assert!(let Err(_) = data.reshape(&[12, 5, 1, 1]).as_mono8().info());
assert!(let Err(_) = data.reshape(&[6, 5, 2]).as_mono8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4]).as_mono8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3]).as_mono8().info());
assert!(let Err(_) = data.reshape(&[60]).as_mono8().info());
assert!(data.reshape(&[4, 5, 3]).as_interlaced_rgb8().info() == Ok(ImageInfo::rgb8(5, 4)));
assert!(data.reshape(&[4, 5, 3]).as_interlaced_bgr8().info() == Ok(ImageInfo::bgr8(5, 4)));
assert!(let Err(_) = data.reshape(&[4, 5, 3, 1]).as_interlaced_bgr8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3, 1]).as_interlaced_bgr8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4]).as_interlaced_bgr8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4]).as_interlaced_bgr8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_interlaced_rgb8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_interlaced_rgb8().info());
assert!(data.reshape(&[3, 5, 4]).as_interlaced_rgba8().info() == Ok(ImageInfo::rgba8(5, 3)));
assert!(data.reshape(&[3, 5, 4]).as_interlaced_bgra8().info() == Ok(ImageInfo::bgra8(5, 3)));
assert!(let Err(_) = data.reshape(&[3, 5, 4, 1]).as_interlaced_rgba8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4, 1]).as_interlaced_bgra8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3]).as_interlaced_rgba8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3]).as_interlaced_bgra8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_interlaced_rgba8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_interlaced_bgra8().info());
}
#[test]
fn tensor_info_planar_with_known_format() {
let data = tch::Tensor::of_slice(&(0..60).collect::<Vec<u8>>());
assert!(data.reshape(&[3, 4, 5]).as_planar_rgb8().info() == Ok(ImageInfo::rgb8(5, 4)));
assert!(data.reshape(&[3, 4, 5]).as_planar_bgr8().info() == Ok(ImageInfo::bgr8(5, 4)));
assert!(let Err(_) = data.reshape(&[4, 5, 3, 1]).as_planar_bgr8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3, 1]).as_planar_bgr8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3]).as_planar_bgr8().info());
assert!(let Err(_) = data.reshape(&[4, 5, 3]).as_planar_bgr8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_planar_rgb8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_planar_rgb8().info());
assert!(data.reshape(&[4, 3, 5]).as_planar_rgba8().info() == Ok(ImageInfo::rgba8(5, 3)));
assert!(data.reshape(&[4, 3, 5]).as_planar_bgra8().info() == Ok(ImageInfo::bgra8(5, 3)));
assert!(let Err(_) = data.reshape(&[3, 5, 4, 1]).as_planar_rgba8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4, 1]).as_planar_bgra8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4]).as_planar_rgba8().info());
assert!(let Err(_) = data.reshape(&[3, 5, 4]).as_planar_bgra8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_planar_rgba8().info());
assert!(let Err(_) = data.reshape(&[15, 4]).as_planar_bgra8().info());
}
}