ort 2.0.0-rc.12

A safe Rust wrapper for ONNX Runtime 1.24 - Optimize and accelerate machine learning inference & training
Documentation
//! Traits and types related to [`Tensor`](crate::value::Tensor)s.

use alloc::{string::String, vec::Vec};
use core::{
	fmt,
	ops::{Deref, DerefMut}
};

use smallvec::{SmallVec, smallvec};

#[derive(Default, Clone, PartialEq, Eq)]
pub struct Shape {
	inner: SmallVec<[i64; 4]>
}

impl Shape {
	pub fn new(dims: impl IntoIterator<Item = i64>) -> Self {
		Self { inner: dims.into_iter().collect() }
	}

	pub fn empty(rank: usize) -> Self {
		Self { inner: smallvec![0; rank] }
	}

	#[doc(alias = "numel")]
	pub fn num_elements(&self) -> usize {
		let mut size = 1usize;
		for dim in &self.inner {
			if *dim < 0 {
				return 0;
			}
			size *= *dim as usize;
		}
		size
	}

	#[cfg(feature = "ndarray")]
	#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
	pub fn to_ixdyn(&self) -> ::ndarray::IxDyn {
		use ::ndarray::IntoDimension;
		self.inner.iter().map(|d| *d as usize).collect::<Vec<usize>>().into_dimension()
	}
}

impl fmt::Debug for Shape {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		f.debug_list().entries(self.inner.iter()).finish()
	}
}

impl fmt::Display for Shape {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		f.debug_list().entries(self.inner.iter()).finish()
	}
}

impl From<Vec<usize>> for Shape {
	fn from(value: Vec<usize>) -> Self {
		Shape::new(value.into_iter().map(|x| x as i64))
	}
}

impl From<Vec<i64>> for Shape {
	fn from(value: Vec<i64>) -> Self {
		Self { inner: SmallVec::from(value) }
	}
}

impl From<&[usize]> for Shape {
	fn from(value: &[usize]) -> Self {
		Shape::new(value.iter().map(|x| *x as i64))
	}
}

impl From<&[i64]> for Shape {
	fn from(value: &[i64]) -> Self {
		Self { inner: SmallVec::from(value) }
	}
}

impl<const N: usize> From<[usize; N]> for Shape {
	fn from(value: [usize; N]) -> Self {
		Shape::new(value.into_iter().map(|x| x as i64))
	}
}

impl<const N: usize> From<[i64; N]> for Shape {
	fn from(value: [i64; N]) -> Self {
		Self { inner: SmallVec::from_slice(&value) }
	}
}

impl FromIterator<usize> for Shape {
	fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
		Self {
			inner: iter.into_iter().map(|x| x as i64).collect()
		}
	}
}

impl FromIterator<i64> for Shape {
	fn from_iter<T: IntoIterator<Item = i64>>(iter: T) -> Self {
		Self { inner: iter.into_iter().collect() }
	}
}

impl Deref for Shape {
	type Target = [i64];
	fn deref(&self) -> &Self::Target {
		&self.inner
	}
}

impl DerefMut for Shape {
	fn deref_mut(&mut self) -> &mut Self::Target {
		&mut self.inner
	}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SymbolicDimensions(SmallVec<[String; 4]>);

impl SymbolicDimensions {
	pub fn new(dims: impl IntoIterator<Item = String>) -> Self {
		Self(dims.into_iter().collect())
	}

	pub fn empty(rank: usize) -> Self {
		Self(smallvec![String::default(); rank])
	}
}

impl FromIterator<String> for SymbolicDimensions {
	fn from_iter<T: IntoIterator<Item = String>>(iter: T) -> Self {
		Self(iter.into_iter().collect())
	}
}

impl Deref for SymbolicDimensions {
	type Target = [String];
	fn deref(&self) -> &Self::Target {
		&self.0
	}
}