#![doc = include_str!("../README.md")]
#![deny(warnings)]
use digit_layout::DigitLayout;
use ndarray_layout::{ArrayLayout, Endian::BigEndian};
use std::{
borrow::Cow,
ops::{Deref, DerefMut},
};
pub extern crate digit_layout;
pub extern crate ndarray_layout;
#[derive(Clone)]
pub struct Tensor<T, const N: usize> {
dt: DigitLayout,
layout: ArrayLayout<N>,
item: T,
}
impl<const N: usize> Tensor<usize, N> {
pub fn new(dt: DigitLayout, shape: [usize; N]) -> Self {
Self::from_dim_slice(dt, shape)
}
pub fn from_dim_slice(dt: DigitLayout, shape: impl AsRef<[usize]>) -> Self {
let shape = shape.as_ref();
let shape = match dt.group_size() {
1 => Cow::Borrowed(shape),
g => {
let mut shape = shape.to_vec();
let last = shape.last_mut().unwrap();
assert_eq!(*last % g, 0);
*last /= g;
Cow::Owned(shape)
}
};
let element_size = dt.nbytes();
let layout = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
let size = layout.num_elements() * element_size;
Self {
dt,
layout,
item: size,
}
}
}
impl<T, const N: usize> Tensor<T, N> {
pub const fn dt(&self) -> DigitLayout {
self.dt
}
pub const fn layout(&self) -> &ArrayLayout<N> {
&self.layout
}
pub fn shape(&self) -> &[usize] {
self.layout.shape()
}
pub fn strides(&self) -> &[isize] {
self.layout.strides()
}
pub fn offset(&self) -> isize {
self.layout.offset()
}
pub const fn get(&self) -> &T {
&self.item
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.item
}
pub fn take(self) -> T {
self.item
}
pub const fn from_raw_parts(dt: DigitLayout, layout: ArrayLayout<N>, item: T) -> Self {
Self { dt, layout, item }
}
pub fn into_raw_parts(self) -> (DigitLayout, ArrayLayout<N>, T) {
let Self { dt, layout, item } = self;
(dt, layout, item)
}
pub fn use_info(&self) -> Tensor<usize, N> {
let dt = self.dt;
let element_size = dt.nbytes();
let layout = ArrayLayout::new_contiguous(self.layout.shape(), BigEndian, element_size);
let size = layout.num_elements() * element_size;
Tensor {
dt,
layout,
item: size,
}
}
pub fn is_contiguous(&self) -> bool {
match self.layout.merge_be(0, self.layout.ndim()) {
Some(layout) => {
let &[s] = layout.strides() else {
unreachable!()
};
s == self.dt.nbytes() as isize
}
None => false,
}
}
}
impl<T, const N: usize> Tensor<T, N> {
pub fn as_ref(&self) -> Tensor<&T, N> {
Tensor {
dt: self.dt,
layout: self.layout.clone(),
item: &self.item,
}
}
pub fn as_mut(&mut self) -> Tensor<&mut T, N> {
Tensor {
dt: self.dt,
layout: self.layout.clone(),
item: &mut self.item,
}
}
pub fn transform(self, f: impl FnOnce(ArrayLayout<N>) -> ArrayLayout<N>) -> Self {
let Self { dt, layout, item } = self;
Self {
dt,
layout: f(layout),
item,
}
}
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Tensor<U, N> {
let Self { dt, layout, item } = self;
Tensor {
dt,
layout,
item: f(item),
}
}
pub fn replace<U>(self, u: U) -> (T, Tensor<U, N>) {
let Self { dt, layout, item } = self;
(
item,
Tensor {
dt,
layout,
item: u,
},
)
}
}
impl<T: Deref, const N: usize> Tensor<T, N> {
pub fn as_deref(&self) -> Tensor<&<T as Deref>::Target, N> {
Tensor {
dt: self.dt,
layout: self.layout.clone(),
item: self.item.deref(),
}
}
}
impl<T: DerefMut, const N: usize> Tensor<T, N> {
pub fn as_deref_mut(&mut self) -> Tensor<&mut <T as Deref>::Target, N> {
Tensor {
dt: self.dt,
layout: self.layout.clone(),
item: self.item.deref_mut(),
}
}
}
#[test]
fn test_basic_functions() {
digit_layout::layout!(GROUP u(8); 32);
let t1 = Tensor::new(GROUP, [7, 1024]);
assert_eq!(t1.dt(), GROUP);
let l1 = t1.layout();
let l2: ArrayLayout<2> = ArrayLayout::new_contiguous(&[7, 1024 / 32], BigEndian, 32);
assert_eq!(l1.shape(), l2.shape());
assert_eq!(l1.strides(), l2.strides());
assert_eq!(l1.offset(), l2.offset());
assert_eq!(t1.shape(), [7, 1024 / 32]);
assert_eq!(t1.strides(), &[1024, 32]);
assert_eq!(t1.offset(), 0);
assert_eq!(*t1.get(), 7 * 1024);
let mut t2 = t1.clone();
*(t2.get_mut()) += 1;
assert_eq!(*t2.get(), 7 * 1024 + 1);
assert_eq!(t1.take(), 7 * 1024)
}
#[test]
fn test_extra_functions() {
digit_layout::layout!(GROUP u(8); 32);
let shape = [7, 1024];
let element_size = 32;
let layout: ArrayLayout<2> = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
let item = 7 * 1024;
let tensor = Tensor::from_raw_parts(GROUP, layout.clone(), item);
assert_eq!(tensor.dt(), GROUP);
assert_eq!(tensor.layout().shape(), layout.shape());
assert_eq!(tensor.layout().strides(), layout.strides());
assert_eq!(tensor.layout().offset(), layout.offset());
assert_eq!(*tensor.get(), item);
let (dt, layout_from_parts, item_from_parts) = tensor.into_raw_parts();
assert_eq!(dt, GROUP);
assert_eq!(layout_from_parts.shape(), layout.shape());
assert_eq!(layout_from_parts.strides(), layout.strides());
assert_eq!(layout_from_parts.offset(), layout.offset());
assert_eq!(item_from_parts, item);
let tensor = Tensor::from_raw_parts(GROUP, layout, item);
let info_tensor = tensor.use_info();
assert_eq!(info_tensor.dt(), GROUP);
assert_eq!(info_tensor.shape(), tensor.shape());
let expected_size = info_tensor.layout().num_elements() * GROUP.nbytes();
assert_eq!(info_tensor.take(), expected_size);
let is_contig = tensor.is_contiguous();
assert!(is_contig);
let non_contig_layout: ArrayLayout<2> = ArrayLayout::new(&[7, 1024], &[1, 7], 0);
let non_contig_tensor = Tensor::from_raw_parts(GROUP, non_contig_layout, item);
assert!(!non_contig_tensor.is_contiguous())
}
#[test]
fn test_as_ref() {
digit_layout::layout!(GROUP u(8); 32);
let t1 = Tensor::new(GROUP, [7, 1024]);
let ref_tensor = t1.as_ref();
assert_eq!(*ref_tensor.item, 7168)
}
#[test]
fn test_as_mut() {
digit_layout::layout!(GROUP u(8); 32);
let mut t1 = Tensor::new(GROUP, [7, 1024]);
let mut ref_tensor = t1.as_mut();
(**ref_tensor.get_mut()) += 1;
assert_eq!(*t1.get(), 7169)
}
#[test]
fn test_transform() {
digit_layout::layout!(GROUP u(8); 32);
let t1 = Tensor::new(GROUP, [7, 1024]);
fn trans(layout: ArrayLayout<2>) -> ArrayLayout<2> {
layout.transpose(&[0, 1])
}
let t2 = t1.transform(trans);
assert_eq!(t2.shape(), [7, 32]);
assert_eq!(t2.strides(), &[1024, 32]);
assert_eq!(t2.offset(), 0)
}
#[test]
fn test_map() {
digit_layout::layout!(GROUP u(8); 32);
let t1 = Tensor::new(GROUP, [7, 1024]);
fn trans(n: usize) -> isize {
n as isize
}
let t2 = t1.map(trans);
assert_eq!((*t2.get()), (7 * 1024) as isize)
}
#[test]
fn test_as_deref() {
struct TestDeref<T>(T);
impl<T> Deref for TestDeref<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
digit_layout::layout!(GROUP u(8); 32);
let tensor = Tensor {
dt: GROUP,
layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
item: TestDeref(42),
};
let tensor_ref = tensor.as_deref();
assert_eq!(*tensor_ref.item, 42)
}
#[test]
fn test_as_deref_mut() {
struct TestDeref<T>(T);
impl<T> Deref for TestDeref<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for TestDeref<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
digit_layout::layout!(GROUP u(8); 32);
let mut tensor = Tensor {
dt: GROUP,
layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
item: TestDeref(42),
};
let mut tensor_ref = tensor.as_deref_mut();
*(*tensor_ref.get_mut()) += 1;
assert_eq!(**tensor.get(), 43)
}