1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use crate::prelude::*;
pub trait Tensor:
HasArrayType + HasArrayData + HasDevice + CanUpdateWithGradients + HasUniqueId + IntoPhantom
{
type Tape: Tape;
type NoTape: 'static
+ Tensor<Array = Self::Array, Dtype = Self::Dtype, Tape = NoneTape, NoTape = Self::NoTape>
+ TensorCreator
+ PutTape<Self::Tape, Output = Self>
+ Clone;
type OwnedTape: 'static
+ Tensor<
Array = Self::Array,
Dtype = Self::Dtype,
Tape = OwnedTape,
OwnedTape = Self::OwnedTape,
>;
fn split_tape(self) -> (Self::NoTape, Self::Tape);
fn duplicate(&self) -> Self::NoTape;
}
macro_rules! tensor_impl {
($struct:ident, [$($Vs:tt),*]) => {
impl<$(const $Vs: usize, )* H: Tape> Tensor for $struct<$($Vs, )* H> {
type Tape = H;
type NoTape = $struct<$($Vs, )* NoneTape>;
type OwnedTape = $struct<$($Vs, )* OwnedTape>;
fn split_tape(self) -> (Self::NoTape, Self::Tape) {
(
Self::NoTape { id: self.id, data: self.data, tape: Default::default() },
self.tape,
)
}
fn duplicate(&self) -> Self::NoTape {
Self::NoTape {
id: self.id,
data: self.data.clone(),
tape: Default::default(),
}
}
}
impl<$(const $Vs: usize, )* H: Clone> Clone for $struct<$($Vs, )* H> {
fn clone(&self) -> Self {
Self {
id: unique_id(),
data: self.data.clone(),
tape: self.tape.clone(),
}
}
}
};
}
tensor_impl!(Tensor0D, []);
tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ids_with_duplicate() {
let t1: Tensor1D<32> = TensorCreator::zeros();
let t2: Tensor1D<32, NoneTape> = t1.duplicate();
assert_eq!(t1.id, t2.id);
}
#[test]
fn test_ids_with_clone() {
let t1: Tensor1D<32> = TensorCreator::zeros();
let t2: Tensor1D<32, NoneTape> = t1.clone();
assert_ne!(t1.id, t2.id);
}
#[test]
fn test_ids_with_split_and_put() {
let t1: Tensor1D<32> = TensorCreator::zeros();
let t1_id = t1.id;
let (t2, tape) = t1.split_tape();
assert_eq!(t2.id, t1_id);
let t3 = t2.put_tape(tape);
assert_eq!(t3.id, t1_id);
}
}