1use axonml_autograd::Variable;
18
19use crate::module::Module;
20
21#[derive(Debug, Clone, Copy, Default)]
29pub struct ReLU;
30
31impl ReLU {
32 pub fn new() -> Self {
34 Self
35 }
36}
37
38impl Module for ReLU {
39 fn forward(&self, input: &Variable) -> Variable {
40 input.relu()
41 }
42
43 fn name(&self) -> &'static str {
44 "ReLU"
45 }
46}
47
48#[derive(Debug, Clone, Copy)]
56pub struct LeakyReLU {
57 negative_slope: f32,
58}
59
60impl LeakyReLU {
61 pub fn new() -> Self {
63 Self {
64 negative_slope: 0.01,
65 }
66 }
67
68 pub fn with_slope(negative_slope: f32) -> Self {
70 Self { negative_slope }
71 }
72}
73
74impl Default for LeakyReLU {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Module for LeakyReLU {
81 fn forward(&self, input: &Variable) -> Variable {
82 input.leaky_relu(self.negative_slope)
83 }
84
85 fn name(&self) -> &'static str {
86 "LeakyReLU"
87 }
88}
89
90#[derive(Debug, Clone, Copy, Default)]
98pub struct Sigmoid;
99
100impl Sigmoid {
101 pub fn new() -> Self {
103 Self
104 }
105}
106
107impl Module for Sigmoid {
108 fn forward(&self, input: &Variable) -> Variable {
109 input.sigmoid()
110 }
111
112 fn name(&self) -> &'static str {
113 "Sigmoid"
114 }
115}
116
117#[derive(Debug, Clone, Copy, Default)]
125pub struct Tanh;
126
127impl Tanh {
128 pub fn new() -> Self {
130 Self
131 }
132}
133
134impl Module for Tanh {
135 fn forward(&self, input: &Variable) -> Variable {
136 input.tanh()
137 }
138
139 fn name(&self) -> &'static str {
140 "Tanh"
141 }
142}
143
144#[derive(Debug, Clone, Copy)]
152pub struct Softmax {
153 dim: i64,
154}
155
156impl Softmax {
157 pub fn new(dim: i64) -> Self {
159 Self { dim }
160 }
161}
162
163impl Default for Softmax {
164 fn default() -> Self {
165 Self::new(-1)
166 }
167}
168
169impl Module for Softmax {
170 fn forward(&self, input: &Variable) -> Variable {
171 input.softmax(self.dim as i32)
172 }
173
174 fn name(&self) -> &'static str {
175 "Softmax"
176 }
177}
178
179#[derive(Debug, Clone, Copy)]
185pub struct LogSoftmax {
186 dim: i64,
187}
188
189impl LogSoftmax {
190 pub fn new(dim: i64) -> Self {
192 Self { dim }
193 }
194}
195
196impl Default for LogSoftmax {
197 fn default() -> Self {
198 Self::new(-1)
199 }
200}
201
202impl Module for LogSoftmax {
203 fn forward(&self, input: &Variable) -> Variable {
204 input.log_softmax(self.dim as i32)
205 }
206
207 fn name(&self) -> &'static str {
208 "LogSoftmax"
209 }
210}
211
212#[derive(Debug, Clone, Copy, Default)]
220pub struct GELU;
221
222impl GELU {
223 pub fn new() -> Self {
225 Self
226 }
227}
228
229impl Module for GELU {
230 fn forward(&self, input: &Variable) -> Variable {
231 input.gelu()
232 }
233
234 fn name(&self) -> &'static str {
235 "GELU"
236 }
237}
238
239#[derive(Debug, Clone, Copy, Default)]
247pub struct SiLU;
248
249impl SiLU {
250 pub fn new() -> Self {
252 Self
253 }
254}
255
256impl Module for SiLU {
257 fn forward(&self, input: &Variable) -> Variable {
258 let sigmoid = input.sigmoid();
259 input.mul_var(&sigmoid)
260 }
261
262 fn name(&self) -> &'static str {
263 "SiLU"
264 }
265}
266
267#[derive(Debug, Clone, Copy)]
275pub struct ELU {
276 alpha: f32,
277}
278
279impl ELU {
280 pub fn new() -> Self {
282 Self { alpha: 1.0 }
283 }
284
285 pub fn with_alpha(alpha: f32) -> Self {
287 Self { alpha }
288 }
289}
290
291impl Default for ELU {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297impl Module for ELU {
298 fn forward(&self, input: &Variable) -> Variable {
299 input.elu(self.alpha)
300 }
301
302 fn name(&self) -> &'static str {
303 "ELU"
304 }
305}
306
307#[derive(Debug, Clone, Copy, Default)]
313pub struct Identity;
314
315impl Identity {
316 pub fn new() -> Self {
318 Self
319 }
320}
321
322impl Module for Identity {
323 fn forward(&self, input: &Variable) -> Variable {
324 input.clone()
325 }
326
327 fn name(&self) -> &'static str {
328 "Identity"
329 }
330}
331
332#[derive(Debug, Clone, Copy)]
346pub struct Flatten {
347 start_dim: usize,
348}
349
350impl Flatten {
351 pub fn new() -> Self {
353 Self { start_dim: 1 }
354 }
355
356 pub fn from(start_dim: usize) -> Self {
358 Self { start_dim }
359 }
360}
361
362impl Default for Flatten {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368impl Module for Flatten {
369 fn forward(&self, input: &Variable) -> Variable {
370 input.flatten(self.start_dim)
371 }
372
373 fn parameters(&self) -> Vec<crate::Parameter> {
374 Vec::new()
375 }
376
377 fn name(&self) -> &'static str {
378 "Flatten"
379 }
380}
381
382#[cfg(test)]
387mod tests {
388 use super::*;
389 use axonml_tensor::Tensor;
390
391 #[test]
392 fn test_relu() {
393 let relu = ReLU::new();
394 let input = Variable::new(
395 Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
396 false,
397 );
398 let output = relu.forward(&input);
399 assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
400 }
401
402 #[test]
403 fn test_sigmoid() {
404 let sigmoid = Sigmoid::new();
405 let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
406 let output = sigmoid.forward(&input);
407 assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
408 }
409
410 #[test]
411 fn test_softmax() {
412 let softmax = Softmax::new(-1);
413 let input = Variable::new(
414 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
415 false,
416 );
417 let output = softmax.forward(&input);
418 let sum: f32 = output.data().to_vec().iter().sum();
419 assert!((sum - 1.0).abs() < 1e-5);
420 }
421
422 #[test]
423 fn test_leaky_relu() {
424 let leaky = LeakyReLU::with_slope(0.1);
425 let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
426 let output = leaky.forward(&input);
427 assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
428 }
429
430 #[test]
431 fn test_identity() {
432 let id = Identity::new();
433 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
434 let output = id.forward(&input);
435 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
436 }
437}