1use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13#[derive(Debug, Clone, Copy, Default)]
21pub struct ReLU;
22
23impl ReLU {
24 pub fn new() -> Self {
26 Self
27 }
28}
29
30impl Module for ReLU {
31 fn forward(&self, input: &Variable) -> Variable {
32 input.relu()
33 }
34
35 fn name(&self) -> &'static str {
36 "ReLU"
37 }
38}
39
40#[derive(Debug, Clone, Copy)]
48pub struct LeakyReLU {
49 negative_slope: f32,
50}
51
52impl LeakyReLU {
53 pub fn new() -> Self {
55 Self {
56 negative_slope: 0.01,
57 }
58 }
59
60 pub fn with_slope(negative_slope: f32) -> Self {
62 Self { negative_slope }
63 }
64}
65
66impl Default for LeakyReLU {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl Module for LeakyReLU {
73 fn forward(&self, input: &Variable) -> Variable {
74 let data = input.data();
75 let result: Vec<f32> = data
76 .to_vec()
77 .iter()
78 .map(|&x| if x > 0.0 { x } else { x * self.negative_slope })
79 .collect();
80 Variable::new(
81 Tensor::from_vec(result, data.shape()).unwrap(),
82 input.requires_grad(),
83 )
84 }
85
86 fn name(&self) -> &'static str {
87 "LeakyReLU"
88 }
89}
90
91#[derive(Debug, Clone, Copy, Default)]
99pub struct Sigmoid;
100
101impl Sigmoid {
102 pub fn new() -> Self {
104 Self
105 }
106}
107
108impl Module for Sigmoid {
109 fn forward(&self, input: &Variable) -> Variable {
110 input.sigmoid()
111 }
112
113 fn name(&self) -> &'static str {
114 "Sigmoid"
115 }
116}
117
118#[derive(Debug, Clone, Copy, Default)]
126pub struct Tanh;
127
128impl Tanh {
129 pub fn new() -> Self {
131 Self
132 }
133}
134
135impl Module for Tanh {
136 fn forward(&self, input: &Variable) -> Variable {
137 input.tanh()
138 }
139
140 fn name(&self) -> &'static str {
141 "Tanh"
142 }
143}
144
145#[derive(Debug, Clone, Copy)]
153pub struct Softmax {
154 dim: i64,
155}
156
157impl Softmax {
158 pub fn new(dim: i64) -> Self {
160 Self { dim }
161 }
162}
163
164impl Default for Softmax {
165 fn default() -> Self {
166 Self::new(-1)
167 }
168}
169
170impl Module for Softmax {
171 fn forward(&self, input: &Variable) -> Variable {
172 let data = input.data();
174 let shape = data.shape().to_vec();
175 let data_vec = data.to_vec();
176
177 let ndim = shape.len();
178 let dim = if self.dim < 0 {
179 (ndim as i64 + self.dim) as usize
180 } else {
181 self.dim as usize
182 };
183
184 let outer_size: usize = shape[..dim].iter().product();
185 let dim_size = shape[dim];
186 let inner_size: usize = shape[dim + 1..].iter().product();
187
188 let mut result = vec![0.0f32; data_vec.len()];
189
190 for outer in 0..outer_size {
191 for inner in 0..inner_size {
192 let mut max_val = f32::NEG_INFINITY;
194 for d in 0..dim_size {
195 let idx = outer * dim_size * inner_size + d * inner_size + inner;
196 max_val = max_val.max(data_vec[idx]);
197 }
198
199 let mut sum = 0.0f32;
201 for d in 0..dim_size {
202 let idx = outer * dim_size * inner_size + d * inner_size + inner;
203 let exp_val = (data_vec[idx] - max_val).exp();
204 result[idx] = exp_val;
205 sum += exp_val;
206 }
207
208 for d in 0..dim_size {
210 let idx = outer * dim_size * inner_size + d * inner_size + inner;
211 result[idx] /= sum;
212 }
213 }
214 }
215
216 Variable::new(
217 Tensor::from_vec(result, &shape).unwrap(),
218 input.requires_grad(),
219 )
220 }
221
222 fn name(&self) -> &'static str {
223 "Softmax"
224 }
225}
226
227#[derive(Debug, Clone, Copy)]
233pub struct LogSoftmax {
234 dim: i64,
235}
236
237impl LogSoftmax {
238 pub fn new(dim: i64) -> Self {
240 Self { dim }
241 }
242}
243
244impl Default for LogSoftmax {
245 fn default() -> Self {
246 Self::new(-1)
247 }
248}
249
250impl Module for LogSoftmax {
251 fn forward(&self, input: &Variable) -> Variable {
252 let softmax = Softmax::new(self.dim);
253 let sm = softmax.forward(input);
254 let sm_vec = sm.data().to_vec();
255 let result: Vec<f32> = sm_vec.iter().map(|&x| x.ln()).collect();
256 Variable::new(
257 Tensor::from_vec(result, sm.data().shape()).unwrap(),
258 input.requires_grad(),
259 )
260 }
261
262 fn name(&self) -> &'static str {
263 "LogSoftmax"
264 }
265}
266
267#[derive(Debug, Clone, Copy, Default)]
275pub struct GELU;
276
277impl GELU {
278 pub fn new() -> Self {
280 Self
281 }
282}
283
284impl Module for GELU {
285 fn forward(&self, input: &Variable) -> Variable {
286 let data = input.data();
287 let data_vec = data.to_vec();
288 let sqrt_2_over_pi = (2.0_f32 / std::f32::consts::PI).sqrt();
290 let result: Vec<f32> = data_vec
291 .iter()
292 .map(|&x| {
293 let inner = sqrt_2_over_pi * (x + 0.044715 * x.powi(3));
294 0.5 * x * (1.0 + inner.tanh())
295 })
296 .collect();
297 Variable::new(
298 Tensor::from_vec(result, data.shape()).unwrap(),
299 input.requires_grad(),
300 )
301 }
302
303 fn name(&self) -> &'static str {
304 "GELU"
305 }
306}
307
308#[derive(Debug, Clone, Copy, Default)]
316pub struct SiLU;
317
318impl SiLU {
319 pub fn new() -> Self {
321 Self
322 }
323}
324
325impl Module for SiLU {
326 fn forward(&self, input: &Variable) -> Variable {
327 let sigmoid = input.sigmoid();
328 input.mul_var(&sigmoid)
329 }
330
331 fn name(&self) -> &'static str {
332 "SiLU"
333 }
334}
335
336#[derive(Debug, Clone, Copy)]
344pub struct ELU {
345 alpha: f32,
346}
347
348impl ELU {
349 pub fn new() -> Self {
351 Self { alpha: 1.0 }
352 }
353
354 pub fn with_alpha(alpha: f32) -> Self {
356 Self { alpha }
357 }
358}
359
360impl Default for ELU {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366impl Module for ELU {
367 fn forward(&self, input: &Variable) -> Variable {
368 let data = input.data();
369 let result: Vec<f32> = data
370 .to_vec()
371 .iter()
372 .map(|&x| {
373 if x > 0.0 {
374 x
375 } else {
376 self.alpha * (x.exp() - 1.0)
377 }
378 })
379 .collect();
380 Variable::new(
381 Tensor::from_vec(result, data.shape()).unwrap(),
382 input.requires_grad(),
383 )
384 }
385
386 fn name(&self) -> &'static str {
387 "ELU"
388 }
389}
390
391#[derive(Debug, Clone, Copy, Default)]
397pub struct Identity;
398
399impl Identity {
400 pub fn new() -> Self {
402 Self
403 }
404}
405
406impl Module for Identity {
407 fn forward(&self, input: &Variable) -> Variable {
408 input.clone()
409 }
410
411 fn name(&self) -> &'static str {
412 "Identity"
413 }
414}
415
416#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_relu() {
426 let relu = ReLU::new();
427 let input = Variable::new(
428 Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
429 false,
430 );
431 let output = relu.forward(&input);
432 assert_eq!(output.data().to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
433 }
434
435 #[test]
436 fn test_sigmoid() {
437 let sigmoid = Sigmoid::new();
438 let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
439 let output = sigmoid.forward(&input);
440 assert!((output.data().to_vec()[0] - 0.5).abs() < 1e-6);
441 }
442
443 #[test]
444 fn test_softmax() {
445 let softmax = Softmax::new(-1);
446 let input = Variable::new(
447 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
448 false,
449 );
450 let output = softmax.forward(&input);
451 let sum: f32 = output.data().to_vec().iter().sum();
452 assert!((sum - 1.0).abs() < 1e-5);
453 }
454
455 #[test]
456 fn test_leaky_relu() {
457 let leaky = LeakyReLU::with_slope(0.1);
458 let input = Variable::new(Tensor::from_vec(vec![-1.0, 0.0, 1.0], &[3]).unwrap(), false);
459 let output = leaky.forward(&input);
460 assert_eq!(output.data().to_vec(), vec![-0.1, 0.0, 1.0]);
461 }
462
463 #[test]
464 fn test_identity() {
465 let id = Identity::new();
466 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
467 let output = id.forward(&input);
468 assert_eq!(output.data().to_vec(), vec![1.0, 2.0, 3.0]);
469 }
470}