ferrite/tensor/ops/
activation.rs1use crate::*;
2use std::rc::Rc;
3
4pub trait ActivationOps {
5 fn binary_step(&self) -> Self;
6 fn sigmoid(&self) -> Self;
7 fn tanh(&self) -> Self;
8 fn relu(&self) -> Self;
9 fn leaky_relu(&self) -> Self;
10 fn parametric_relu(&self, a: f32) -> Self;
11 fn elu(&self, alpha: f32) -> Self;
12 fn softmax(&self, dim: usize) -> Self;
13 fn swish(&self) -> Self;
14}
15
16impl ActivationOps for Storage {
17 fn binary_step(&self) -> Self {
18 match_storage!(unary self, binary_step)
19 }
20
21 fn sigmoid(&self) -> Self {
22 match_storage!(unary self, sigmoid)
23 }
24
25 fn tanh(&self) -> Self {
26 match_storage!(unary self, tanh)
27 }
28
29 fn relu(&self) -> Self {
30 match_storage!(unary self, relu)
31 }
32
33 fn leaky_relu(&self) -> Self {
34 match_storage!(unary self, leaky_relu)
35 }
36
37 fn parametric_relu(&self, a: f32) -> Self {
38 match_storage!(unary self, parametric_relu, a)
39 }
40
41 fn elu(&self, alpha: f32) -> Self {
42 match_storage!(unary self, elu, alpha)
43 }
44
45 fn softmax(&self, dim: usize) -> Self {
46 match_storage!(unary self, softmax, dim)
47 }
48
49 fn swish(&self) -> Self {
50 match_storage!(unary self, swish)
51 }
52}
53
54
55impl ActivationOps for Tensor {
56 fn binary_step(&self) -> Self {
57 let tensor = self.tensor().binary_step();
58
59 let requires_grad = *self.requires_grad();
61 let mut result = Tensor::new(tensor, self.device(), requires_grad);
62
63 if requires_grad {
65 result.set_grad_fn(Some(Rc::new(BinaryStepGrad::new(
66 self,
67 &result
68 ))));
69 }
70
71 result
72 }
73
74 fn sigmoid(&self) -> Self {
75 let tensor = self.tensor().sigmoid();
76
77 let requires_grad = *self.requires_grad();
79 let mut result = Tensor::new(tensor, self.device(), requires_grad);
80
81 if requires_grad {
83 result.set_grad_fn(Some(Rc::new(SigmoidGrad::new(
84 self,
85 &result
86 ))));
87 }
88
89 result
90 }
91
92 fn tanh(&self) -> Self {
93 let tensor = self.tensor().tanh();
94
95 let requires_grad = *self.requires_grad();
97 let mut result = Tensor::new(tensor, self.device(), requires_grad);
98
99 if requires_grad {
101 result.set_grad_fn(Some(Rc::new(TanhGrad::new(
102 self,
103 &result
104 ))));
105 }
106
107 result
108 }
109
110 fn relu(&self) -> Self {
111 let tensor = self.tensor().relu();
112
113 let requires_grad = *self.requires_grad();
115 let mut result = Tensor::new(tensor, self.device(), requires_grad);
116
117 if requires_grad {
119 result.set_grad_fn(Some(Rc::new(ReluGrad::new(
120 self,
121 &result
122 ))));
123 }
124
125 result
126 }
127
128 fn leaky_relu(&self) -> Self {
129 let tensor = self.tensor().leaky_relu();
130
131 let requires_grad = *self.requires_grad();
133 let mut result = Tensor::new(tensor, self.device(), requires_grad);
134
135 if requires_grad {
137 result.set_grad_fn(Some(Rc::new(LeakyReluGrad::new(
138 self,
139 &result
140 ))));
141 }
142
143 result
144 }
145
146 fn parametric_relu(&self, a: f32) -> Self {
147 let tensor = self.tensor().parametric_relu(a);
148
149 let requires_grad = *self.requires_grad();
151 let mut result = Tensor::new(tensor, self.device(), requires_grad);
152
153 if requires_grad {
155 result.set_grad_fn(Some(Rc::new(ParametricReluGrad::new(
156 self,
157 a,
158 &result
159 ))));
160 }
161
162 result
163 }
164
165 fn elu(&self, alpha: f32) -> Self {
166 let tensor = self.tensor().elu(alpha);
167
168 let requires_grad = *self.requires_grad();
170 let mut result = Tensor::new(tensor, self.device(), requires_grad);
171
172 if requires_grad {
174 result.set_grad_fn(Some(Rc::new(EluGrad::new(
175 self,
176 alpha,
177 &result
178 ))));
179 }
180
181 result
182 }
183
184 fn softmax(&self, dim: usize) -> Self {
185 let tensor = self.tensor().softmax(dim);
186
187 let requires_grad = *self.requires_grad();
189 let mut result = Tensor::new(tensor, self.device(), requires_grad);
190
191 if requires_grad {
193 result.set_grad_fn(Some(Rc::new(SoftmaxGrad::new(
194 self,
195 &result
196 ))));
197 }
198
199 result
200 }
201
202 fn swish(&self) -> Self {
203 let tensor = self.tensor().swish();
204
205 let requires_grad = *self.requires_grad();
207 let mut result = Tensor::new(tensor, self.device(), requires_grad);
208
209 if requires_grad {
211 result.set_grad_fn(Some(Rc::new(SwishGrad::new(
212 self,
213 &result
214 ))));
215 }
216
217 result
218 }
219}