1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
use rand::distributions::Distribution;

use super::*;
use crate::shapes::*;

use std::sync::Arc;

/// The single tensor struct that stores nd arrays and tapes.
///
/// See module level documentation on how to create and use tensors.
///
/// Generics:
/// 1. [Shape] - the shape of the underlying nd array
/// 2. [Dtype] - the type of the datas stored in the array
/// 3. [Storage] - the device the array is stored on
/// 4. [Tape] - the tape the tensor has
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// // A 1d tensor with 1000 f32 elements, stored on the Cpu
/// type A = Tensor<Rank1<1000>, f32, Cpu>;
///
/// // A 2d tensor with bool elements, stored on the Cpu
/// type B = Tensor<Rank2<2, 3>, bool, Cpu>;
///
/// // A 3d tensor with usize elements, stored on the Cpu, without any tape
/// type C = Tensor<Rank3<4, 2, 3>, usize, Cpu, NoneTape>;
/// ```
#[derive(Debug, Clone)]
pub struct Tensor<S: Shape, E, D: Storage<E>, T = NoneTape> {
    pub(crate) id: UniqueId,
    pub(crate) data: Arc<D::Vec>,
    pub(crate) shape: S,
    pub(crate) strides: S::Concrete,
    pub(crate) device: D,
    pub(crate) tape: T,
}

impl<S: Shape, E, D: Storage<E>, T> HasShape for Tensor<S, E, D, T> {
    type WithShape<New: Shape> = Tensor<New, E, D, T>;
    type Shape = S;
    fn shape(&self) -> &Self::Shape {
        &self.shape
    }
}

impl<S: Shape, E: Unit, D: Storage<E>, T> HasUnitType for Tensor<S, E, D, T> {
    type Unit = E;
}

impl<S: Shape, E: Dtype, D: Storage<E>, T> HasDtype for Tensor<S, E, D, T> {
    type Dtype = E;
}

impl<S: Shape, E, D: Storage<E>, T> HasErr for Tensor<S, E, D, T> {
    type Err = D::Err;
}

/// Something that can trace gradients
pub trait Trace<E, D: Storage<E>>: Clone {
    type Traced;
    /// Start tracking gradients, clones self. The gradients will never free
    /// temporary gradients - See [Gradients::leaky()] for more info.
    ///
    /// Prefer to use [Tensor::trace()] with gradients allocated
    /// with [crate::nn::ZeroGrads::alloc_grads()].
    fn leaky_trace(&self) -> Self::Traced {
        self.clone().leaky_traced()
    }
    /// Start tracking gradients. The gradients will never free
    /// temporary gradients - See [Gradients::leaky()] for more info.
    ///
    /// Prefer to use [Tensor::traced()] with gradients allocated
    /// with [crate::nn::ZeroGrads::alloc_grads()].
    fn leaky_traced(self) -> Self::Traced;

    /// Accumulates gradients into `gradients`, clones self. Use [crate::nn::ZeroGrads::alloc_grads()]
    /// to create gradients.
    fn trace(&self, gradients: Gradients<E, D>) -> Self::Traced {
        self.clone().traced(gradients)
    }
    /// Accumulates gradients into `gradients`. Use [crate::nn::ZeroGrads::alloc_grads()]
    /// to create gradients.
    fn traced(self, gradients: Gradients<E, D>) -> Self::Traced;
}

impl<S: Shape, E: Unit, F: Unit, D: Storage<F> + Storage<E>> Trace<E, D>
    for Tensor<S, F, D, NoneTape>
{
    type Traced = Tensor<S, F, D, OwnedTape<E, D>>;
    fn leaky_traced(self) -> Self::Traced {
        self.put_tape(Default::default())
    }
    fn traced(self, gradients: Gradients<E, D>) -> Self::Traced {
        self.put_tape(OwnedTape {
            gradients,
            operations: std::vec::Vec::new(),
        })
    }
}

impl<S: Shape, E, D: Storage<E>, T> Tensor<S, E, D, T> {
    /// Clone and insert a new tape of type `New` into the tensor
    pub fn retaped<New: Tape<E, D>>(&self) -> Tensor<S, E, D, New> {
        Tensor {
            id: self.id,
            data: self.data.clone(),
            shape: self.shape,
            strides: self.strides,
            device: self.device.clone(),
            tape: Default::default(),
        }
    }

    /// Get a reference to the tensor's `Storage`
    pub fn device(&self) -> &D {
        &self.device
    }
}

/// Put a tape of type `T` into the tensor
pub trait PutTape<T> {
    type Output;
    /// ```rust
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// let a: Tensor<Rank2<2, 3>, f32, _, NoneTape> = dev.zeros();
    /// let a: Tensor<Rank2<2, 3>, f32, _, OwnedTape<f32, Cpu>> = a.put_tape(Default::default());
    /// ```
    fn put_tape(self, tape: T) -> Self::Output;
}

