use std::mem::MaybeUninit;
use tensorgraph_sys::{ptr::Ref, View, ViewMut, device::Device};
use crate::{
dims::{Dimension, RemoveDim},
storage::{IntoOwned, Storage, StorageMut},
};
mod matrix;
mod vector;
pub use matrix::*;
pub use vector::*;
pub type Slice<T, D> = Ref<[T], D>;
pub type ViewOf<S> = Slice<<S as Storage>::T, <S as Storage>::Device>;
pub type TensorView<'a, T, D, Dim> = Tensor<&'a Slice<T, D>, Dim>;
pub type TensorViewMut<'a, T, D, Dim> = Tensor<&'a mut Slice<T, D>, Dim>;
pub type UninitTensor<'a, T, D, Dim> = TensorViewMut<'a, MaybeUninit<T>, D, Dim>;
#[derive(Copy, Clone)]
pub struct Tensor<S: Storage, Dim: Dimension> {
shape: Dim,
strides: Dim,
data: S,
}
impl<S: Storage, Dim: Dimension> View for Tensor<S, Dim> {
type Ref<'a>
where
Self: 'a,
= Tensor<&'a ViewOf<S>, Dim>;
fn view(&self) -> TensorView<S::T, S::Device, Dim> {
Tensor {
shape: self.shape.clone(),
strides: self.strides.clone(),
data: self.data.as_ref(),
}
}
}
impl<S: StorageMut, Dim: Dimension> ViewMut for Tensor<S, Dim> {
type Mut<'a>
where
Self: 'a,
= Tensor<&'a mut ViewOf<S>, Dim>;
fn view_mut(&mut self) -> TensorViewMut<S::T, S::Device, Dim> {
Tensor {
shape: self.shape.clone(),
strides: self.strides.clone(),
data: self.data.as_mut(),
}
}
}
impl<S: Storage, Dim: Dimension> Tensor<S, Dim> {
pub fn from_shape(shape: Dim, data: S) -> Self {
assert_eq!(data.as_ref().len(), shape.size());
let strides = shape.column_major_strides();
Self {
shape,
strides,
data,
}
}
pub fn into_inner(self) -> S {
self.data
}
pub fn reverse_axes(&mut self) {
self.shape.as_mut().reverse();
self.strides.as_mut().reverse();
}
pub fn swap_axes(&mut self, i: usize, j: usize) {
self.shape.as_mut().swap(i, j);
self.strides.as_mut().swap(i, j);
}
pub fn t(&self) -> Tensor<&ViewOf<S>, Dim> {
let mut view = self.view();
view.reverse_axes();
view
}
pub fn into_owned(self) -> Tensor<S::Owned, Dim>
where
S: IntoOwned,
S::Owned: Storage<T = S::T, Device = S::Device>,
{
Tensor {
shape: self.shape,
strides: self.strides,
data: self.data.into_owned(),
}
}
pub fn slice_axis(&self, axis: usize, n: usize) -> Tensor<&ViewOf<S>, Dim::Smaller>
where
Dim: RemoveDim,
{
assert!(axis < self.shape.as_ref().len());
let (shape, m) = self.shape.remove(axis);
let (strides, s) = self.strides.remove(axis);
assert!(n < m);
Tensor {
shape,
strides,
data: &self.data.as_ref()[s * n..],
}
}
}
impl<'a, T, D: Device, Dim: Dimension> UninitTensor<'a, T, D, Dim> {
pub unsafe fn assume_init(self) -> TensorViewMut<'a, T, D, Dim> {
Tensor {
shape: self.shape,
strides: self.strides,
data: self.data.assume_init_mut(),
}
}
}
impl<'a, T, D: Device, Dim: Dimension> TensorView<'a, MaybeUninit<T>, D, Dim> {
pub unsafe fn assume_init(self) -> TensorView<'a, T, D, Dim> {
Tensor {
shape: self.shape,
strides: self.strides,
data: self.data.assume_init(),
}
}
}
#[cfg(test)]
mod tests {
use tensorgraph_sys::{View, ViewMut};
use crate::tensor::{gemm, Tensor};
#[test]
fn matmul() {
let a = [0., 2., 4., 1., 3., 5.];
let b = [0., 2., 1., 3.];
let a = Tensor::from_shape([3, 2], a); let b = Tensor::from_shape([2, 2], b);
let c = a.matmul(b.view());
assert_eq!(c.into_inner().into_std(), [2., 6., 10., 3., 11., 19.]);
}
#[test]
fn matmul_t() {
let a = [1., 2., 3., 4.];
let b = [5., 6., 7., 8.];
let a = Tensor::from_shape([2, 2], a);
let b = Tensor::from_shape([2, 2], b);
let c1 = a.t().matmul(b.t());
assert_eq!(c1.into_inner().into_std(), [19.0, 43.0, 22.0, 50.0]);
let c2 = a.matmul(b.t());
assert_eq!(c2.into_inner().into_std(), [26.0, 38.0, 30.0, 44.0]);
let c3 = a.t().matmul(b.view());
assert_eq!(c3.into_inner().into_std(), [17.0, 39.0, 23.0, 53.0]);
}
#[test]
fn slice() {
let a = [0., 1., 2., 3., 4., 5.];
let a = Tensor::from_shape([2, 3], a);
let a00 = a.slice_axis(0, 0);
assert_eq!(&**a00.into_inner(), [0., 1., 2., 3., 4., 5.]); assert_eq!(a00.shape, [3]);
assert_eq!(a00.strides, [2]);
let a01 = a.slice_axis(0, 1);
assert_eq!(&**a01.into_inner(), [1., 2., 3., 4., 5.]); assert_eq!(a01.shape, [3]);
assert_eq!(a01.strides, [2]);
let a10 = a.slice_axis(1, 0);
assert_eq!(&**a10.into_inner(), [0., 1., 2., 3., 4., 5.]); assert_eq!(a10.shape, [2]);
assert_eq!(a10.strides, [1]);
let a11 = a.slice_axis(1, 1);
assert_eq!(&**a11.into_inner(), [2., 3., 4., 5.]); assert_eq!(a11.shape, [2]);
assert_eq!(a11.strides, [1]);
let a12 = a.slice_axis(1, 2);
assert_eq!(&**a12.into_inner(), [4., 5.]); assert_eq!(a12.shape, [2]);
assert_eq!(a12.strides, [1]);
}
#[test]
#[cfg(feature = "cublas")]
fn matmul_cuda() {
use crate::blas::cublas::CublasContext;
use tensorgraph_sys::{
device::cuda::{Context, Stream},
Vec,
};
let ctx = Context::quick_init().unwrap();
let cuda = Stream::new(&ctx).unwrap();
let cuda = &*cuda;
let a = Vec::copy_from_host_in(&[0., 2., 4., 1., 3., 5.], cuda);
let b = Vec::copy_from_host_in(&[0., 2., 1., 3.], cuda);
let ctx = CublasContext::new();
let ctx = ctx.with_stream(Some(cuda));
let a = Tensor::from_shape([3, 2], a);
let b = Tensor::from_shape([2, 2], b);
let c = a.matmul_into(b.view(), ctx, cuda);
let mut out = vec![0.0_f32; 6];
c.data.copy_to_host(&mut out);
assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
}
#[test]
#[cfg(feature = "cublas")]
fn matmul_cuda_global() {
use crate::blas::cublas::CublasContext;
use tensorgraph_sys::{
device::cuda::{Context, Cuda, Stream},
DefaultVec,
};
let ctx = Context::quick_init().unwrap();
let cuda = Stream::new(&ctx).unwrap();
let _handle = cuda.as_global();
let a = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 4., 1., 3., 5.]);
let b = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 1., 3.]);
let ctx = CublasContext::new();
let _handle = ctx.with_stream(Some(&cuda)).as_global();
let a = Tensor::from_shape([3, 2], a);
let b = Tensor::from_shape([2, 2], b);
let c = a.matmul(b.view());
let mut out = vec![0.0_f32; 6];
c.data.copy_to_host(&mut out);
assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
}
#[test]
fn matmul2() {
let a = [0.001, 1.0, 1.0, 0.];
let b = a;
let c = [0.; 4];
let mut a = Tensor::from_shape([2, 2], a);
let b = Tensor::from_shape([2, 2], b);
let mut c = Tensor::from_shape([2, 2], c);
for _ in 0..1000 {
gemm(1., a.view(), b.view(), 0., c.view_mut());
std::mem::swap(&mut a, &mut c);
}
let out = c.into_inner();
let expected = [
1.1278865019586632,
0.5210952168646452,
0.5210952168646452,
1.1273654067417986,
];
assert_eq!(out[0], expected[0]);
assert_eq!(out[1], expected[1]);
assert_eq!(out[2], expected[2]);
assert_eq!(out[3], expected[3]);
}
#[test]
#[cfg(feature = "cublas")]
fn matmul_cuda2() {
use crate::{blas::cublas::CublasContext, tensor::gemm_ctx};
use tensorgraph_sys::{
device::cuda::{Context, Stream},
Vec,
};
let ctx = Context::quick_init().unwrap();
let cuda = Stream::new(&ctx).unwrap();
let cuda = &*cuda;
let a = Vec::copy_from_host_in(&[0.001, 1.0, 1.0, 0.], cuda);
let b = a.clone();
let c = b.clone();
let ctx = CublasContext::new();
let ctx = ctx.with_stream(Some(cuda));
let mut a = Tensor::from_shape([2, 2], a);
let b = Tensor::from_shape([2, 2], b);
let mut c = Tensor::from_shape([2, 2], c);
for _ in 0..1000 {
gemm_ctx(ctx, 1., a.view(), b.view(), 0., c.view_mut());
std::mem::swap(&mut a, &mut c);
}
let mut out = vec![0.; 4];
c.data.copy_to_host(&mut out);
let expected = [
1.1278865019586632,
0.5210952168646452,
0.5210952168646452,
1.1273654067417986,
];
assert_eq!(out[0], expected[0]);
assert_eq!(out[1], expected[1]);
assert_eq!(out[2], expected[2]);
assert_eq!(out[3], expected[3]);
}
}