use crate::{DType, PixelFormat, Tensor, TensorMemory, TensorTrait};
use half::f16;
use std::fmt;
#[non_exhaustive]
pub enum TensorDyn {
U8(Tensor<u8>),
I8(Tensor<i8>),
U16(Tensor<u16>),
I16(Tensor<i16>),
U32(Tensor<u32>),
I32(Tensor<i32>),
U64(Tensor<u64>),
I64(Tensor<i64>),
F16(Tensor<f16>),
F32(Tensor<f32>),
F64(Tensor<f64>),
}
macro_rules! dispatch {
($self:expr, $method:ident $(, $arg:expr)*) => {
match $self {
TensorDyn::U8(t) => t.$method($($arg),*),
TensorDyn::I8(t) => t.$method($($arg),*),
TensorDyn::U16(t) => t.$method($($arg),*),
TensorDyn::I16(t) => t.$method($($arg),*),
TensorDyn::U32(t) => t.$method($($arg),*),
TensorDyn::I32(t) => t.$method($($arg),*),
TensorDyn::U64(t) => t.$method($($arg),*),
TensorDyn::I64(t) => t.$method($($arg),*),
TensorDyn::F16(t) => t.$method($($arg),*),
TensorDyn::F32(t) => t.$method($($arg),*),
TensorDyn::F64(t) => t.$method($($arg),*),
}
};
}
macro_rules! downcast_methods {
($variant:ident, $ty:ty, $as_name:ident, $as_mut_name:ident, $into_name:ident) => {
pub fn $as_name(&self) -> Option<&Tensor<$ty>> {
match self {
Self::$variant(t) => Some(t),
_ => None,
}
}
pub fn $as_mut_name(&mut self) -> Option<&mut Tensor<$ty>> {
match self {
Self::$variant(t) => Some(t),
_ => None,
}
}
#[allow(clippy::result_large_err)]
pub fn $into_name(self) -> Result<Tensor<$ty>, Self> {
match self {
Self::$variant(t) => Ok(t),
other => Err(other),
}
}
};
}
impl TensorDyn {
pub fn dtype(&self) -> DType {
match self {
Self::U8(_) => DType::U8,
Self::I8(_) => DType::I8,
Self::U16(_) => DType::U16,
Self::I16(_) => DType::I16,
Self::U32(_) => DType::U32,
Self::I32(_) => DType::I32,
Self::U64(_) => DType::U64,
Self::I64(_) => DType::I64,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
}
pub fn shape(&self) -> &[usize] {
dispatch!(self, shape)
}
pub fn name(&self) -> String {
dispatch!(self, name)
}
pub fn format(&self) -> Option<PixelFormat> {
dispatch!(self, format)
}
pub fn width(&self) -> Option<usize> {
dispatch!(self, width)
}
pub fn height(&self) -> Option<usize> {
dispatch!(self, height)
}
pub fn size(&self) -> usize {
dispatch!(self, size)
}
pub fn memory(&self) -> TensorMemory {
dispatch!(self, memory)
}
pub fn reshape(&mut self, shape: &[usize]) -> crate::Result<()> {
dispatch!(self, reshape, shape)
}
pub fn set_format(&mut self, format: PixelFormat) -> crate::Result<()> {
dispatch!(self, set_format, format)
}
pub fn with_format(mut self, format: PixelFormat) -> crate::Result<Self> {
self.set_format(format)?;
Ok(self)
}
pub fn row_stride(&self) -> Option<usize> {
dispatch!(self, row_stride)
}
pub fn effective_row_stride(&self) -> Option<usize> {
dispatch!(self, effective_row_stride)
}
pub fn set_row_stride(&mut self, stride: usize) -> crate::Result<()> {
dispatch!(self, set_row_stride, stride)
}
pub fn with_row_stride(mut self, stride: usize) -> crate::Result<Self> {
self.set_row_stride(stride)?;
Ok(self)
}
pub fn plane_offset(&self) -> Option<usize> {
dispatch!(self, plane_offset)
}
pub fn set_plane_offset(&mut self, offset: usize) {
dispatch!(self, set_plane_offset, offset)
}
pub fn with_plane_offset(mut self, offset: usize) -> Self {
self.set_plane_offset(offset);
self
}
pub fn quantization(&self) -> Option<&crate::Quantization> {
match self {
Self::U8(t) => t.quantization(),
Self::I8(t) => t.quantization(),
Self::U16(t) => t.quantization(),
Self::I16(t) => t.quantization(),
Self::U32(t) => t.quantization(),
Self::I32(t) => t.quantization(),
Self::U64(t) => t.quantization(),
Self::I64(t) => t.quantization(),
Self::F16(_) | Self::F32(_) | Self::F64(_) => None,
}
}
pub fn set_quantization(&mut self, q: crate::Quantization) -> crate::Result<()> {
match self {
Self::U8(t) => t.set_quantization(q),
Self::I8(t) => t.set_quantization(q),
Self::U16(t) => t.set_quantization(q),
Self::I16(t) => t.set_quantization(q),
Self::U32(t) => t.set_quantization(q),
Self::I32(t) => t.set_quantization(q),
Self::U64(t) => t.set_quantization(q),
Self::I64(t) => t.set_quantization(q),
Self::F16(_) | Self::F32(_) | Self::F64(_) => Err(crate::Error::QuantizationInvalid {
field: "dtype_is_integer",
expected: "integer tensor dtype (u8/i8/u16/i16/u32/i32/u64/i64)".to_string(),
got: format!("{:?}", self.dtype()),
}),
}
}
pub fn with_quantization(mut self, q: crate::Quantization) -> crate::Result<Self> {
self.set_quantization(q)?;
Ok(self)
}
pub fn clear_quantization(&mut self) {
match self {
Self::U8(t) => t.clear_quantization(),
Self::I8(t) => t.clear_quantization(),
Self::U16(t) => t.clear_quantization(),
Self::I16(t) => t.clear_quantization(),
Self::U32(t) => t.clear_quantization(),
Self::I32(t) => t.clear_quantization(),
Self::U64(t) => t.clear_quantization(),
Self::I64(t) => t.clear_quantization(),
Self::F16(_) | Self::F32(_) | Self::F64(_) => {}
}
}
#[cfg(unix)]
pub fn clone_fd(&self) -> crate::Result<std::os::fd::OwnedFd> {
dispatch!(self, clone_fd)
}
#[cfg(target_os = "linux")]
pub fn dmabuf_clone(&self) -> crate::Result<std::os::fd::OwnedFd> {
if self.memory() != TensorMemory::Dma {
return Err(crate::Error::NotImplemented(format!(
"dmabuf_clone requires DMA-backed tensor, got {:?}",
self.memory()
)));
}
self.clone_fd()
}
#[cfg(target_os = "linux")]
pub fn dmabuf(&self) -> crate::Result<std::os::fd::BorrowedFd<'_>> {
dispatch!(self, dmabuf)
}
pub fn is_multiplane(&self) -> bool {
dispatch!(self, is_multiplane)
}
pub fn buffer_identity(&self) -> &crate::BufferIdentity {
dispatch!(self, buffer_identity)
}
pub fn aliases(&self, other: &Self) -> bool {
if self.buffer_identity().id() == other.buffer_identity().id() {
return true;
}
if self.memory() != other.memory() {
return false;
}
#[cfg(target_os = "linux")]
if self.memory() == TensorMemory::Dma {
use std::os::fd::AsRawFd;
if let (Ok(a), Ok(b)) = (self.dmabuf(), other.dmabuf()) {
return a.as_raw_fd() == b.as_raw_fd();
}
}
false
}
downcast_methods!(U8, u8, as_u8, as_u8_mut, into_u8);
downcast_methods!(I8, i8, as_i8, as_i8_mut, into_i8);
downcast_methods!(U16, u16, as_u16, as_u16_mut, into_u16);
downcast_methods!(I16, i16, as_i16, as_i16_mut, into_i16);
downcast_methods!(U32, u32, as_u32, as_u32_mut, into_u32);
downcast_methods!(I32, i32, as_i32, as_i32_mut, into_i32);
downcast_methods!(U64, u64, as_u64, as_u64_mut, into_u64);
downcast_methods!(I64, i64, as_i64, as_i64_mut, into_i64);
downcast_methods!(F16, f16, as_f16, as_f16_mut, into_f16);
downcast_methods!(F32, f32, as_f32, as_f32_mut, into_f32);
downcast_methods!(F64, f64, as_f64, as_f64_mut, into_f64);
pub fn new(
shape: &[usize],
dtype: DType,
memory: Option<TensorMemory>,
name: Option<&str>,
) -> crate::Result<Self> {
match dtype {
DType::U8 => Tensor::<u8>::new(shape, memory, name).map(Self::U8),
DType::I8 => Tensor::<i8>::new(shape, memory, name).map(Self::I8),
DType::U16 => Tensor::<u16>::new(shape, memory, name).map(Self::U16),
DType::I16 => Tensor::<i16>::new(shape, memory, name).map(Self::I16),
DType::U32 => Tensor::<u32>::new(shape, memory, name).map(Self::U32),
DType::I32 => Tensor::<i32>::new(shape, memory, name).map(Self::I32),
DType::U64 => Tensor::<u64>::new(shape, memory, name).map(Self::U64),
DType::I64 => Tensor::<i64>::new(shape, memory, name).map(Self::I64),
DType::F16 => Tensor::<f16>::new(shape, memory, name).map(Self::F16),
DType::F32 => Tensor::<f32>::new(shape, memory, name).map(Self::F32),
DType::F64 => Tensor::<f64>::new(shape, memory, name).map(Self::F64),
}
}
#[cfg(unix)]
pub fn from_fd(
fd: std::os::fd::OwnedFd,
shape: &[usize],
dtype: DType,
name: Option<&str>,
) -> crate::Result<Self> {
match dtype {
DType::U8 => Tensor::<u8>::from_fd(fd, shape, name).map(Self::U8),
DType::I8 => Tensor::<i8>::from_fd(fd, shape, name).map(Self::I8),
DType::U16 => Tensor::<u16>::from_fd(fd, shape, name).map(Self::U16),
DType::I16 => Tensor::<i16>::from_fd(fd, shape, name).map(Self::I16),
DType::U32 => Tensor::<u32>::from_fd(fd, shape, name).map(Self::U32),
DType::I32 => Tensor::<i32>::from_fd(fd, shape, name).map(Self::I32),
DType::U64 => Tensor::<u64>::from_fd(fd, shape, name).map(Self::U64),
DType::I64 => Tensor::<i64>::from_fd(fd, shape, name).map(Self::I64),
DType::F16 => Tensor::<f16>::from_fd(fd, shape, name).map(Self::F16),
DType::F32 => Tensor::<f32>::from_fd(fd, shape, name).map(Self::F32),
DType::F64 => Tensor::<f64>::from_fd(fd, shape, name).map(Self::F64),
}
}
pub fn image(
width: usize,
height: usize,
format: PixelFormat,
dtype: DType,
memory: Option<TensorMemory>,
) -> crate::Result<Self> {
match dtype {
DType::U8 => Tensor::<u8>::image(width, height, format, memory).map(Self::U8),
DType::I8 => Tensor::<i8>::image(width, height, format, memory).map(Self::I8),
DType::U16 => Tensor::<u16>::image(width, height, format, memory).map(Self::U16),
DType::I16 => Tensor::<i16>::image(width, height, format, memory).map(Self::I16),
DType::U32 => Tensor::<u32>::image(width, height, format, memory).map(Self::U32),
DType::I32 => Tensor::<i32>::image(width, height, format, memory).map(Self::I32),
DType::U64 => Tensor::<u64>::image(width, height, format, memory).map(Self::U64),
DType::I64 => Tensor::<i64>::image(width, height, format, memory).map(Self::I64),
DType::F16 => Tensor::<f16>::image(width, height, format, memory).map(Self::F16),
DType::F32 => Tensor::<f32>::image(width, height, format, memory).map(Self::F32),
DType::F64 => Tensor::<f64>::image(width, height, format, memory).map(Self::F64),
}
}
pub fn image_with_stride(
width: usize,
height: usize,
format: PixelFormat,
dtype: DType,
row_stride_bytes: usize,
memory: Option<TensorMemory>,
) -> crate::Result<Self> {
match dtype {
DType::U8 => {
Tensor::<u8>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::U8)
}
DType::I8 => {
Tensor::<i8>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::I8)
}
DType::U16 => {
Tensor::<u16>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::U16)
}
DType::I16 => {
Tensor::<i16>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::I16)
}
DType::U32 => {
Tensor::<u32>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::U32)
}
DType::I32 => {
Tensor::<i32>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::I32)
}
DType::U64 => {
Tensor::<u64>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::U64)
}
DType::I64 => {
Tensor::<i64>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::I64)
}
DType::F16 => {
Tensor::<f16>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::F16)
}
DType::F32 => {
Tensor::<f32>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::F32)
}
DType::F64 => {
Tensor::<f64>::image_with_stride(width, height, format, row_stride_bytes, memory)
.map(Self::F64)
}
}
}
}
impl From<Tensor<u8>> for TensorDyn {
fn from(t: Tensor<u8>) -> Self {
Self::U8(t)
}
}
impl From<Tensor<i8>> for TensorDyn {
fn from(t: Tensor<i8>) -> Self {
Self::I8(t)
}
}
impl From<Tensor<u16>> for TensorDyn {
fn from(t: Tensor<u16>) -> Self {
Self::U16(t)
}
}
impl From<Tensor<i16>> for TensorDyn {
fn from(t: Tensor<i16>) -> Self {
Self::I16(t)
}
}
impl From<Tensor<u32>> for TensorDyn {
fn from(t: Tensor<u32>) -> Self {
Self::U32(t)
}
}
impl From<Tensor<i32>> for TensorDyn {
fn from(t: Tensor<i32>) -> Self {
Self::I32(t)
}
}
impl From<Tensor<u64>> for TensorDyn {
fn from(t: Tensor<u64>) -> Self {
Self::U64(t)
}
}
impl From<Tensor<i64>> for TensorDyn {
fn from(t: Tensor<i64>) -> Self {
Self::I64(t)
}
}
impl From<Tensor<f16>> for TensorDyn {
fn from(t: Tensor<f16>) -> Self {
Self::F16(t)
}
}
impl From<Tensor<f32>> for TensorDyn {
fn from(t: Tensor<f32>) -> Self {
Self::F32(t)
}
}
impl From<Tensor<f64>> for TensorDyn {
fn from(t: Tensor<f64>) -> Self {
Self::F64(t)
}
}
impl fmt::Debug for TensorDyn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
dispatch!(self, fmt, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_typed_tensor() {
let t = Tensor::<u8>::new(&[10], None, None).unwrap();
let dyn_t: TensorDyn = t.into();
assert_eq!(dyn_t.dtype(), DType::U8);
assert_eq!(dyn_t.shape(), &[10]);
}
#[test]
fn downcast_ref() {
let t = Tensor::<u8>::new(&[10], None, None).unwrap();
let dyn_t: TensorDyn = t.into();
assert!(dyn_t.as_u8().is_some());
assert!(dyn_t.as_i8().is_none());
}
#[test]
fn downcast_into() {
let t = Tensor::<u8>::new(&[10], None, None).unwrap();
let dyn_t: TensorDyn = t.into();
let back = dyn_t.into_u8().unwrap();
assert_eq!(back.shape(), &[10]);
}
#[test]
fn image_accessors() {
let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
let dyn_t: TensorDyn = t.into();
assert_eq!(dyn_t.format(), Some(PixelFormat::Rgba));
assert_eq!(dyn_t.width(), Some(640));
assert_eq!(dyn_t.height(), Some(480));
assert!(!dyn_t.is_multiplane());
}
#[test]
fn image_constructor() {
let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::U8, None).unwrap();
assert_eq!(dyn_t.dtype(), DType::U8);
assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
assert_eq!(dyn_t.width(), Some(640));
}
#[test]
fn image_constructor_i8() {
let dyn_t = TensorDyn::image(640, 480, PixelFormat::Rgb, DType::I8, None).unwrap();
assert_eq!(dyn_t.dtype(), DType::I8);
assert_eq!(dyn_t.format(), Some(PixelFormat::Rgb));
}
#[test]
fn set_format_packed() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
assert_eq!(t.format(), None);
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgb));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn set_format_planar() {
let mut t = TensorDyn::new(&[3, 480, 640], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::PlanarRgb).unwrap();
assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn set_format_rejects_wrong_shape() {
let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
assert!(t.set_format(PixelFormat::Rgb).is_err());
}
#[test]
fn with_format_builder() {
let t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
.unwrap()
.with_format(PixelFormat::Rgba)
.unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgba));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[cfg(target_os = "linux")]
#[test]
fn dmabuf_clone_mem_tensor_fails() {
let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
assert_eq!(t.memory(), TensorMemory::Mem);
assert!(t.dmabuf_clone().is_err());
}
#[cfg(target_os = "linux")]
#[test]
fn dmabuf_mem_tensor_fails() {
let t = TensorDyn::new(&[480, 640, 3], DType::U8, Some(TensorMemory::Mem), None).unwrap();
assert!(t.dmabuf().is_err());
}
#[test]
fn set_format_semi_planar_nv12() {
let mut t = TensorDyn::new(&[720, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
t.set_format(PixelFormat::Nv12).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Nv12));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn set_format_semi_planar_nv16() {
let mut t = TensorDyn::new(&[960, 640], DType::U8, Some(TensorMemory::Mem), None).unwrap();
t.set_format(PixelFormat::Nv16).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Nv16));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn with_format_rejects_wrong_shape() {
let result = TensorDyn::new(&[480, 640, 4], DType::U8, None, None)
.unwrap()
.with_format(PixelFormat::Rgb);
assert!(result.is_err());
}
#[test]
fn set_format_preserved_after_rejection() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgb));
assert!(t.set_format(PixelFormat::Rgba).is_err());
assert_eq!(t.format(), Some(PixelFormat::Rgb));
}
#[test]
fn set_format_idempotent() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgb));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn set_row_stride_valid() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
t.set_row_stride(512).unwrap();
assert_eq!(t.row_stride(), Some(512));
assert_eq!(t.effective_row_stride(), Some(512));
}
#[test]
fn set_row_stride_equals_min() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
t.set_row_stride(300).unwrap();
assert_eq!(t.row_stride(), Some(300));
}
#[test]
fn set_row_stride_too_small() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
assert!(t.set_row_stride(300).is_err());
assert_eq!(t.row_stride(), None);
}
#[test]
fn set_row_stride_zero() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
assert!(t.set_row_stride(0).is_err());
}
#[test]
fn set_row_stride_requires_format() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
assert!(t.set_row_stride(2048).is_err());
}
#[test]
fn effective_row_stride_without_stride() {
let t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
assert_eq!(t.row_stride(), None);
assert_eq!(t.effective_row_stride(), Some(300)); }
#[test]
fn effective_row_stride_no_format() {
let t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
assert_eq!(t.effective_row_stride(), None);
}
#[test]
fn with_row_stride_builder() {
let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
.unwrap()
.with_row_stride(512)
.unwrap();
assert_eq!(t.row_stride(), Some(512));
assert_eq!(t.effective_row_stride(), Some(512));
}
#[test]
fn with_row_stride_rejects_small() {
let result = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
.unwrap()
.with_row_stride(200);
assert!(result.is_err());
}
#[test]
fn set_format_clears_row_stride() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
t.set_row_stride(2048).unwrap();
assert_eq!(t.row_stride(), Some(2048));
let _ = t.set_format(PixelFormat::Bgra);
assert_eq!(t.row_stride(), Some(2048));
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.row_stride(), Some(2048));
t.reshape(&[480 * 640 * 3]).unwrap();
assert_eq!(t.row_stride(), None);
assert_eq!(t.format(), None);
}
#[test]
fn set_format_different_compatible_clears_stride() {
let mut t = TensorDyn::new(&[480, 640, 4], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::Rgba).unwrap();
t.set_row_stride(4096).unwrap();
assert_eq!(t.row_stride(), Some(4096));
t.set_format(PixelFormat::Bgra).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Bgra));
assert_eq!(t.row_stride(), None);
}
#[test]
fn set_format_same_preserves_stride() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
t.set_row_stride(512).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.row_stride(), Some(512));
}
#[test]
fn effective_row_stride_planar() {
let t = TensorDyn::image(640, 480, PixelFormat::PlanarRgb, DType::U8, None).unwrap();
assert_eq!(t.effective_row_stride(), Some(640)); }
#[test]
fn effective_row_stride_nv12() {
let t = TensorDyn::image(640, 480, PixelFormat::Nv12, DType::U8, None).unwrap();
assert_eq!(t.effective_row_stride(), Some(640)); }
#[test]
fn map_rejects_strided_tensor() {
let mut t =
Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
assert!(t.map().is_ok());
t.set_row_stride(512).unwrap();
let err = t.map();
assert!(err.is_err());
}
#[test]
fn plane_offset_default_none() {
let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
assert_eq!(t.plane_offset(), None);
}
#[test]
fn set_plane_offset_basic() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None).unwrap();
t.set_plane_offset(4096);
assert_eq!(t.plane_offset(), Some(4096));
}
#[test]
fn set_plane_offset_zero() {
let mut t = TensorDyn::image(100, 100, PixelFormat::Rgb, DType::U8, None).unwrap();
t.set_plane_offset(0);
assert_eq!(t.plane_offset(), Some(0));
}
#[test]
fn set_plane_offset_no_format() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
t.set_plane_offset(4096);
assert_eq!(t.plane_offset(), Some(4096));
}
#[test]
fn with_plane_offset_builder() {
let t = TensorDyn::image(100, 100, PixelFormat::Rgba, DType::U8, None)
.unwrap()
.with_plane_offset(8192);
assert_eq!(t.plane_offset(), Some(8192));
}
#[test]
fn set_format_clears_plane_offset() {
let mut t = TensorDyn::new(&[480, 640, 3], DType::U8, None, None).unwrap();
t.set_format(PixelFormat::Rgb).unwrap();
t.set_plane_offset(4096);
assert_eq!(t.plane_offset(), Some(4096));
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.plane_offset(), Some(4096));
t.reshape(&[480 * 640 * 3]).unwrap();
assert_eq!(t.plane_offset(), None);
assert_eq!(t.format(), None);
}
#[test]
fn map_rejects_offset_tensor() {
let mut t =
Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
assert!(t.map().is_ok());
t.set_plane_offset(4096);
assert!(t.map().is_err());
}
#[test]
fn map_accepts_zero_offset_tensor() {
let mut t =
Tensor::<u8>::image(100, 100, PixelFormat::Rgba, Some(TensorMemory::Mem)).unwrap();
t.set_plane_offset(0);
assert!(t.map().is_ok());
}
#[test]
fn from_planes_propagates_plane_offset() {
let mut luma =
Tensor::<u8>::new(&[480, 640], Some(TensorMemory::Mem), Some("luma")).unwrap();
luma.set_plane_offset(4096);
let chroma =
Tensor::<u8>::new(&[240, 640], Some(TensorMemory::Mem), Some("chroma")).unwrap();
let combined = Tensor::<u8>::from_planes(luma, chroma, PixelFormat::Nv12).unwrap();
assert_eq!(combined.plane_offset(), Some(4096));
}
}