#[cfg(target_os = "linux")]
mod dma;
#[cfg(target_os = "linux")]
mod dmabuf;
mod error;
mod format;
mod mem;
mod pbo;
#[cfg(unix)]
mod shm;
mod tensor_dyn;
#[cfg(target_os = "linux")]
pub use crate::dma::{DmaMap, DmaTensor};
pub use crate::mem::{MemMap, MemTensor};
pub use crate::pbo::{PboMap, PboMapping, PboOps, PboTensor};
#[cfg(unix)]
pub use crate::shm::{ShmMap, ShmTensor};
pub use error::{Error, Result};
pub use format::{PixelFormat, PixelLayout};
use num_traits::Num;
use serde::{Deserialize, Serialize};
#[cfg(unix)]
use std::os::fd::OwnedFd;
use std::{
fmt,
ops::{Deref, DerefMut},
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
},
};
pub use tensor_dyn::TensorDyn;
#[cfg(unix)]
pub struct PlaneDescriptor {
fd: OwnedFd,
stride: Option<usize>,
offset: Option<usize>,
}
#[cfg(unix)]
impl PlaneDescriptor {
pub fn new(fd: std::os::fd::BorrowedFd<'_>) -> Result<Self> {
let owned = fd.try_clone_to_owned()?;
Ok(Self {
fd: owned,
stride: None,
offset: None,
})
}
pub fn with_stride(mut self, stride: usize) -> Self {
self.stride = Some(stride);
self
}
pub fn with_offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
pub fn into_fd(self) -> OwnedFd {
self.fd
}
pub fn stride(&self) -> Option<usize> {
self.stride
}
pub fn offset(&self) -> Option<usize> {
self.offset
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
#[non_exhaustive]
pub enum DType {
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
F16,
F32,
F64,
}
impl DType {
pub const fn size(&self) -> usize {
match self {
Self::U8 | Self::I8 => 1,
Self::U16 | Self::I16 | Self::F16 => 2,
Self::U32 | Self::I32 | Self::F32 => 4,
Self::U64 | Self::I64 | Self::F64 => 8,
}
}
pub const fn name(&self) -> &'static str {
match self {
Self::U8 => "u8",
Self::I8 => "i8",
Self::U16 => "u16",
Self::I16 => "i16",
Self::U32 => "u32",
Self::I32 => "i32",
Self::U64 => "u64",
Self::I64 => "i64",
Self::F16 => "f16",
Self::F32 => "f32",
Self::F64 => "f64",
}
}
}
impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
mod sealed {
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for i8 {}
impl Sealed for u16 {}
impl Sealed for i16 {}
impl Sealed for u32 {}
impl Sealed for i32 {}
impl Sealed for u64 {}
impl Sealed for i64 {}
}
pub trait IntegerType: sealed::Sealed {}
impl IntegerType for u8 {}
impl IntegerType for i8 {}
impl IntegerType for u16 {}
impl IntegerType for i16 {}
impl IntegerType for u32 {}
impl IntegerType for i32 {}
impl IntegerType for u64 {}
impl IntegerType for i64 {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Quantization {
#[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
scale: Vec<f32>,
#[serde(
default,
deserialize_with = "deserialize_opt_scalar_or_vec_i32",
skip_serializing_if = "Option::is_none"
)]
zero_point: Option<Vec<i32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
axis: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub enum QuantMode<'a> {
PerTensorSymmetric {
scale: f32,
},
PerTensor {
scale: f32,
zero_point: i32,
},
PerChannelSymmetric {
scales: &'a [f32],
axis: usize,
},
PerChannel {
scales: &'a [f32],
zero_points: &'a [i32],
axis: usize,
},
}
impl Quantization {
pub fn per_tensor_symmetric(scale: f32) -> Self {
Self {
scale: vec![scale],
zero_point: None,
axis: None,
}
}
pub fn per_tensor(scale: f32, zero_point: i32) -> Self {
Self {
scale: vec![scale],
zero_point: Some(vec![zero_point]),
axis: None,
}
}
pub fn per_channel_symmetric(scales: Vec<f32>, axis: usize) -> Result<Self> {
if scales.is_empty() {
return Err(Error::QuantizationInvalid {
field: "scale.len",
expected: "non-empty per-channel scales".to_string(),
got: "length 0".to_string(),
});
}
Ok(Self {
scale: scales,
zero_point: None,
axis: Some(axis),
})
}
pub fn per_channel(scales: Vec<f32>, zero_points: Vec<i32>, axis: usize) -> Result<Self> {
if scales.is_empty() {
return Err(Error::QuantizationInvalid {
field: "scale.len",
expected: "non-empty per-channel scales".to_string(),
got: "length 0".to_string(),
});
}
if scales.len() != zero_points.len() {
return Err(Error::QuantizationInvalid {
field: "zero_point.len",
expected: format!("length matches scale ({})", scales.len()),
got: format!("length {}", zero_points.len()),
});
}
Ok(Self {
scale: scales,
zero_point: Some(zero_points),
axis: Some(axis),
})
}
pub fn mode(&self) -> QuantMode<'_> {
match (self.scale.len(), self.zero_point.as_deref(), self.axis) {
(1, None, _) => QuantMode::PerTensorSymmetric {
scale: self.scale[0],
},
(1, Some(zps), _) => QuantMode::PerTensor {
scale: self.scale[0],
zero_point: zps.first().copied().unwrap_or(0),
},
(_, None, Some(axis)) => QuantMode::PerChannelSymmetric {
scales: &self.scale,
axis,
},
(_, Some(zps), Some(axis)) => QuantMode::PerChannel {
scales: &self.scale,
zero_points: zps,
axis,
},
_ => {
debug_assert!(
false,
"Quantization::mode: per-channel without axis is unreachable"
);
QuantMode::PerTensorSymmetric {
scale: self.scale.first().copied().unwrap_or(1.0),
}
}
}
}
pub fn is_per_tensor(&self) -> bool {
self.scale.len() == 1
}
pub fn is_per_channel(&self) -> bool {
self.scale.len() > 1
}
pub fn is_symmetric(&self) -> bool {
match &self.zero_point {
None => true,
Some(zps) => zps.iter().all(|&z| z == 0),
}
}
pub fn scale(&self) -> &[f32] {
&self.scale
}
pub fn zero_point(&self) -> Option<&[i32]> {
self.zero_point.as_deref()
}
pub fn axis(&self) -> Option<usize> {
self.axis
}
pub(crate) fn validate(&self, shape: &[usize]) -> Result<()> {
if self.scale.is_empty() {
return Err(Error::QuantizationInvalid {
field: "scale.len",
expected: ">= 1".to_string(),
got: "0".to_string(),
});
}
if let Some(zps) = self.zero_point.as_ref() {
let expected = if self.scale.len() == 1 {
1
} else {
self.scale.len()
};
if zps.len() != expected {
return Err(Error::QuantizationInvalid {
field: "zero_point.len",
expected: format!(
"{expected} (matching {})",
if self.scale.len() == 1 {
"per-tensor scale"
} else {
"per-channel scale.len"
}
),
got: format!("length {}", zps.len()),
});
}
}
match (self.scale.len(), self.axis) {
(1, None) => Ok(()),
(1, Some(_)) => Err(Error::QuantizationInvalid {
field: "per_tensor_redundant_axis",
expected: "axis=None for per-tensor quantization".to_string(),
got: format!("axis={:?}", self.axis),
}),
(_, None) => Err(Error::QuantizationInvalid {
field: "per_channel_requires_axis",
expected: format!(
"axis=Some(_) for per-channel quantization (scale.len={})",
self.scale.len()
),
got: "axis=None".to_string(),
}),
(n, Some(axis)) => {
if axis >= shape.len() {
return Err(Error::QuantizationInvalid {
field: "axis",
expected: format!("axis < tensor rank ({})", shape.len()),
got: format!("axis={axis}"),
});
}
if shape[axis] != n {
return Err(Error::QuantizationInvalid {
field: "scale.len",
expected: format!("length matches shape[{axis}] ({})", shape[axis]),
got: format!("length {n}"),
});
}
Ok(())
}
}
}
}
impl From<(f32, i32)> for Quantization {
fn from((scale, zero_point): (f32, i32)) -> Self {
Self::per_tensor(scale, zero_point)
}
}
fn deserialize_scalar_or_vec_f32<'de, D: serde::Deserializer<'de>>(
de: D,
) -> std::result::Result<Vec<f32>, D::Error> {
use serde::de::{self, Visitor};
struct V;
impl<'de> Visitor<'de> for V {
type Value = Vec<f32>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("f32 or array of f32")
}
fn visit_f64<E: de::Error>(self, v: f64) -> std::result::Result<Self::Value, E> {
Ok(vec![v as f32])
}
#[allow(clippy::cast_possible_truncation)]
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
Ok(vec![v as f32])
}
#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(vec![v as f32])
}
fn visit_seq<A: de::SeqAccess<'de>>(
self,
mut seq: A,
) -> std::result::Result<Self::Value, A::Error> {
let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
while let Some(x) = seq.next_element::<f32>()? {
out.push(x);
}
Ok(out)
}
}
de.deserialize_any(V)
}
fn deserialize_opt_scalar_or_vec_i32<'de, D: serde::Deserializer<'de>>(
de: D,
) -> std::result::Result<Option<Vec<i32>>, D::Error> {
use serde::de::{self, Visitor};
struct V;
impl<'de> Visitor<'de> for V {
type Value = Option<Vec<i32>>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("null, i32, or array of i32")
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_some<D2: serde::Deserializer<'de>>(
self,
de: D2,
) -> std::result::Result<Self::Value, D2::Error> {
struct Inner;
impl<'de> Visitor<'de> for Inner {
type Value = Vec<i32>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("i32 or array of i32")
}
#[allow(clippy::cast_possible_truncation)]
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
Ok(vec![v as i32])
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(vec![v as i32])
}
fn visit_seq<A: de::SeqAccess<'de>>(
self,
mut seq: A,
) -> std::result::Result<Self::Value, A::Error> {
let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
while let Some(x) = seq.next_element::<i32>()? {
out.push(x);
}
Ok(out)
}
}
de.deserialize_any(Inner).map(Some)
}
#[allow(clippy::cast_possible_truncation)]
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
Ok(Some(vec![v as i32]))
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(vec![v as i32]))
}
fn visit_seq<A: de::SeqAccess<'de>>(
self,
mut seq: A,
) -> std::result::Result<Self::Value, A::Error> {
let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(1));
while let Some(x) = seq.next_element::<i32>()? {
out.push(x);
}
Ok(Some(out))
}
}
de.deserialize_option(V)
}
static NEXT_BUFFER_ID: AtomicU64 = AtomicU64::new(1);
#[derive(Debug, Clone)]
pub struct BufferIdentity {
id: u64,
guard: Arc<()>,
}
impl BufferIdentity {
pub fn new() -> Self {
Self {
id: NEXT_BUFFER_ID.fetch_add(1, Ordering::Relaxed),
guard: Arc::new(()),
}
}
pub fn id(&self) -> u64 {
self.id
}
pub fn weak(&self) -> Weak<()> {
Arc::downgrade(&self.guard)
}
}
impl Default for BufferIdentity {
fn default() -> Self {
Self::new()
}
}
#[cfg(target_os = "linux")]
use nix::sys::stat::{major, minor};
pub trait TensorTrait<T>: Send + Sync
where
T: Num + Clone + fmt::Debug,
{
fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
where
Self: Sized;
#[cfg(unix)]
fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
where
Self: Sized;
#[cfg(unix)]
fn clone_fd(&self) -> Result<std::os::fd::OwnedFd>;
fn memory(&self) -> TensorMemory;
fn name(&self) -> String;
fn len(&self) -> usize {
self.shape().iter().product()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn size(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn shape(&self) -> &[usize];
fn reshape(&mut self, shape: &[usize]) -> Result<()>;
fn map(&self) -> Result<TensorMap<T>>;
fn buffer_identity(&self) -> &BufferIdentity;
}
pub trait TensorMapTrait<T>
where
T: Num + Clone + fmt::Debug,
{
fn shape(&self) -> &[usize];
fn unmap(&mut self);
fn len(&self) -> usize {
self.shape().iter().product()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn size(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn as_slice(&self) -> &[T];
fn as_mut_slice(&mut self) -> &mut [T];
#[cfg(feature = "ndarray")]
fn view(&'_ self) -> Result<ndarray::ArrayView<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
Ok(ndarray::ArrayView::from_shape(
self.shape(),
self.as_slice(),
)?)
}
#[cfg(feature = "ndarray")]
fn view_mut(
&'_ mut self,
) -> Result<ndarray::ArrayViewMut<'_, T, ndarray::Dim<ndarray::IxDynImpl>>> {
let shape = self.shape().to_vec();
Ok(ndarray::ArrayViewMut::from_shape(
shape,
self.as_mut_slice(),
)?)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorMemory {
#[cfg(target_os = "linux")]
Dma,
#[cfg(unix)]
Shm,
Mem,
Pbo,
}
impl From<TensorMemory> for String {
fn from(memory: TensorMemory) -> Self {
match memory {
#[cfg(target_os = "linux")]
TensorMemory::Dma => "dma".to_owned(),
#[cfg(unix)]
TensorMemory::Shm => "shm".to_owned(),
TensorMemory::Mem => "mem".to_owned(),
TensorMemory::Pbo => "pbo".to_owned(),
}
}
}
impl TryFrom<&str> for TensorMemory {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
match s {
#[cfg(target_os = "linux")]
"dma" => Ok(TensorMemory::Dma),
#[cfg(unix)]
"shm" => Ok(TensorMemory::Shm),
"mem" => Ok(TensorMemory::Mem),
"pbo" => Ok(TensorMemory::Pbo),
_ => Err(Error::InvalidMemoryType(s.to_owned())),
}
}
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) enum TensorStorage<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
#[cfg(target_os = "linux")]
Dma(DmaTensor<T>),
#[cfg(unix)]
Shm(ShmTensor<T>),
Mem(MemTensor<T>),
Pbo(PboTensor<T>),
}
impl<T> TensorStorage<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
match memory {
#[cfg(target_os = "linux")]
Some(TensorMemory::Dma) => {
DmaTensor::<T>::new(shape, name).map(TensorStorage::Dma)
}
#[cfg(unix)]
Some(TensorMemory::Shm) => {
ShmTensor::<T>::new(shape, name).map(TensorStorage::Shm)
}
Some(TensorMemory::Mem) => {
MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
}
Some(TensorMemory::Pbo) => Err(crate::error::Error::NotImplemented(
"PboTensor cannot be created via Tensor::new() — use ImageProcessor::create_image()".to_owned(),
)),
None => {
if std::env::var("EDGEFIRST_TENSOR_FORCE_MEM")
.is_ok_and(|x| x != "0" && x.to_lowercase() != "false")
{
MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
} else {
#[cfg(target_os = "linux")]
{
match DmaTensor::<T>::new(shape, name) {
Ok(tensor) => Ok(TensorStorage::Dma(tensor)),
Err(_) => {
match ShmTensor::<T>::new(shape, name)
.map(TensorStorage::Shm)
{
Ok(tensor) => Ok(tensor),
Err(_) => MemTensor::<T>::new(shape, name)
.map(TensorStorage::Mem),
}
}
}
}
#[cfg(all(unix, not(target_os = "linux")))]
{
match ShmTensor::<T>::new(shape, name) {
Ok(tensor) => Ok(TensorStorage::Shm(tensor)),
Err(_) => {
MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
}
}
}
#[cfg(not(unix))]
{
MemTensor::<T>::new(shape, name).map(TensorStorage::Mem)
}
}
}
}
}
#[cfg(target_os = "linux")]
pub(crate) fn new_dma_with_byte_size(
shape: &[usize],
byte_size: usize,
name: Option<&str>,
) -> Result<Self> {
DmaTensor::<T>::new_with_byte_size(shape, byte_size, name).map(TensorStorage::Dma)
}
#[cfg(unix)]
fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
#[cfg(target_os = "linux")]
{
use nix::sys::stat::fstat;
let stat = fstat(&fd)?;
let major = major(stat.st_dev);
let minor = minor(stat.st_dev);
log::debug!("Creating tensor from fd: major={major}, minor={minor}");
if major != 0 {
return Err(Error::UnknownDeviceType(major, minor));
}
match minor {
9 | 10 => {
DmaTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Dma)
}
_ => {
ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
}
}
}
#[cfg(all(unix, not(target_os = "linux")))]
{
ShmTensor::<T>::from_fd(fd, shape, name).map(TensorStorage::Shm)
}
}
}
impl<T> TensorTrait<T> for TensorStorage<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn new(shape: &[usize], name: Option<&str>) -> Result<Self> {
Self::new(shape, None, name)
}
#[cfg(unix)]
fn from_fd(fd: OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self> {
Self::from_fd(fd, shape, name)
}
#[cfg(unix)]
fn clone_fd(&self) -> Result<OwnedFd> {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.clone_fd(),
TensorStorage::Shm(t) => t.clone_fd(),
TensorStorage::Mem(t) => t.clone_fd(),
TensorStorage::Pbo(t) => t.clone_fd(),
}
}
fn memory(&self) -> TensorMemory {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(_) => TensorMemory::Dma,
#[cfg(unix)]
TensorStorage::Shm(_) => TensorMemory::Shm,
TensorStorage::Mem(_) => TensorMemory::Mem,
TensorStorage::Pbo(_) => TensorMemory::Pbo,
}
}
fn name(&self) -> String {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.name(),
#[cfg(unix)]
TensorStorage::Shm(t) => t.name(),
TensorStorage::Mem(t) => t.name(),
TensorStorage::Pbo(t) => t.name(),
}
}
fn shape(&self) -> &[usize] {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.shape(),
#[cfg(unix)]
TensorStorage::Shm(t) => t.shape(),
TensorStorage::Mem(t) => t.shape(),
TensorStorage::Pbo(t) => t.shape(),
}
}
fn reshape(&mut self, shape: &[usize]) -> Result<()> {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.reshape(shape),
#[cfg(unix)]
TensorStorage::Shm(t) => t.reshape(shape),
TensorStorage::Mem(t) => t.reshape(shape),
TensorStorage::Pbo(t) => t.reshape(shape),
}
}
fn map(&self) -> Result<TensorMap<T>> {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.map(),
#[cfg(unix)]
TensorStorage::Shm(t) => t.map(),
TensorStorage::Mem(t) => t.map(),
TensorStorage::Pbo(t) => t.map(),
}
}
fn buffer_identity(&self) -> &BufferIdentity {
match self {
#[cfg(target_os = "linux")]
TensorStorage::Dma(t) => t.buffer_identity(),
#[cfg(unix)]
TensorStorage::Shm(t) => t.buffer_identity(),
TensorStorage::Mem(t) => t.buffer_identity(),
TensorStorage::Pbo(t) => t.buffer_identity(),
}
}
}
#[derive(Debug)]
pub struct Tensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
pub(crate) storage: TensorStorage<T>,
format: Option<PixelFormat>,
chroma: Option<Box<Tensor<T>>>,
row_stride: Option<usize>,
plane_offset: Option<usize>,
pub(crate) quantization: Option<Quantization>,
}
impl<T> Tensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
pub(crate) fn wrap(storage: TensorStorage<T>) -> Self {
Self {
storage,
format: None,
chroma: None,
row_stride: None,
plane_offset: None,
quantization: None,
}
}
pub fn from_slice(values: &[T], shape: &[usize]) -> Result<Self>
where
T: Copy,
{
let expected: usize = shape.iter().product();
if values.len() != expected {
return Err(Error::InvalidShape(format!(
"from_slice: values.len()={} but shape product={expected} (shape={shape:?})",
values.len()
)));
}
let t = Self::new(shape, Some(TensorMemory::Mem), None)?;
{
let mut m = t.map()?;
m.as_mut_slice().copy_from_slice(values);
}
Ok(t)
}
#[cfg(feature = "ndarray")]
pub fn from_arrayview3(view: ndarray::ArrayView3<'_, T>) -> Result<Self>
where
T: Copy,
{
let (h, w, c) = view.dim();
let t = Self::new(&[h, w, c], Some(TensorMemory::Mem), None)?;
{
let mut m = t.map()?;
let dst = m.as_mut_slice();
if let Some(src) = view.as_slice() {
dst.copy_from_slice(src);
} else {
for (d, &s) in dst.iter_mut().zip(view.iter()) {
*d = s;
}
}
}
Ok(t)
}
pub fn new(shape: &[usize], memory: Option<TensorMemory>, name: Option<&str>) -> Result<Self> {
let _span = tracing::trace_span!(
"tensor_alloc",
?shape,
memory = ?memory,
dtype = std::any::type_name::<T>(),
)
.entered();
TensorStorage::new(shape, memory, name).map(Self::wrap)
}
pub fn image(
width: usize,
height: usize,
format: PixelFormat,
memory: Option<TensorMemory>,
) -> Result<Self> {
let shape = match format.layout() {
PixelLayout::Packed => vec![height, width, format.channels()],
PixelLayout::Planar => vec![format.channels(), height, width],
PixelLayout::SemiPlanar => {
let total_h = match format {
PixelFormat::Nv12 => {
if !height.is_multiple_of(2) {
return Err(Error::InvalidArgument(format!(
"NV12 requires even height, got {height}"
)));
}
height * 3 / 2
}
PixelFormat::Nv16 => height * 2,
_ => {
return Err(Error::InvalidArgument(format!(
"unknown semi-planar height multiplier for {format:?}"
)))
}
};
vec![total_h, width]
}
};
let mut t = Self::new(&shape, memory, None)?;
t.format = Some(format);
Ok(t)
}
pub fn image_with_stride(
width: usize,
height: usize,
format: PixelFormat,
row_stride_bytes: usize,
memory: Option<TensorMemory>,
) -> Result<Self> {
#[cfg(not(target_os = "linux"))]
{
let _ = (width, height, format, row_stride_bytes, memory);
Err(Error::NotImplemented(
"image_with_stride requires DMA support (Linux only)".to_owned(),
))
}
#[cfg(target_os = "linux")]
{
if format.layout() != PixelLayout::Packed {
return Err(Error::NotImplemented(format!(
"Tensor::image_with_stride only supports packed pixel layouts, got {format:?}"
)));
}
let elem = std::mem::size_of::<T>();
let min_stride = width
.checked_mul(format.channels())
.and_then(|p| p.checked_mul(elem))
.ok_or_else(|| {
Error::InvalidArgument(format!(
"image_with_stride: width {width} × channels {} × sizeof::<T>={elem} \
overflows usize",
format.channels()
))
})?;
if row_stride_bytes < min_stride {
return Err(Error::InvalidArgument(format!(
"image_with_stride: row_stride {row_stride_bytes} < minimum {min_stride} \
({width} px × {} ch × {elem} B)",
format.channels()
)));
}
let total_byte_size = row_stride_bytes.checked_mul(height).ok_or_else(|| {
Error::InvalidArgument(format!(
"image_with_stride: row_stride {row_stride_bytes} × height {height} overflows usize"
))
})?;
let shape = vec![height, width, format.channels()];
let storage = match memory {
Some(TensorMemory::Dma) | None => {
TensorStorage::<T>::new_dma_with_byte_size(&shape, total_byte_size, None)?
}
Some(other) => {
return Err(Error::NotImplemented(format!(
"image_with_stride: only TensorMemory::Dma is supported, got {other:?}"
)));
}
};
let mut t = Self::wrap(storage);
t.format = Some(format);
t.row_stride = Some(row_stride_bytes);
Ok(t)
}
}
pub fn set_format(&mut self, format: PixelFormat) -> Result<()> {
let shape = self.shape();
match format.layout() {
PixelLayout::Packed => {
if shape.len() != 3 || shape[2] != format.channels() {
return Err(Error::InvalidShape(format!(
"packed format {format:?} expects [H, W, {}], got {shape:?}",
format.channels()
)));
}
}
PixelLayout::Planar => {
if shape.len() != 3 || shape[0] != format.channels() {
return Err(Error::InvalidShape(format!(
"planar format {format:?} expects [{}, H, W], got {shape:?}",
format.channels()
)));
}
}
PixelLayout::SemiPlanar => {
if shape.len() != 2 {
return Err(Error::InvalidShape(format!(
"semi-planar format {format:?} expects [H*k, W], got {shape:?}"
)));
}
match format {
PixelFormat::Nv12 if !shape[0].is_multiple_of(3) => {
return Err(Error::InvalidShape(format!(
"NV12 contiguous shape[0] must be divisible by 3, got {}",
shape[0]
)));
}
PixelFormat::Nv16 if !shape[0].is_multiple_of(2) => {
return Err(Error::InvalidShape(format!(
"NV16 contiguous shape[0] must be even, got {}",
shape[0]
)));
}
_ => {}
}
}
}
if self.format != Some(format) {
self.row_stride = None;
self.plane_offset = None;
#[cfg(target_os = "linux")]
if let TensorStorage::Dma(ref mut dma) = self.storage {
dma.mmap_offset = 0;
}
}
self.format = Some(format);
Ok(())
}
pub fn format(&self) -> Option<PixelFormat> {
self.format
}
pub fn width(&self) -> Option<usize> {
let fmt = self.format?;
let shape = self.shape();
match fmt.layout() {
PixelLayout::Packed => Some(shape[1]),
PixelLayout::Planar => Some(shape[2]),
PixelLayout::SemiPlanar => Some(shape[1]),
}
}
pub fn height(&self) -> Option<usize> {
let fmt = self.format?;
let shape = self.shape();
match fmt.layout() {
PixelLayout::Packed => Some(shape[0]),
PixelLayout::Planar => Some(shape[1]),
PixelLayout::SemiPlanar => {
if self.is_multiplane() {
Some(shape[0])
} else {
match fmt {
PixelFormat::Nv12 => Some(shape[0] * 2 / 3),
PixelFormat::Nv16 => Some(shape[0] / 2),
_ => None,
}
}
}
}
}
pub fn from_planes(luma: Tensor<T>, chroma: Tensor<T>, format: PixelFormat) -> Result<Self> {
if format.layout() != PixelLayout::SemiPlanar {
return Err(Error::InvalidArgument(format!(
"from_planes requires a semi-planar format, got {format:?}"
)));
}
if chroma.format.is_some() || chroma.chroma.is_some() {
return Err(Error::InvalidArgument(
"chroma tensor must be a raw tensor (no format or chroma metadata)".into(),
));
}
let luma_shape = luma.shape();
let chroma_shape = chroma.shape();
if luma_shape.len() != 2 || chroma_shape.len() != 2 {
return Err(Error::InvalidArgument(format!(
"from_planes expects 2D shapes, got luma={luma_shape:?} chroma={chroma_shape:?}"
)));
}
if luma_shape[1] != chroma_shape[1] {
return Err(Error::InvalidArgument(format!(
"luma width {} != chroma width {}",
luma_shape[1], chroma_shape[1]
)));
}
match format {
PixelFormat::Nv12 => {
if luma_shape[0] % 2 != 0 {
return Err(Error::InvalidArgument(format!(
"NV12 requires even luma height, got {}",
luma_shape[0]
)));
}
if chroma_shape[0] != luma_shape[0] / 2 {
return Err(Error::InvalidArgument(format!(
"NV12 chroma height {} != luma height / 2 ({})",
chroma_shape[0],
luma_shape[0] / 2
)));
}
}
PixelFormat::Nv16 => {
if chroma_shape[0] != luma_shape[0] {
return Err(Error::InvalidArgument(format!(
"NV16 chroma height {} != luma height {}",
chroma_shape[0], luma_shape[0]
)));
}
}
_ => {
return Err(Error::InvalidArgument(format!(
"from_planes only supports NV12 and NV16, got {format:?}"
)));
}
}
Ok(Tensor {
storage: luma.storage,
format: Some(format),
chroma: Some(Box::new(chroma)),
row_stride: luma.row_stride,
plane_offset: luma.plane_offset,
quantization: luma.quantization,
})
}
pub fn is_multiplane(&self) -> bool {
self.chroma.is_some()
}
pub fn chroma(&self) -> Option<&Tensor<T>> {
self.chroma.as_deref()
}
pub fn chroma_mut(&mut self) -> Option<&mut Tensor<T>> {
self.chroma.as_deref_mut()
}
pub fn row_stride(&self) -> Option<usize> {
self.row_stride
}
pub fn effective_row_stride(&self) -> Option<usize> {
if let Some(s) = self.row_stride {
return Some(s);
}
let fmt = self.format?;
let w = self.width()?;
let elem = std::mem::size_of::<T>();
Some(match fmt.layout() {
PixelLayout::Packed => w * fmt.channels() * elem,
PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
})
}
pub fn set_row_stride(&mut self, stride: usize) -> Result<()> {
let fmt = self.format.ok_or_else(|| {
Error::InvalidArgument("cannot set row_stride without a pixel format".into())
})?;
let w = self.width().ok_or_else(|| {
Error::InvalidArgument("cannot determine width for row_stride validation".into())
})?;
let elem = std::mem::size_of::<T>();
let min_stride = match fmt.layout() {
PixelLayout::Packed => w * fmt.channels() * elem,
PixelLayout::Planar | PixelLayout::SemiPlanar => w * elem,
};
if stride < min_stride {
return Err(Error::InvalidArgument(format!(
"row_stride {stride} < minimum {min_stride} for {fmt:?} at width {w}"
)));
}
self.row_stride = Some(stride);
Ok(())
}
pub fn set_row_stride_unchecked(&mut self, stride: usize) {
self.row_stride = Some(stride);
}
pub fn with_row_stride(mut self, stride: usize) -> Result<Self> {
self.set_row_stride(stride)?;
Ok(self)
}
pub fn plane_offset(&self) -> Option<usize> {
self.plane_offset
}
pub fn set_plane_offset(&mut self, offset: usize) {
self.plane_offset = Some(offset);
#[cfg(target_os = "linux")]
if let TensorStorage::Dma(ref mut dma) = self.storage {
dma.mmap_offset = offset;
}
}
pub fn with_plane_offset(mut self, offset: usize) -> Self {
self.set_plane_offset(offset);
self
}
pub fn as_pbo(&self) -> Option<&PboTensor<T>> {
match &self.storage {
TensorStorage::Pbo(p) => Some(p),
_ => None,
}
}
#[cfg(target_os = "linux")]
pub fn as_dma(&self) -> Option<&DmaTensor<T>> {
match &self.storage {
TensorStorage::Dma(d) => Some(d),
_ => None,
}
}
#[cfg(target_os = "linux")]
pub fn dmabuf(&self) -> Result<std::os::fd::BorrowedFd<'_>> {
use std::os::fd::AsFd;
match &self.storage {
TensorStorage::Dma(dma) => Ok(dma.fd.as_fd()),
_ => Err(Error::NotImplemented(format!(
"dmabuf requires DMA-backed tensor, got {:?}",
self.storage.memory()
))),
}
}
pub fn from_pbo(pbo: PboTensor<T>) -> Self {
Self {
storage: TensorStorage::Pbo(pbo),
format: None,
chroma: None,
row_stride: None,
plane_offset: None,
quantization: None,
}
}
}
impl<T> Tensor<T>
where
T: IntegerType + Num + Clone + fmt::Debug + Send + Sync,
{
pub fn quantization(&self) -> Option<&Quantization> {
self.quantization.as_ref()
}
pub fn set_quantization(&mut self, q: Quantization) -> Result<()> {
q.validate(self.shape())?;
self.quantization = Some(q);
Ok(())
}
pub fn with_quantization(mut self, q: Quantization) -> Result<Self> {
self.set_quantization(q)?;
Ok(self)
}
pub fn clear_quantization(&mut self) {
self.quantization = None;
}
}
impl<T> TensorTrait<T> for Tensor<T>
where
T: Num + Clone + fmt::Debug + Send + Sync,
{
fn new(shape: &[usize], name: Option<&str>) -> Result<Self>
where
Self: Sized,
{
Self::new(shape, None, name)
}
#[cfg(unix)]
fn from_fd(fd: std::os::fd::OwnedFd, shape: &[usize], name: Option<&str>) -> Result<Self>
where
Self: Sized,
{
Ok(Self::wrap(TensorStorage::from_fd(fd, shape, name)?))
}
#[cfg(unix)]
fn clone_fd(&self) -> Result<std::os::fd::OwnedFd> {
self.storage.clone_fd()
}
fn memory(&self) -> TensorMemory {
self.storage.memory()
}
fn name(&self) -> String {
self.storage.name()
}
fn shape(&self) -> &[usize] {
self.storage.shape()
}
fn reshape(&mut self, shape: &[usize]) -> Result<()> {
if self.chroma.is_some() {
return Err(Error::InvalidOperation(
"cannot reshape a multiplane tensor — decompose planes first".into(),
));
}
self.storage.reshape(shape)?;
self.format = None;
self.row_stride = None;
self.plane_offset = None;
#[cfg(target_os = "linux")]
if let TensorStorage::Dma(ref mut dma) = self.storage {
dma.mmap_offset = 0;
}
Ok(())
}
fn map(&self) -> Result<TensorMap<T>> {
let _span = tracing::trace_span!(
"tensor_map",
memory = ?self.storage.memory(),
)
.entered();
#[cfg(target_os = "linux")]
if let Some(stride) = self.row_stride {
if let TensorStorage::Dma(dma) = &self.storage {
if !dma.is_imported {
let height = self.height().ok_or_else(|| {
Error::InvalidOperation(
"Tensor::map: strided DMA mapping requires a PixelFormat \
so height() can be derived; set a format before mapping \
or clear row_stride for raw tensor access"
.into(),
)
})?;
let total_bytes = stride.checked_mul(height).ok_or_else(|| {
Error::InvalidOperation(format!(
"Tensor::map: row_stride {stride} × height {height} overflows usize"
))
})?;
let available_bytes = dma.buf_size.saturating_sub(dma.mmap_offset);
if total_bytes > available_bytes {
return Err(Error::InvalidOperation(format!(
"Tensor::map: strided mapping needs {total_bytes} bytes \
but DMA buffer only has {available_bytes} available \
(buf_size={}, mmap_offset={}, stride={stride}, height={height}); \
the row_stride was likely set larger than the original allocation",
dma.buf_size, dma.mmap_offset
)));
}
return dma.map_with_byte_size(total_bytes).map(TensorMap::Dma);
}
}
return Err(Error::InvalidOperation(
"CPU mapping of strided foreign tensors is not supported; \
use GPU path only"
.into(),
));
}
#[cfg(not(target_os = "linux"))]
if self.row_stride.is_some() {
return Err(Error::InvalidOperation(
"CPU mapping of strided tensors is not supported on this \
platform (DMA backing is Linux-only)"
.into(),
));
}
if self.plane_offset.is_some_and(|o| o > 0) {
#[cfg(target_os = "linux")]
if !matches!(self.storage, TensorStorage::Dma(_)) {
return Err(Error::InvalidOperation(
"plane offset only supported for DMA tensors".into(),
));
}
#[cfg(not(target_os = "linux"))]
return Err(Error::InvalidOperation(
"plane offset only supported for DMA tensors".into(),
));
}
self.storage.map()
}
fn buffer_identity(&self) -> &BufferIdentity {
self.storage.buffer_identity()
}
}
pub enum TensorMap<T>
where
T: Num + Clone + fmt::Debug,
{
#[cfg(target_os = "linux")]
Dma(DmaMap<T>),
#[cfg(unix)]
Shm(ShmMap<T>),
Mem(MemMap<T>),
Pbo(PboMap<T>),
}
impl<T> TensorMapTrait<T> for TensorMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn shape(&self) -> &[usize] {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.shape(),
#[cfg(unix)]
TensorMap::Shm(map) => map.shape(),
TensorMap::Mem(map) => map.shape(),
TensorMap::Pbo(map) => map.shape(),
}
}
fn unmap(&mut self) {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.unmap(),
#[cfg(unix)]
TensorMap::Shm(map) => map.unmap(),
TensorMap::Mem(map) => map.unmap(),
TensorMap::Pbo(map) => map.unmap(),
}
}
fn as_slice(&self) -> &[T] {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.as_slice(),
#[cfg(unix)]
TensorMap::Shm(map) => map.as_slice(),
TensorMap::Mem(map) => map.as_slice(),
TensorMap::Pbo(map) => map.as_slice(),
}
}
fn as_mut_slice(&mut self) -> &mut [T] {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.as_mut_slice(),
#[cfg(unix)]
TensorMap::Shm(map) => map.as_mut_slice(),
TensorMap::Mem(map) => map.as_mut_slice(),
TensorMap::Pbo(map) => map.as_mut_slice(),
}
}
}
impl<T> Deref for TensorMap<T>
where
T: Num + Clone + fmt::Debug,
{
type Target = [T];
fn deref(&self) -> &[T] {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.deref(),
#[cfg(unix)]
TensorMap::Shm(map) => map.deref(),
TensorMap::Mem(map) => map.deref(),
TensorMap::Pbo(map) => map.deref(),
}
}
}
impl<T> DerefMut for TensorMap<T>
where
T: Num + Clone + fmt::Debug,
{
fn deref_mut(&mut self) -> &mut [T] {
match self {
#[cfg(target_os = "linux")]
TensorMap::Dma(map) => map.deref_mut(),
#[cfg(unix)]
TensorMap::Shm(map) => map.deref_mut(),
TensorMap::Mem(map) => map.deref_mut(),
TensorMap::Pbo(map) => map.deref_mut(),
}
}
}
#[cfg(target_os = "linux")]
static DMA_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
#[cfg(target_os = "linux")]
pub fn is_dma_available() -> bool {
*DMA_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Dma), None).is_ok())
}
#[cfg(not(target_os = "linux"))]
pub fn is_dma_available() -> bool {
false
}
#[cfg(unix)]
static SHM_AVAILABLE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
#[cfg(unix)]
pub fn is_shm_available() -> bool {
*SHM_AVAILABLE.get_or_init(|| Tensor::<u8>::new(&[64], Some(TensorMemory::Shm), None).is_ok())
}
#[cfg(not(unix))]
pub fn is_shm_available() -> bool {
false
}
#[cfg(test)]
mod dtype_tests {
use super::*;
#[test]
fn dtype_size() {
assert_eq!(DType::U8.size(), 1);
assert_eq!(DType::I8.size(), 1);
assert_eq!(DType::U16.size(), 2);
assert_eq!(DType::I16.size(), 2);
assert_eq!(DType::U32.size(), 4);
assert_eq!(DType::I32.size(), 4);
assert_eq!(DType::U64.size(), 8);
assert_eq!(DType::I64.size(), 8);
assert_eq!(DType::F16.size(), 2);
assert_eq!(DType::F32.size(), 4);
assert_eq!(DType::F64.size(), 8);
}
#[test]
fn dtype_name() {
assert_eq!(DType::U8.name(), "u8");
assert_eq!(DType::F16.name(), "f16");
assert_eq!(DType::F32.name(), "f32");
}
#[test]
fn dtype_serde_roundtrip() {
use serde_json;
let dt = DType::F16;
let json = serde_json::to_string(&dt).unwrap();
let back: DType = serde_json::from_str(&json).unwrap();
assert_eq!(dt, back);
}
}
#[cfg(test)]
mod image_tests {
use super::*;
#[test]
fn raw_tensor_has_no_format() {
let t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
assert!(t.format().is_none());
assert!(t.width().is_none());
assert!(t.height().is_none());
assert!(!t.is_multiplane());
assert!(t.chroma().is_none());
}
#[test]
fn image_tensor_packed() {
let t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgba));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
assert_eq!(t.shape(), &[480, 640, 4]);
assert!(!t.is_multiplane());
}
#[test]
fn image_tensor_planar() {
let t = Tensor::<u8>::image(640, 480, PixelFormat::PlanarRgb, None).unwrap();
assert_eq!(t.format(), Some(PixelFormat::PlanarRgb));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
assert_eq!(t.shape(), &[3, 480, 640]);
}
#[test]
fn image_tensor_semi_planar_contiguous() {
let t = Tensor::<u8>::image(640, 480, PixelFormat::Nv12, None).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Nv12));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
assert_eq!(t.shape(), &[720, 640]);
assert!(!t.is_multiplane());
}
#[test]
#[cfg(target_os = "linux")]
fn image_tensor_with_stride_preserves_logical_width() {
if !is_dma_available() {
eprintln!("SKIPPED: DMA heap not available");
return;
}
let stride = 12032;
let t = Tensor::<u8>::image_with_stride(
3004,
1688,
PixelFormat::Rgba,
stride,
Some(TensorMemory::Dma),
)
.unwrap();
assert_eq!(t.width(), Some(3004));
assert_eq!(t.height(), Some(1688));
assert_eq!(t.shape(), &[1688, 3004, 4]);
assert_eq!(t.effective_row_stride(), Some(stride));
use crate::TensorMapTrait;
{
let map = t.map().unwrap();
assert!(
map.as_slice().len() >= stride * 1688,
"mapped buffer {} bytes < expected {}",
map.as_slice().len(),
stride * 1688
);
}
{
let mut map = t.map().unwrap();
let slice = map.as_mut_slice();
for y in 0..1688 {
let row_start = y * stride;
for x in 0..3004 {
let p = row_start + x * 4;
slice[p] = (y & 0xFF) as u8;
slice[p + 1] = (x & 0xFF) as u8;
slice[p + 2] = 0x42;
slice[p + 3] = 0xFF;
}
}
}
{
let map = t.map().unwrap();
let slice = map.as_slice();
assert_eq!(slice[0], 0x00);
assert_eq!(slice[1], 0x00);
assert_eq!(slice[2], 0x42);
assert_eq!(slice[3], 0xFF);
let mid = 100 * stride + 50 * 4;
assert_eq!(slice[mid], 100);
assert_eq!(slice[mid + 1], 50);
assert_eq!(slice[mid + 2], 0x42);
}
}
#[test]
#[cfg(target_os = "linux")]
fn image_tensor_with_stride_rejects_foreign_strided_map() {
if !is_dma_available() {
eprintln!("SKIPPED: DMA heap not available");
return;
}
let backing = Tensor::<u8>::new(&[240 * 320 * 4], Some(TensorMemory::Dma), None).unwrap();
let fd = backing.clone_fd().unwrap();
let shape = [240usize, 320, 4];
let storage = TensorStorage::<u8>::from_fd(fd, &shape, None).unwrap();
let mut t = Tensor::<u8>::wrap(storage);
t.set_format(PixelFormat::Bgra).unwrap();
t.set_row_stride(320 * 4).unwrap(); let err = t.map();
assert!(
matches!(err, Err(Error::InvalidOperation(_))),
"foreign strided map should error"
);
}
#[test]
#[cfg(target_os = "linux")]
fn image_tensor_with_stride_map_rejects_tampered_stride() {
if !is_dma_available() {
eprintln!("SKIPPED: DMA heap not available");
return;
}
let mut t = Tensor::<u8>::image_with_stride(
640,
480,
PixelFormat::Rgba,
3072,
Some(TensorMemory::Dma),
)
.unwrap();
t.set_row_stride(12288).unwrap();
let err = t.map();
assert!(
matches!(err, Err(Error::InvalidOperation(_))),
"map() with oversized stride must return InvalidOperation"
);
}
#[test]
fn dma_tensor_new_with_byte_size_rejects_shape_overflow() {
#[cfg(target_os = "linux")]
{
let err = crate::dma::DmaTensor::<u64>::new_with_byte_size(
&[usize::MAX, 2, 2],
usize::MAX,
None,
);
assert!(
matches!(err, Err(Error::InvalidArgument(_))),
"new_with_byte_size must detect shape.product() overflow"
);
}
}
#[test]
#[cfg(target_os = "linux")]
fn image_tensor_with_stride_rejects_too_small_stride() {
let err = Tensor::<u8>::image_with_stride(
640,
480,
PixelFormat::Rgba,
2400,
Some(TensorMemory::Dma),
);
assert!(matches!(err, Err(Error::InvalidArgument(_))));
}
#[test]
#[cfg(target_os = "linux")]
fn image_tensor_with_stride_rejects_non_packed() {
let err = Tensor::<u8>::image_with_stride(
640,
480,
PixelFormat::Nv12,
640,
Some(TensorMemory::Dma),
);
assert!(matches!(err, Err(Error::NotImplemented(_))));
}
#[test]
fn set_format_valid() {
let mut t = Tensor::<u8>::new(&[480, 640, 3], None, None).unwrap();
assert!(t.format().is_none());
t.set_format(PixelFormat::Rgb).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgb));
assert_eq!(t.width(), Some(640));
assert_eq!(t.height(), Some(480));
}
#[test]
fn set_format_invalid_shape() {
let mut t = Tensor::<u8>::new(&[480, 640, 4], None, None).unwrap();
let err = t.set_format(PixelFormat::Rgb);
assert!(err.is_err());
assert!(t.format().is_none());
}
#[test]
fn reshape_clears_format() {
let mut t = Tensor::<u8>::image(640, 480, PixelFormat::Rgba, None).unwrap();
assert_eq!(t.format(), Some(PixelFormat::Rgba));
t.reshape(&[480 * 640 * 4]).unwrap();
assert!(t.format().is_none());
}
#[test]
fn from_planes_nv12() {
let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
let img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
assert_eq!(img.format(), Some(PixelFormat::Nv12));
assert!(img.is_multiplane());
assert!(img.chroma().is_some());
assert_eq!(img.width(), Some(640));
assert_eq!(img.height(), Some(480));
}
#[test]
fn from_planes_rejects_non_semiplanar() {
let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
let err = Tensor::from_planes(y, uv, PixelFormat::Rgb);
assert!(err.is_err());
}
#[test]
fn reshape_multiplane_errors() {
let y = Tensor::<u8>::new(&[480, 640], None, None).unwrap();
let uv = Tensor::<u8>::new(&[240, 640], None, None).unwrap();
let mut img = Tensor::from_planes(y, uv, PixelFormat::Nv12).unwrap();
let err = img.reshape(&[480 * 640 + 240 * 640]);
assert!(err.is_err());
}
}
#[cfg(test)]
mod tests {
#[cfg(target_os = "linux")]
use nix::unistd::{access, AccessFlags};
#[cfg(target_os = "linux")]
use std::io::Write as _;
use std::sync::RwLock;
use super::*;
#[ctor::ctor]
fn init() {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
}
#[cfg(target_os = "linux")]
macro_rules! function {
() => {{
fn f() {}
fn type_name_of<T>(_: T) -> &'static str {
std::any::type_name::<T>()
}
let name = type_name_of(f);
match &name[..name.len() - 3].rfind(':') {
Some(pos) => &name[pos + 1..name.len() - 3],
None => &name[..name.len() - 3],
}
}};
}
#[test]
#[cfg(target_os = "linux")]
fn test_tensor() {
let _lock = FD_LOCK.read().unwrap();
let shape = vec![1];
let tensor = DmaTensor::<f32>::new(&shape, Some("dma_tensor"));
let dma_enabled = tensor.is_ok();
let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
match dma_enabled {
true => assert_eq!(tensor.memory(), TensorMemory::Dma),
false => assert_eq!(tensor.memory(), TensorMemory::Shm),
}
}
#[test]
#[cfg(all(unix, not(target_os = "linux")))]
fn test_tensor() {
let shape = vec![1];
let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
assert!(
tensor.memory() == TensorMemory::Shm || tensor.memory() == TensorMemory::Mem,
"Expected SHM or Mem on macOS, got {:?}",
tensor.memory()
);
}
#[test]
#[cfg(not(unix))]
fn test_tensor() {
let shape = vec![1];
let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
assert_eq!(tensor.memory(), TensorMemory::Mem);
}
#[test]
#[cfg(target_os = "linux")]
fn test_dma_tensor() {
let _lock = FD_LOCK.read().unwrap();
match access(
"/dev/dma_heap/linux,cma",
AccessFlags::R_OK | AccessFlags::W_OK,
) {
Ok(_) => println!("/dev/dma_heap/linux,cma is available"),
Err(_) => match access(
"/dev/dma_heap/system",
AccessFlags::R_OK | AccessFlags::W_OK,
) {
Ok(_) => println!("/dev/dma_heap/system is available"),
Err(e) => {
writeln!(
&mut std::io::stdout(),
"[WARNING] DMA Heap is unavailable: {e}"
)
.unwrap();
return;
}
},
}
let shape = vec![2, 3, 4];
let tensor =
DmaTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
const DUMMY_VALUE: f32 = 12.34;
assert_eq!(tensor.memory(), TensorMemory::Dma);
assert_eq!(tensor.name(), "test_tensor");
assert_eq!(tensor.shape(), &shape);
assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
assert_eq!(tensor.len(), 2 * 3 * 4);
{
let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
tensor_map.fill(42.0);
assert!(tensor_map.iter().all(|&x| x == 42.0));
}
{
let shared = Tensor::<f32>::from_fd(
tensor
.clone_fd()
.expect("Failed to duplicate tensor file descriptor"),
&shape,
Some("test_tensor_shared"),
)
.expect("Failed to create tensor from fd");
assert_eq!(shared.memory(), TensorMemory::Dma);
assert_eq!(shared.name(), "test_tensor_shared");
assert_eq!(shared.shape(), &shape);
let mut tensor_map = shared.map().expect("Failed to map DMA memory from fd");
tensor_map.fill(DUMMY_VALUE);
assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
}
{
let tensor_map = tensor.map().expect("Failed to map DMA memory");
assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
}
let mut tensor = DmaTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
assert_eq!(tensor.shape(), &shape);
let new_shape = vec![3, 4, 4];
assert!(
tensor.reshape(&new_shape).is_err(),
"Reshape should fail due to size mismatch"
);
assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
let new_shape = vec![2, 3, 4];
tensor.reshape(&new_shape).expect("Reshape should succeed");
assert_eq!(
tensor.shape(),
&new_shape,
"Shape should be updated after successful reshape"
);
{
let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
tensor_map.fill(1);
assert!(tensor_map.iter().all(|&x| x == 1));
}
{
let mut tensor_map = tensor.map().expect("Failed to map DMA memory");
tensor_map[2] = 42;
assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
}
}
#[test]
#[cfg(unix)]
fn test_shm_tensor() {
let _lock = FD_LOCK.read().unwrap();
let shape = vec![2, 3, 4];
let tensor =
ShmTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
assert_eq!(tensor.shape(), &shape);
assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
assert_eq!(tensor.name(), "test_tensor");
const DUMMY_VALUE: f32 = 12.34;
{
let mut tensor_map = tensor.map().expect("Failed to map shared memory");
tensor_map.fill(42.0);
assert!(tensor_map.iter().all(|&x| x == 42.0));
}
{
let shared = Tensor::<f32>::from_fd(
tensor
.clone_fd()
.expect("Failed to duplicate tensor file descriptor"),
&shape,
Some("test_tensor_shared"),
)
.expect("Failed to create tensor from fd");
assert_eq!(shared.memory(), TensorMemory::Shm);
assert_eq!(shared.name(), "test_tensor_shared");
assert_eq!(shared.shape(), &shape);
let mut tensor_map = shared.map().expect("Failed to map shared memory from fd");
tensor_map.fill(DUMMY_VALUE);
assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
}
{
let tensor_map = tensor.map().expect("Failed to map shared memory");
assert!(tensor_map.iter().all(|&x| x == DUMMY_VALUE));
}
let mut tensor = ShmTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
assert_eq!(tensor.shape(), &shape);
let new_shape = vec![3, 4, 4];
assert!(
tensor.reshape(&new_shape).is_err(),
"Reshape should fail due to size mismatch"
);
assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
let new_shape = vec![2, 3, 4];
tensor.reshape(&new_shape).expect("Reshape should succeed");
assert_eq!(
tensor.shape(),
&new_shape,
"Shape should be updated after successful reshape"
);
{
let mut tensor_map = tensor.map().expect("Failed to map shared memory");
tensor_map.fill(1);
assert!(tensor_map.iter().all(|&x| x == 1));
}
{
let mut tensor_map = tensor.map().expect("Failed to map shared memory");
tensor_map[2] = 42;
assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
}
}
#[test]
fn test_mem_tensor() {
let shape = vec![2, 3, 4];
let tensor =
MemTensor::<f32>::new(&shape, Some("test_tensor")).expect("Failed to create tensor");
assert_eq!(tensor.shape(), &shape);
assert_eq!(tensor.size(), 2 * 3 * 4 * std::mem::size_of::<f32>());
assert_eq!(tensor.name(), "test_tensor");
{
let mut tensor_map = tensor.map().expect("Failed to map memory");
tensor_map.fill(42.0);
assert!(tensor_map.iter().all(|&x| x == 42.0));
}
let mut tensor = MemTensor::<u8>::new(&shape, None).expect("Failed to create tensor");
assert_eq!(tensor.shape(), &shape);
let new_shape = vec![3, 4, 4];
assert!(
tensor.reshape(&new_shape).is_err(),
"Reshape should fail due to size mismatch"
);
assert_eq!(tensor.shape(), &shape, "Shape should remain unchanged");
let new_shape = vec![2, 3, 4];
tensor.reshape(&new_shape).expect("Reshape should succeed");
assert_eq!(
tensor.shape(),
&new_shape,
"Shape should be updated after successful reshape"
);
{
let mut tensor_map = tensor.map().expect("Failed to map memory");
tensor_map.fill(1);
assert!(tensor_map.iter().all(|&x| x == 1));
}
{
let mut tensor_map = tensor.map().expect("Failed to map memory");
tensor_map[2] = 42;
assert_eq!(tensor_map[1], 1, "Value at index 1 should be 1");
assert_eq!(tensor_map[2], 42, "Value at index 2 should be 42");
}
}
#[test]
#[cfg(target_os = "linux")]
fn test_dma_no_fd_leaks() {
let _lock = FD_LOCK.write().unwrap();
if !is_dma_available() {
log::warn!(
"SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
function!()
);
return;
}
let proc = procfs::process::Process::myself()
.expect("Failed to get current process using /proc/self");
let start_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
for _ in 0..100 {
let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None)
.expect("Failed to create tensor");
let mut map = tensor.map().unwrap();
map.as_mut_slice().fill(233);
}
let end_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
assert_eq!(
start_open_fds, end_open_fds,
"File descriptor leak detected: {} -> {}",
start_open_fds, end_open_fds
);
}
#[test]
#[cfg(target_os = "linux")]
fn test_dma_from_fd_no_fd_leaks() {
let _lock = FD_LOCK.write().unwrap();
if !is_dma_available() {
log::warn!(
"SKIPPED: {} - DMA memory allocation not available (permission denied or no DMA-BUF support)",
function!()
);
return;
}
let proc = procfs::process::Process::myself()
.expect("Failed to get current process using /proc/self");
let start_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Dma), None).unwrap();
for _ in 0..100 {
let tensor =
Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
let mut map = tensor.map().unwrap();
map.as_mut_slice().fill(233);
}
drop(orig);
let end_open_fds = proc.fd_count().unwrap();
assert_eq!(
start_open_fds, end_open_fds,
"File descriptor leak detected: {} -> {}",
start_open_fds, end_open_fds
);
}
#[test]
#[cfg(target_os = "linux")]
fn test_shm_no_fd_leaks() {
let _lock = FD_LOCK.write().unwrap();
if !is_shm_available() {
log::warn!(
"SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
function!()
);
return;
}
let proc = procfs::process::Process::myself()
.expect("Failed to get current process using /proc/self");
let start_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
for _ in 0..100 {
let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
.expect("Failed to create tensor");
let mut map = tensor.map().unwrap();
map.as_mut_slice().fill(233);
}
let end_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
assert_eq!(
start_open_fds, end_open_fds,
"File descriptor leak detected: {} -> {}",
start_open_fds, end_open_fds
);
}
#[test]
#[cfg(target_os = "linux")]
fn test_shm_from_fd_no_fd_leaks() {
let _lock = FD_LOCK.write().unwrap();
if !is_shm_available() {
log::warn!(
"SKIPPED: {} - SHM memory allocation not available (permission denied or no SHM support)",
function!()
);
return;
}
let proc = procfs::process::Process::myself()
.expect("Failed to get current process using /proc/self");
let start_open_fds = proc
.fd_count()
.expect("Failed to get open file descriptor count");
let orig = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None).unwrap();
for _ in 0..100 {
let tensor =
Tensor::<u8>::from_fd(orig.clone_fd().unwrap(), orig.shape(), None).unwrap();
let mut map = tensor.map().unwrap();
map.as_mut_slice().fill(233);
}
drop(orig);
let end_open_fds = proc.fd_count().unwrap();
assert_eq!(
start_open_fds, end_open_fds,
"File descriptor leak detected: {} -> {}",
start_open_fds, end_open_fds
);
}
#[cfg(feature = "ndarray")]
#[test]
fn test_ndarray() {
let _lock = FD_LOCK.read().unwrap();
let shape = vec![2, 3, 4];
let tensor = Tensor::<f32>::new(&shape, None, None).expect("Failed to create tensor");
let mut tensor_map = tensor.map().expect("Failed to map tensor memory");
tensor_map.fill(1.0);
let view = tensor_map.view().expect("Failed to get ndarray view");
assert_eq!(view.shape(), &[2, 3, 4]);
assert!(view.iter().all(|&x| x == 1.0));
let mut view_mut = tensor_map
.view_mut()
.expect("Failed to get mutable ndarray view");
view_mut[[0, 0, 0]] = 42.0;
assert_eq!(view_mut[[0, 0, 0]], 42.0);
assert_eq!(tensor_map[0], 42.0, "Value at index 0 should be 42");
}
#[test]
fn test_buffer_identity_unique() {
let id1 = BufferIdentity::new();
let id2 = BufferIdentity::new();
assert_ne!(
id1.id(),
id2.id(),
"Two identities should have different ids"
);
}
#[test]
fn test_buffer_identity_clone_shares_guard() {
let id1 = BufferIdentity::new();
let weak = id1.weak();
assert!(
weak.upgrade().is_some(),
"Weak should be alive while original exists"
);
let id2 = id1.clone();
assert_eq!(id1.id(), id2.id(), "Cloned identity should have same id");
drop(id1);
assert!(
weak.upgrade().is_some(),
"Weak should still be alive (clone holds Arc)"
);
drop(id2);
assert!(
weak.upgrade().is_none(),
"Weak should be dead after all clones dropped"
);
}
#[test]
fn test_tensor_buffer_identity() {
let t1 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t1")).unwrap();
let t2 = Tensor::<u8>::new(&[100], Some(TensorMemory::Mem), Some("t2")).unwrap();
assert_ne!(
t1.buffer_identity().id(),
t2.buffer_identity().id(),
"Different tensors should have different buffer ids"
);
}
#[test]
fn test_quantization_per_tensor_constructors() {
let q = Quantization::per_tensor(0.1, -5);
assert!(q.is_per_tensor());
assert!(!q.is_per_channel());
assert!(!q.is_symmetric());
assert_eq!(q.scale(), &[0.1]);
assert_eq!(q.zero_point(), Some(&[-5][..]));
let qs = Quantization::per_tensor_symmetric(0.05);
assert!(qs.is_per_tensor());
assert!(qs.is_symmetric());
assert_eq!(qs.zero_point(), None);
}
#[test]
fn test_quantization_per_channel_constructors() {
let q = Quantization::per_channel(vec![0.1, 0.2, 0.3], vec![0, -1, 1], 2).unwrap();
assert!(q.is_per_channel());
assert!(!q.is_symmetric());
assert_eq!(q.axis(), Some(2));
assert_eq!(q.scale().len(), 3);
let qs = Quantization::per_channel_symmetric(vec![0.054, 0.089, 0.195], 0).unwrap();
assert!(qs.is_per_channel());
assert!(qs.is_symmetric());
assert_eq!(qs.axis(), Some(0));
}
#[test]
fn test_quantization_per_channel_length_mismatch_rejected() {
let err = Quantization::per_channel(vec![0.1, 0.2], vec![0, 0, 0], 0).unwrap_err();
assert!(matches!(err, Error::QuantizationInvalid { .. }));
}
#[test]
fn test_quantization_per_channel_empty_rejected() {
let err = Quantization::per_channel_symmetric(vec![], 0).unwrap_err();
assert!(matches!(err, Error::QuantizationInvalid { .. }));
}
#[test]
fn test_quantization_validate_rejects_malformed_deserialize() {
let mut t = Tensor::<i8>::new(&[1, 1, 4], Some(TensorMemory::Mem), None).unwrap();
let q: Quantization = serde_json::from_str(r#"{"scale": []}"#).unwrap();
assert!(matches!(
t.set_quantization(q).unwrap_err(),
Error::QuantizationInvalid { .. }
));
let q: Quantization =
serde_json::from_str(r#"{"scale": 0.1, "zero_point": [0, 0, 0]}"#).unwrap();
assert!(matches!(
t.set_quantization(q).unwrap_err(),
Error::QuantizationInvalid { .. }
));
let q: Quantization = serde_json::from_str(
r#"{"scale": [0.1, 0.2, 0.3, 0.4], "zero_point": [0, 0], "axis": 2}"#,
)
.unwrap();
assert!(matches!(
t.set_quantization(q).unwrap_err(),
Error::QuantizationInvalid { .. }
));
}
#[test]
fn test_quantization_mode_dispatch() {
let pt = Quantization::per_tensor(0.1, -5);
assert!(matches!(
pt.mode(),
QuantMode::PerTensor { scale, zero_point } if scale == 0.1 && zero_point == -5
));
let pts = Quantization::per_tensor_symmetric(0.05);
assert!(matches!(
pts.mode(),
QuantMode::PerTensorSymmetric { scale } if scale == 0.05
));
let pc = Quantization::per_channel(vec![0.1, 0.2], vec![0, -1], 2).unwrap();
assert!(matches!(pc.mode(), QuantMode::PerChannel { axis: 2, .. }));
let pcs = Quantization::per_channel_symmetric(vec![0.1, 0.2], 0).unwrap();
assert!(matches!(
pcs.mode(),
QuantMode::PerChannelSymmetric { axis: 0, .. }
));
}
#[test]
fn test_tensor_quantization_roundtrip_integer() {
let mut t = Tensor::<i8>::new(&[2, 3, 4], Some(TensorMemory::Mem), None).unwrap();
assert!(t.quantization().is_none());
t.set_quantization(Quantization::per_tensor(0.1, -5))
.unwrap();
let q = t.quantization().unwrap();
assert_eq!(q.scale(), &[0.1]);
t.clear_quantization();
assert!(t.quantization().is_none());
}
#[test]
fn test_tensor_with_quantization_builder() {
let t = Tensor::<i8>::new(&[4, 4], Some(TensorMemory::Mem), None)
.unwrap()
.with_quantization(Quantization::per_tensor_symmetric(0.05))
.unwrap();
assert!(t.quantization().is_some());
}
#[test]
fn test_tensor_dyn_quantization_float_arm_returns_none() {
let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
let td = TensorDyn::F32(t);
assert!(td.quantization().is_none());
}
#[test]
fn test_tensor_dyn_set_quantization_float_arm_errors() {
let t = Tensor::<f32>::new(&[2, 2], Some(TensorMemory::Mem), None).unwrap();
let mut td = TensorDyn::F32(t);
let err = td
.set_quantization(Quantization::per_tensor(0.1, 0))
.unwrap_err();
assert!(matches!(err, Error::QuantizationInvalid { .. }));
}
fn _compile_fail_doctest_anchor() {}
pub static FD_LOCK: RwLock<()> = RwLock::new(());
#[test]
#[cfg(not(target_os = "linux"))]
fn test_dma_not_available_on_non_linux() {
assert!(
!is_dma_available(),
"DMA memory allocation should NOT be available on non-Linux platforms"
);
}
#[test]
#[cfg(unix)]
fn test_shm_available_and_usable() {
assert!(
is_shm_available(),
"SHM memory allocation should be available on Unix systems"
);
let tensor = Tensor::<u8>::new(&[100, 100], Some(TensorMemory::Shm), None)
.expect("Failed to create SHM tensor");
let mut map = tensor.map().expect("Failed to map SHM tensor");
map.as_mut_slice().fill(0xAB);
assert!(
map.as_slice().iter().all(|&b| b == 0xAB),
"SHM tensor data should be writable and readable"
);
}
}