impl<S: Shape, E, D: Storage<E>, T> PutTape<T> for Tensor<S, E, D> {
    type Output = Tensor<S, E, D, T>;
    fn put_tape(self, tape: T) -> Self::Output {
        Tensor {
            id: self.id,
            data: self.data,
            shape: self.shape,
            strides: self.strides,
            device: self.device,
            tape,
        }
    }
}

/// Remove the tape from a tensor
pub trait SplitTape {
    /// The type of tape the tensor has now
    type Tape;
    // The type of Self without the tape.
    type NoTape: Clone + PutTape<Self::Tape, Output = Self>;
    /// Splits tape off of self
    /// ```rust
    /// # use dfdx::prelude::*;
    /// # let dev: Cpu = Default::default();
    /// # let grads = Gradients::leaky();
    /// let a: Tensor<Rank1<5>, f32, _, OwnedTape<f32, _>> = dev.zeros().traced(grads);
    /// let (a, tape): (Tensor<_, _, _, NoneTape>, OwnedTape<f32, _>) = a.split_tape();
    /// ```
    fn split_tape(self) -> (Self::NoTape, Self::Tape);
}

impl<S: Shape, E: Clone, D: Storage<E>, T> SplitTape for Tensor<S, E, D, T> {
    type Tape = T;
    type NoTape = Tensor<S, E, D>;
    fn split_tape(self) -> (Self::NoTape, Self::Tape) {
        (
            Tensor {
                id: self.id,
                data: self.data,
                shape: self.shape,
                strides: self.strides,
                device: self.device,
                tape: NoneTape,
            },
            self.tape,
        )
    }
}

/// Clones self and inserts a new empty tape into the clone
pub trait WithEmptyTape {
    /// Clones self and inserts a new empty tape into the clone
    fn with_empty_tape(&self) -> Self;
}

impl<S: Shape, E, D: Storage<E>, T: Default> WithEmptyTape for Tensor<S, E, D, T> {
    fn with_empty_tape(&self) -> Self {
        Tensor {
            id: self.id,
            data: self.data.clone(),
            shape: self.shape,
            strides: self.strides,
            device: self.device.clone(),
            tape: Default::default(),
        }
    }
}

impl<S: Shape, E: Dtype, D: ZeroFillStorage<E>, T> Tensor<S, E, D, T> {
    /// Fills the tensor with zeros
    pub fn fill_with_zeros(&mut self) {
        self.try_fill_with_zeros().unwrap()
    }
    /// Fallible version of [Tensor::fill_with_zeros]
    pub fn try_fill_with_zeros(&mut self) -> Result<(), D::Err> {
        self.device
            .try_fill_with_zeros(Arc::make_mut(&mut self.data))
    }
}

impl<S: Shape, E: Dtype, D: OneFillStorage<E>, T> Tensor<S, E, D, T> {
    /// Fills the tensor with ones
    pub fn fill_with_ones(&mut self) {
        self.try_fill_with_ones().unwrap()
    }
    /// Fallible version of [Tensor::fill_with_ones]
    pub fn try_fill_with_ones(&mut self) -> Result<(), D::Err> {
        self.device
            .try_fill_with_ones(Arc::make_mut(&mut self.data))
    }
}

impl<S: Shape, E: Unit, D: SampleTensor<E>, T> Tensor<S, E, D, T> {
    /// Fills the tensor with random data from the distribution
    pub fn fill_with_distr<Distr: Distribution<E>>(&mut self, distr: Distr) {
        self.try_fill_with_distr(distr).unwrap()
    }

    /// Fallible version of [Tensor::fill_with_distr]
    pub fn try_fill_with_distr<Distr: Distribution<E>>(
        &mut self,
        distr: Distr,
    ) -> Result<(), D::Err> {
        self.device
            .try_fill_with_distr(Arc::make_mut(&mut self.data), distr)
    }
}

pub type Tensor0D<Tape = NoneTape> = Tensor<Rank0, f32, Cpu, Tape>;
pub type Tensor1D<const M: usize, Tape = NoneTape> = Tensor<Rank1<M>, f32, Cpu, Tape>;
pub type Tensor2D<const M: usize, const N: usize, Tape = NoneTape> =
    Tensor<Rank2<M, N>, f32, Cpu, Tape>;
pub type Tensor3D<const M: usize, const N: usize, const O: usize, Tape = NoneTape> =
    Tensor<Rank3<M, N, O>, f32, Cpu, Tape>;
pub type Tensor4D<const M: usize, const N: usize, const O: usize, const P: usize, Tape = NoneTape> =
    Tensor<Rank4<M, N, O, P>, f32, Cpu, Tape>;
pub type Tensor5D<
    const M: usize,
    const N: usize,
    const O: usize,
    const P: usize,
    const Q: usize,
    Tape = NoneTape,
> = Tensor<Rank5<M, N, O, P, Q>, f32, Cpu, Tape>;
pub type Tensor6D<
    const M: usize,
    const N: usize,
    const O: usize,
    const P: usize,
    const Q: usize,
    const R: usize,
    Tape = NoneTape,
> = Tensor<Rank6<M, N, O, P, Q, R>, f32, Cpu, Tape>;