1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use super::conv;
use crate::backend::Backend;
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
pub weights_grad: B::TensorPrimitive<4>,
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
}
#[derive(new)]
pub struct MaxPool2dWithIndexes<B: Backend> {
pub output: B::TensorPrimitive<4>,
pub indexes: B::IntTensorPrimitive<4>,
}
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<3>,
pub weights_grad: B::TensorPrimitive<3>,
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
pub trait ModuleOps<B: Backend> {
fn embedding(
weights: B::TensorPrimitive<2>,
indexes: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<3>;
fn embedding_backward(
weights: B::TensorPrimitive<2>,
output: B::TensorPrimitive<3>,
indexes: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<2>;
fn conv2d(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
fn conv2d_backward(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
output_grad: B::TensorPrimitive<4>,
) -> Conv2dBackward<B> {
conv::conv2d_backward(x, weight, bias, stride, output_grad)
}
fn conv1d(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
) -> B::TensorPrimitive<3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding)
}
fn conv1d_backward(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
output_grad: B::TensorPrimitive<3>,
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, stride, output_grad)
}
fn max_pool2d(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
fn max_pool2d_with_indexes(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<B>;
fn max_pool2d_with_indexes_backward(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: B::TensorPrimitive<4>,
indexes: B::IntTensorPrimitive<4>,
) -> MaxPool2dBackward<B>;
}