ferrite/tensor/ops/
activation.rs

1use crate::*;
2use std::rc::Rc;
3
4pub trait ActivationOps {
5  fn binary_step(&self) -> Self;
6  fn sigmoid(&self) -> Self;
7  fn tanh(&self) -> Self;
8  fn relu(&self) -> Self;
9  fn leaky_relu(&self) -> Self;
10  fn parametric_relu(&self, a: f32) -> Self;
11  fn elu(&self, alpha: f32) -> Self;
12  fn softmax(&self, dim: usize) -> Self;
13  fn swish(&self) -> Self;
14}
15
16impl ActivationOps for Storage {
17  fn binary_step(&self) -> Self {
18    match_storage!(unary self, binary_step)
19  }
20
21  fn sigmoid(&self) -> Self {
22    match_storage!(unary self, sigmoid)
23  }
24
25  fn tanh(&self) -> Self {
26    match_storage!(unary self, tanh)
27  }
28
29  fn relu(&self) -> Self {
30    match_storage!(unary self, relu)
31  }
32
33  fn leaky_relu(&self) -> Self {
34    match_storage!(unary self, leaky_relu)
35  }
36
37  fn parametric_relu(&self, a: f32) -> Self {
38    match_storage!(unary self, parametric_relu, a)
39  }
40
41  fn elu(&self, alpha: f32) -> Self {
42    match_storage!(unary self, elu, alpha)
43  }
44
45  fn softmax(&self, dim: usize) -> Self {
46    match_storage!(unary self, softmax, dim)
47  }
48
49  fn swish(&self) -> Self {
50    match_storage!(unary self, swish)
51  } 
52}
53
54
55impl ActivationOps for Tensor {
56  fn binary_step(&self) -> Self {
57    let tensor = self.tensor().binary_step();
58    
59    // Create result tensor
60    let requires_grad = *self.requires_grad();
61    let mut result = Tensor::new(tensor, self.device(), requires_grad);
62    
63    // Set up gradient function if needed
64    if requires_grad {
65      result.set_grad_fn(Some(Rc::new(BinaryStepGrad::new(
66        self, 
67        &result
68      ))));
69    }
70    
71    result
72  }
73  
74  fn sigmoid(&self) -> Self {
75    let tensor = self.tensor().sigmoid();
76    
77    // Create result tensor
78    let requires_grad = *self.requires_grad();
79    let mut result = Tensor::new(tensor, self.device(), requires_grad);
80    
81    // Set up gradient function if needed
82    if requires_grad {
83      result.set_grad_fn(Some(Rc::new(SigmoidGrad::new(
84        self, 
85        &result
86      ))));
87    }
88    
89    result
90  }
91  
92  fn tanh(&self) -> Self {
93    let tensor = self.tensor().tanh();
94    
95    // Create result tensor
96    let requires_grad = *self.requires_grad();
97    let mut result = Tensor::new(tensor, self.device(), requires_grad);
98    
99    // Set up gradient function if needed
100    if requires_grad {
101      result.set_grad_fn(Some(Rc::new(TanhGrad::new(
102        self, 
103        &result
104      ))));
105    }
106    
107    result
108  }
109  
110  fn relu(&self) -> Self {
111    let tensor = self.tensor().relu();
112    
113    // Create result tensor
114    let requires_grad = *self.requires_grad();
115    let mut result = Tensor::new(tensor, self.device(), requires_grad);
116    
117    // Set up gradient function if needed
118    if requires_grad {
119      result.set_grad_fn(Some(Rc::new(ReluGrad::new(
120        self, 
121        &result
122      ))));
123    }
124    
125    result
126  }
127  
128  fn leaky_relu(&self) -> Self {
129    let tensor = self.tensor().leaky_relu();
130    
131    // Create result tensor
132    let requires_grad = *self.requires_grad();
133    let mut result = Tensor::new(tensor, self.device(), requires_grad);
134    
135    // Set up gradient function if needed
136    if requires_grad {
137      result.set_grad_fn(Some(Rc::new(LeakyReluGrad::new(
138        self, 
139        &result
140      ))));
141    }
142    
143    result
144  }
145  
146  fn parametric_relu(&self, a: f32) -> Self {
147    let tensor = self.tensor().parametric_relu(a);
148    
149    // Create result tensor
150    let requires_grad = *self.requires_grad();
151    let mut result = Tensor::new(tensor, self.device(), requires_grad);
152    
153    // Set up gradient function if needed
154    if requires_grad {
155      result.set_grad_fn(Some(Rc::new(ParametricReluGrad::new(
156        self, 
157        a,
158        &result
159      ))));
160    }
161    
162    result
163  }
164  
165  fn elu(&self, alpha: f32) -> Self {
166    let tensor = self.tensor().elu(alpha);
167    
168    // Create result tensor
169    let requires_grad = *self.requires_grad();
170    let mut result = Tensor::new(tensor, self.device(), requires_grad);
171    
172    // Set up gradient function if needed
173    if requires_grad {
174      result.set_grad_fn(Some(Rc::new(EluGrad::new(
175        self,
176        alpha,
177        &result
178      ))));
179    }
180    
181    result
182  }
183  
184  fn softmax(&self, dim: usize) -> Self {
185    let tensor = self.tensor().softmax(dim);
186    
187    // Create result tensor
188    let requires_grad = *self.requires_grad();
189    let mut result = Tensor::new(tensor, self.device(), requires_grad);
190    
191    // Set up gradient function if needed
192    if requires_grad {
193      result.set_grad_fn(Some(Rc::new(SoftmaxGrad::new(
194        self, 
195        &result
196      ))));
197    }
198    
199    result
200  }
201  
202  fn swish(&self) -> Self {
203    let tensor = self.tensor().swish();
204    
205    // Create result tensor
206    let requires_grad = *self.requires_grad();
207    let mut result = Tensor::new(tensor, self.device(), requires_grad);
208    
209    // Set up gradient function if needed
210    if requires_grad {
211      result.set_grad_fn(Some(Rc::new(SwishGrad::new(
212        self, 
213        &result
214      ))));
215    }
216    
217    result
218  }
219}