use rand::distributions::Distribution;
use super::*;
use crate::shapes::*;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Tensor<S: Shape, E, D: Storage<E>, T = NoneTape> {
pub(crate) id: UniqueId,
pub(crate) data: Arc<D::Vec>,
pub(crate) shape: S,
pub(crate) strides: S::Concrete,
pub(crate) device: D,
pub(crate) tape: T,
}
impl<S: Shape, E, D: Storage<E>, T> HasShape for Tensor<S, E, D, T> {
type WithShape<New: Shape> = Tensor<New, E, D, T>;
type Shape = S;
fn shape(&self) -> &Self::Shape {
&self.shape
}
}
impl<S: Shape, E: Unit, D: Storage<E>, T> HasUnitType for Tensor<S, E, D, T> {
type Unit = E;
}
impl<S: Shape, E: Dtype, D: Storage<E>, T> HasDtype for Tensor<S, E, D, T> {
type Dtype = E;
}
impl<S: Shape, E, D: Storage<E>, T> HasErr for Tensor<S, E, D, T> {
type Err = D::Err;
}
pub trait Trace<E, D: Storage<E>>: Clone {
type Traced;
fn leaky_trace(&self) -> Self::Traced {
self.clone().leaky_traced()
}
fn leaky_traced(self) -> Self::Traced;
fn trace(&self, gradients: Gradients<E, D>) -> Self::Traced {
self.clone().traced(gradients)
}
fn traced(self, gradients: Gradients<E, D>) -> Self::Traced;
}
impl<S: Shape, E: Unit, F: Unit, D: Storage<F> + Storage<E>> Trace<E, D>
for Tensor<S, F, D, NoneTape>
{
type Traced = Tensor<S, F, D, OwnedTape<E, D>>;
fn leaky_traced(self) -> Self::Traced {
self.put_tape(Default::default())
}
fn traced(self, gradients: Gradients<E, D>) -> Self::Traced {
self.put_tape(OwnedTape {
gradients,
operations: std::vec::Vec::new(),
})
}
}
impl<S: Shape, E, D: Storage<E>, T> Tensor<S, E, D, T> {
pub fn retaped<New: Tape<E, D>>(&self) -> Tensor<S, E, D, New> {
Tensor {
id: self.id,
data: self.data.clone(),
shape: self.shape,
strides: self.strides,
device: self.device.clone(),
tape: Default::default(),
}
}
pub fn device(&self) -> &D {
&self.device
}
}
pub trait PutTape<T> {
type Output;
fn put_tape(self, tape: T) -> Self::Output;
}
impl<S: Shape, E, D: Storage<E>, T> PutTape<T> for Tensor<S, E, D> {
type Output = Tensor<S, E, D, T>;
fn put_tape(self, tape: T) -> Self::Output {
Tensor {
id: self.id,
data: self.data,
shape: self.shape,
strides: self.strides,
device: self.device,
tape,
}
}
}
pub trait SplitTape {
type Tape;
type NoTape: Clone + PutTape<Self::Tape, Output = Self>;
fn split_tape(self) -> (Self::NoTape, Self::Tape);
}
impl<S: Shape, E: Clone, D: Storage<E>, T> SplitTape for Tensor<S, E, D, T> {
type Tape = T;
type NoTape = Tensor<S, E, D>;
fn split_tape(self) -> (Self::NoTape, Self::Tape) {
(
Tensor {
id: self.id,
data: self.data,
shape: self.shape,
strides: self.strides,
device: self.device,
tape: NoneTape,
},
self.tape,
)
}
}
pub trait WithEmptyTape {
fn with_empty_tape(&self) -> Self;
}
impl<S: Shape, E, D: Storage<E>, T: Default> WithEmptyTape for Tensor<S, E, D, T> {
fn with_empty_tape(&self) -> Self {
Tensor {
id: self.id,
data: self.data.clone(),
shape: self.shape,
strides: self.strides,
device: self.device.clone(),
tape: Default::default(),
}
}
}
impl<S: Shape, E: Dtype, D: ZeroFillStorage<E>, T> Tensor<S, E, D, T> {
pub fn fill_with_zeros(&mut self) {
self.try_fill_with_zeros().unwrap()
}
pub fn try_fill_with_zeros(&mut self) -> Result<(), D::Err> {
self.device
.try_fill_with_zeros(Arc::make_mut(&mut self.data))
}
}
impl<S: Shape, E: Dtype, D: OneFillStorage<E>, T> Tensor<S, E, D, T> {
pub fn fill_with_ones(&mut self) {
self.try_fill_with_ones().unwrap()
}
pub fn try_fill_with_ones(&mut self) -> Result<(), D::Err> {
self.device
.try_fill_with_ones(Arc::make_mut(&mut self.data))
}
}
impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> {
pub fn fill_with_distr<Distr: Distribution<E>>(&mut self, distr: Distr) {
self.try_fill_with_distr(distr).unwrap()
}
pub fn try_fill_with_distr<Distr: Distribution<E>>(
&mut self,
distr: Distr,
) -> Result<(), D::Err> {
self.device
.try_fill_with_distr(Arc::make_mut(&mut self.data), distr)
}
}
pub type Tensor0D<Tape = NoneTape> = Tensor<Rank0, f32, Cpu, Tape>;
pub type Tensor1D<const M: usize, Tape = NoneTape> = Tensor<Rank1<M>, f32, Cpu, Tape>;
pub type Tensor2D<const M: usize, const N: usize, Tape = NoneTape> =
Tensor<Rank2<M, N>, f32, Cpu, Tape>;
pub type Tensor3D<const M: usize, const N: usize, const O: usize, Tape = NoneTape> =
Tensor<Rank3<M, N, O>, f32, Cpu, Tape>;
pub type Tensor4D<const M: usize, const N: usize, const O: usize, const P: usize, Tape = NoneTape> =
Tensor<Rank4<M, N, O, P>, f32, Cpu, Tape>;
pub type Tensor5D<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
Tape = NoneTape,
> = Tensor<Rank5<M, N, O, P, Q>, f32, Cpu, Tape>;
pub type Tensor6D<
const M: usize,
const N: usize,
const O: usize,
const P: usize,
const Q: usize,
const R: usize,
Tape = NoneTape,
> = Tensor<Rank6<M, N, O, P, Q, R>, f32, Cpu, Tape>;