optirs_learned/transformer_based_optimizer/
layers.rs1use super::config::ActivationFunction;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9pub struct EmbeddingLayer<T: Float + Debug + Send + Sync + 'static> {
11 embedding_matrix: Array2<T>,
13
14 input_dimension: usize,
16
17 output_dimension: usize,
19}
20
21impl<T: Float + Debug + Send + Sync + 'static> EmbeddingLayer<T> {
22 pub fn new(input_dimension: usize, output_dimension: usize) -> Result<Self> {
24 let embedding_matrix = Array2::zeros((input_dimension, output_dimension));
25
26 Ok(Self {
27 embedding_matrix,
28 input_dimension,
29 output_dimension,
30 })
31 }
32
33 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
35 let batch_size = input.shape()[0];
36 let seq_len = input.shape()[1];
37
38 let mut output = Array2::zeros((batch_size, self.output_dimension));
39
40 for i in 0..batch_size {
41 for j in 0..seq_len {
42 let embedding_idx = input[[i, j]].to_usize().unwrap_or(0) % self.input_dimension;
43 let embedding = self.embedding_matrix.row(embedding_idx);
44
45 for k in 0..self.output_dimension {
46 output[[i, k]] = output[[i, k]] + embedding[k];
47 }
48 }
49 }
50
51 Ok(output)
52 }
53
54 pub fn parameter_count(&self) -> usize {
56 self.input_dimension * self.output_dimension
57 }
58
59 pub fn reset(&mut self) -> Result<()> {
61 self.embedding_matrix.fill(T::zero());
62 Ok(())
63 }
64}
65
66pub struct LayerNormalization<T: Float + Debug + Send + Sync + 'static> {
68 dimension: usize,
70
71 gamma: Array1<T>,
73
74 beta: Array1<T>,
76
77 epsilon: T,
79}
80
81impl<T: Float + Debug + Send + Sync + 'static> LayerNormalization<T> {
82 pub fn new(dimension: usize) -> Result<Self> {
84 let gamma = Array1::ones(dimension);
85 let beta = Array1::zeros(dimension);
86 let epsilon = scirs2_core::numeric::NumCast::from(1e-5).unwrap_or_else(|| T::zero());
87
88 Ok(Self {
89 dimension,
90 gamma,
91 beta,
92 epsilon,
93 })
94 }
95
96 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
98 let mut output = input.clone();
99 let batch_size = input.shape()[0];
100
101 for i in 0..batch_size {
102 let row = input.row(i);
103
104 let mean = row.sum() / T::from(row.len()).expect("unwrap failed");
106
107 let variance = row
109 .iter()
110 .map(|&x| {
111 let diff = x - mean;
112 diff * diff
113 })
114 .fold(T::zero(), |acc, x| acc + x)
115 / T::from(row.len()).expect("unwrap failed");
116
117 let std_dev = (variance + self.epsilon).sqrt();
119
120 for j in 0..self.dimension {
121 let normalized = (input[[i, j]] - mean) / std_dev;
122 output[[i, j]] = self.gamma[j] * normalized + self.beta[j];
123 }
124 }
125
126 Ok(output)
127 }
128
129 pub fn parameter_count(&self) -> usize {
131 2 * self.dimension }
133
134 pub fn reset(&mut self) -> Result<()> {
136 self.gamma.fill(T::one());
137 self.beta.fill(T::zero());
138 Ok(())
139 }
140}
141
142pub struct DropoutLayer {
144 dropout_rate: f64,
146
147 training: bool,
149}
150
151impl DropoutLayer {
152 pub fn new(dropout_rate: f64) -> Self {
154 Self {
155 dropout_rate,
156 training: true,
157 }
158 }
159
160 pub fn forward<T: Float + Debug + Send + Sync + 'static>(
162 &self,
163 input: &Array2<T>,
164 ) -> Array2<T> {
165 if !self.training || self.dropout_rate == 0.0 {
166 return input.clone();
167 }
168
169 let mut output = input.clone();
170 let keep_prob = 1.0 - self.dropout_rate;
171 let scale =
172 scirs2_core::numeric::NumCast::from(1.0 / keep_prob).unwrap_or_else(|| T::zero());
173
174 for elem in output.iter_mut() {
175 if scirs2_core::random::random::<f64>() < self.dropout_rate {
176 *elem = T::zero();
177 } else {
178 *elem = *elem * scale;
179 }
180 }
181
182 output
183 }
184
185 pub fn set_training(&mut self, training: bool) {
187 self.training = training;
188 }
189
190 pub fn is_training(&self) -> bool {
192 self.training
193 }
194}
195
196pub struct OutputProjection<T: Float + Debug + Send + Sync + 'static> {
198 weight: Array2<T>,
200
201 bias: Array1<T>,
203
204 input_dim: usize,
206
207 output_dim: usize,
209}
210
211impl<T: Float + Debug + Send + Sync + 'static> OutputProjection<T> {
212 pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
214 let weight = Array2::zeros((input_dim, output_dim));
215 let bias = Array1::zeros(output_dim);
216
217 Ok(Self {
218 weight,
219 bias,
220 input_dim,
221 output_dim,
222 })
223 }
224
225 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
227 let batch_size = input.shape()[0];
228 let mut output = Array2::zeros((batch_size, self.output_dim));
229
230 for i in 0..batch_size {
232 for j in 0..self.output_dim {
233 let mut sum = self.bias[j];
234 for k in 0..self.input_dim {
235 sum = sum + input[[i, k]] * self.weight[[k, j]];
236 }
237 output[[i, j]] = sum;
238 }
239 }
240
241 Ok(output)
242 }
243
244 pub fn parameter_count(&self) -> usize {
246 self.input_dim * self.output_dim + self.output_dim
247 }
248
249 pub fn reset(&mut self) -> Result<()> {
251 self.weight.fill(T::zero());
252 self.bias.fill(T::zero());
253 Ok(())
254 }
255}
256
257pub struct ResidualConnections<T: Float + Debug + Send + Sync + 'static> {
259 dimension: usize,
261
262 scale_factor: Option<T>,
264}
265
266impl<T: Float + Debug + Send + Sync + 'static> ResidualConnections<T> {
267 pub fn new(dimension: usize) -> Self {
269 Self {
270 dimension,
271 scale_factor: None,
272 }
273 }
274
275 pub fn new_with_scaling(dimension: usize, initial_scale: T) -> Self {
277 Self {
278 dimension,
279 scale_factor: Some(initial_scale),
280 }
281 }
282
283 pub fn add(&self, input: &Array2<T>, residual: &Array2<T>) -> Result<Array2<T>> {
285 if input.shape() != residual.shape() {
286 return Err(crate::error::OptimError::Other(
287 "Shape mismatch in residual connection".to_string(),
288 ));
289 }
290
291 let mut output = input + residual;
292
293 if let Some(scale) = self.scale_factor {
294 output.mapv_inplace(|x| x * scale);
295 }
296
297 Ok(output)
298 }
299
300 pub fn set_scale_factor(&mut self, scale: T) {
302 self.scale_factor = Some(scale);
303 }
304
305 pub fn get_scale_factor(&self) -> Option<T> {
307 self.scale_factor
308 }
309}
310
311pub struct ActivationLayer;
313
314impl ActivationLayer {
315 pub fn apply<T: Float + Debug + Send + Sync + 'static>(
317 input: &Array2<T>,
318 activation: ActivationFunction,
319 ) -> Array2<T> {
320 match activation {
321 ActivationFunction::ReLU => Self::relu(input),
322 ActivationFunction::GELU => Self::gelu(input),
323 ActivationFunction::Swish => Self::swish(input),
324 ActivationFunction::Tanh => Self::tanh(input),
325 ActivationFunction::Sigmoid => Self::sigmoid(input),
326 ActivationFunction::LeakyReLU => Self::leaky_relu(
327 input,
328 scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()),
329 ),
330 }
331 }
332
333 fn relu<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
335 input.map(|&x| if x > T::zero() { x } else { T::zero() })
336 }
337
338 fn gelu<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
340 input.map(|&x| {
341 let half = scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero());
342 let one = T::one();
343 let sqrt_2_pi =
344 scirs2_core::numeric::NumCast::from(0.797884560802865).unwrap_or_else(|| T::zero()); let coeff = scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
346
347 let tanh_arg = sqrt_2_pi * (x + coeff * x * x * x);
348 let tanh_val = tanh_arg.tanh();
349
350 half * x * (one + tanh_val)
351 })
352 }
353
354 fn swish<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
356 input.map(|&x| x * Self::sigmoid_scalar(x))
357 }
358
359 fn tanh<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
361 input.map(|&x| x.tanh())
362 }
363
364 fn sigmoid<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
366 input.map(|&x| Self::sigmoid_scalar(x))
367 }
368
369 fn leaky_relu<T: Float + Debug + Send + Sync + 'static>(
371 input: &Array2<T>,
372 alpha: T,
373 ) -> Array2<T> {
374 input.map(|&x| if x > T::zero() { x } else { alpha * x })
375 }
376
377 fn sigmoid_scalar<T: Float + Debug + Send + Sync + 'static>(x: T) -> T {
379 let one = T::one();
380 one / (one + (-x).exp())
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_embedding_layer() {
390 let embedding = EmbeddingLayer::<f32>::new(100, 64);
391 assert!(embedding.is_ok());
392
393 let emb = embedding.expect("unwrap failed");
394 assert_eq!(emb.parameter_count(), 100 * 64);
395 }
396
397 #[test]
398 fn test_layer_normalization() {
399 let layer_norm = LayerNormalization::<f32>::new(128);
400 assert!(layer_norm.is_ok());
401
402 let ln = layer_norm.expect("unwrap failed");
403 let input = Array2::<f32>::ones((2, 128));
404 let result = ln.forward(&input);
405 assert!(result.is_ok());
406 }
407
408 #[test]
409 fn test_dropout_layer() {
410 let mut dropout = DropoutLayer::new(0.5);
411 let input = Array2::<f32>::ones((4, 128));
412
413 dropout.set_training(false);
414 let output = dropout.forward(&input);
415 assert_eq!(output, input);
416
417 dropout.set_training(true);
418 let output = dropout.forward(&input);
419 assert_eq!(output.shape(), input.shape());
420 }
421
422 #[test]
423 fn test_output_projection() {
424 let projection = OutputProjection::<f32>::new(128, 64);
425 assert!(projection.is_ok());
426
427 let proj = projection.expect("unwrap failed");
428 let input = Array2::<f32>::zeros((2, 128));
429 let result = proj.forward(&input);
430 assert!(result.is_ok());
431
432 let output = result.expect("unwrap failed");
433 assert_eq!(output.shape(), &[2, 64]);
434 }
435
436 #[test]
437 fn test_residual_connections() {
438 let residual = ResidualConnections::<f32>::new(64);
439 let input = Array2::<f32>::ones((2, 64));
440 let res_input = Array2::<f32>::ones((2, 64)) * 0.5;
441
442 let result = residual.add(&input, &res_input);
443 assert!(result.is_ok());
444
445 let output = result.expect("unwrap failed");
446 assert_eq!(output[[0, 0]], 1.5);
447 }
448
449 #[test]
450 fn test_activation_functions() {
451 let input = Array2::<f32>::from_shape_vec((2, 2), vec![-1.0, 0.0, 0.5, 1.0])
452 .expect("unwrap failed");
453
454 let relu_output = ActivationLayer::apply(&input, ActivationFunction::ReLU);
455 assert_eq!(relu_output[[0, 0]], 0.0);
456 assert_eq!(relu_output[[1, 1]], 1.0);
457
458 let gelu_output = ActivationLayer::apply(&input, ActivationFunction::GELU);
459 assert_eq!(gelu_output.shape(), input.shape());
460
461 let sigmoid_output = ActivationLayer::apply(&input, ActivationFunction::Sigmoid);
462 assert!(sigmoid_output.iter().all(|&x| (0.0..=1.0).contains(&x)));
463 }
464}