use crate::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct IoDim {
pub n: usize,
pub is: isize,
pub os: isize,
}
impl IoDim {
#[inline]
#[must_use]
pub const fn new(n: usize, is: isize, os: isize) -> Self {
Self { n, is, os }
}
#[inline]
#[must_use]
pub const fn contiguous(n: usize) -> Self {
Self::new(n, 1, 1)
}
#[inline]
#[must_use]
pub const fn is_contiguous(&self) -> bool {
self.is == 1 && self.os == 1
}
#[inline]
#[must_use]
pub const fn is_inplace_compatible(&self) -> bool {
self.is == self.os
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)]
pub struct Tensor {
pub dims: Vec<IoDim>,
}
impl Tensor {
#[inline]
#[must_use]
pub fn new(dims: Vec<IoDim>) -> Self {
Self { dims }
}
#[inline]
#[must_use]
pub fn empty() -> Self {
Self { dims: Vec::new() }
}
#[inline]
#[must_use]
pub fn rank1(n: usize) -> Self {
Self::new(vec![IoDim::contiguous(n)])
}
#[inline]
#[must_use]
pub fn rank2(n0: usize, n1: usize) -> Self {
Self::new(vec![
IoDim::new(n0, n1 as isize, n1 as isize),
IoDim::contiguous(n1),
])
}
#[inline]
#[must_use]
pub fn rank3(n0: usize, n1: usize, n2: usize) -> Self {
let stride0 = (n1 * n2) as isize;
let stride1 = n2 as isize;
Self::new(vec![
IoDim::new(n0, stride0, stride0),
IoDim::new(n1, stride1, stride1),
IoDim::contiguous(n2),
])
}
#[inline]
#[must_use]
pub fn rank(&self) -> usize {
self.dims.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.dims.is_empty()
}
#[inline]
#[must_use]
pub fn total_size(&self) -> usize {
self.dims.iter().map(|d| d.n).product()
}
#[inline]
#[must_use]
pub fn is_contiguous(&self) -> bool {
if self.dims.is_empty() {
return true;
}
let mut expected_stride = 1isize;
for dim in self.dims.iter().rev() {
if dim.is != expected_stride || dim.os != expected_stride {
return false;
}
expected_stride *= dim.n as isize;
}
true
}
#[inline]
#[must_use]
pub fn is_inplace_compatible(&self) -> bool {
self.dims.iter().all(|d| d.is_inplace_compatible())
}
#[must_use]
pub fn split(&self, axis: usize) -> (Self, Self) {
let outer = Self::new(self.dims[..axis].to_vec());
let inner = Self::new(self.dims[axis..].to_vec());
(outer, inner)
}
#[inline]
#[must_use]
pub fn first(&self) -> Option<&IoDim> {
self.dims.first()
}
#[inline]
#[must_use]
pub fn last(&self) -> Option<&IoDim> {
self.dims.last()
}
#[must_use]
pub fn pop_front(&self) -> Option<(IoDim, Self)> {
if self.dims.is_empty() {
None
} else {
let first = self.dims[0].clone();
let rest = Self::new(self.dims[1..].to_vec());
Some((first, rest))
}
}
pub fn push(&mut self, dim: IoDim) {
self.dims.push(dim);
}
pub fn push_front(&mut self, dim: IoDim) {
self.dims.insert(0, dim);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_rank1() {
let t = Tensor::rank1(256);
assert_eq!(t.rank(), 1);
assert_eq!(t.total_size(), 256);
assert!(t.is_contiguous());
}
#[test]
fn test_tensor_rank2() {
let t = Tensor::rank2(64, 64);
assert_eq!(t.rank(), 2);
assert_eq!(t.total_size(), 4096);
assert!(t.is_contiguous());
}
#[test]
fn test_tensor_split() {
let t = Tensor::rank3(4, 8, 16);
let (outer, inner) = t.split(1);
assert_eq!(outer.rank(), 1);
assert_eq!(inner.rank(), 2);
}
}