optirs_learned/transformer_based_optimizer/
layers.rs

1// Core transformer layer implementations
2
3use super::config::ActivationFunction;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9/// Embedding layer for input vectors
10pub struct EmbeddingLayer<T: Float + Debug + Send + Sync + 'static> {
11    /// Embedding matrix
12    embedding_matrix: Array2<T>,
13
14    /// Input dimension
15    input_dimension: usize,
16
17    /// Output dimension (model dimension)
18    output_dimension: usize,
19}
20
21impl<T: Float + Debug + Send + Sync + 'static> EmbeddingLayer<T> {
22    /// Create new embedding layer
23    pub fn new(input_dimension: usize, output_dimension: usize) -> Result<Self> {
24        let embedding_matrix = Array2::zeros((input_dimension, output_dimension));
25
26        Ok(Self {
27            embedding_matrix,
28            input_dimension,
29            output_dimension,
30        })
31    }
32
33    /// Forward pass through embedding layer
34    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
35        let batch_size = input.shape()[0];
36        let seq_len = input.shape()[1];
37
38        let mut output = Array2::zeros((batch_size, self.output_dimension));
39
40        for i in 0..batch_size {
41            for j in 0..seq_len {
42                let embedding_idx = input[[i, j]].to_usize().unwrap_or(0) % self.input_dimension;
43                let embedding = self.embedding_matrix.row(embedding_idx);
44
45                for k in 0..self.output_dimension {
46                    output[[i, k]] = output[[i, k]] + embedding[k];
47                }
48            }
49        }
50
51        Ok(output)
52    }
53
54    /// Get parameter count
55    pub fn parameter_count(&self) -> usize {
56        self.input_dimension * self.output_dimension
57    }
58
59    /// Reset embedding parameters
60    pub fn reset(&mut self) -> Result<()> {
61        self.embedding_matrix.fill(T::zero());
62        Ok(())
63    }
64}
65
66/// Layer normalization implementation
67pub struct LayerNormalization<T: Float + Debug + Send + Sync + 'static> {
68    /// Layer dimension
69    dimension: usize,
70
71    /// Learnable scale parameters
72    gamma: Array1<T>,
73
74    /// Learnable shift parameters
75    beta: Array1<T>,
76
77    /// Small constant for numerical stability
78    epsilon: T,
79}
80
81impl<T: Float + Debug + Send + Sync + 'static> LayerNormalization<T> {
82    /// Create new layer normalization
83    pub fn new(dimension: usize) -> Result<Self> {
84        let gamma = Array1::ones(dimension);
85        let beta = Array1::zeros(dimension);
86        let epsilon = scirs2_core::numeric::NumCast::from(1e-5).unwrap_or_else(|| T::zero());
87
88        Ok(Self {
89            dimension,
90            gamma,
91            beta,
92            epsilon,
93        })
94    }
95
96    /// Forward pass through layer normalization
97    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
98        let mut output = input.clone();
99        let batch_size = input.shape()[0];
100
101        for i in 0..batch_size {
102            let row = input.row(i);
103
104            // Calculate mean
105            let mean = row.sum() / T::from(row.len()).unwrap();
106
107            // Calculate variance
108            let variance = row
109                .iter()
110                .map(|&x| {
111                    let diff = x - mean;
112                    diff * diff
113                })
114                .fold(T::zero(), |acc, x| acc + x)
115                / T::from(row.len()).unwrap();
116
117            // Normalize
118            let std_dev = (variance + self.epsilon).sqrt();
119
120            for j in 0..self.dimension {
121                let normalized = (input[[i, j]] - mean) / std_dev;
122                output[[i, j]] = self.gamma[j] * normalized + self.beta[j];
123            }
124        }
125
126        Ok(output)
127    }
128
129    /// Get parameter count
130    pub fn parameter_count(&self) -> usize {
131        2 * self.dimension // gamma + beta
132    }
133
134    /// Reset normalization parameters
135    pub fn reset(&mut self) -> Result<()> {
136        self.gamma.fill(T::one());
137        self.beta.fill(T::zero());
138        Ok(())
139    }
140}
141
142/// Dropout layer for regularization
143pub struct DropoutLayer {
144    /// Dropout probability
145    dropout_rate: f64,
146
147    /// Training mode flag
148    training: bool,
149}
150
151impl DropoutLayer {
152    /// Create new dropout layer
153    pub fn new(dropout_rate: f64) -> Self {
154        Self {
155            dropout_rate,
156            training: true,
157        }
158    }
159
160    /// Forward pass through dropout
161    pub fn forward<T: Float + Debug + Send + Sync + 'static>(
162        &self,
163        input: &Array2<T>,
164    ) -> Array2<T> {
165        if !self.training || self.dropout_rate == 0.0 {
166            return input.clone();
167        }
168
169        let mut output = input.clone();
170        let keep_prob = 1.0 - self.dropout_rate;
171        let scale =
172            scirs2_core::numeric::NumCast::from(1.0 / keep_prob).unwrap_or_else(|| T::zero());
173
174        for elem in output.iter_mut() {
175            if scirs2_core::random::random::<f64>() < self.dropout_rate {
176                *elem = T::zero();
177            } else {
178                *elem = *elem * scale;
179            }
180        }
181
182        output
183    }
184
185    /// Set training mode
186    pub fn set_training(&mut self, training: bool) {
187        self.training = training;
188    }
189
190    /// Get training mode
191    pub fn is_training(&self) -> bool {
192        self.training
193    }
194}
195
196/// Output projection layer
197pub struct OutputProjection<T: Float + Debug + Send + Sync + 'static> {
198    /// Weight matrix
199    weight: Array2<T>,
200
201    /// Bias vector
202    bias: Array1<T>,
203
204    /// Input dimension
205    input_dim: usize,
206
207    /// Output dimension
208    output_dim: usize,
209}
210
211impl<T: Float + Debug + Send + Sync + 'static> OutputProjection<T> {
212    /// Create new output projection
213    pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
214        let weight = Array2::zeros((input_dim, output_dim));
215        let bias = Array1::zeros(output_dim);
216
217        Ok(Self {
218            weight,
219            bias,
220            input_dim,
221            output_dim,
222        })
223    }
224
225    /// Forward pass through projection
226    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
227        let batch_size = input.shape()[0];
228        let mut output = Array2::zeros((batch_size, self.output_dim));
229
230        // Matrix multiplication: input @ weight + bias
231        for i in 0..batch_size {
232            for j in 0..self.output_dim {
233                let mut sum = self.bias[j];
234                for k in 0..self.input_dim {
235                    sum = sum + input[[i, k]] * self.weight[[k, j]];
236                }
237                output[[i, j]] = sum;
238            }
239        }
240
241        Ok(output)
242    }
243
244    /// Get parameter count
245    pub fn parameter_count(&self) -> usize {
246        self.input_dim * self.output_dim + self.output_dim
247    }
248
249    /// Reset projection parameters
250    pub fn reset(&mut self) -> Result<()> {
251        self.weight.fill(T::zero());
252        self.bias.fill(T::zero());
253        Ok(())
254    }
255}
256
257/// Residual connections manager
258pub struct ResidualConnections<T: Float + Debug + Send + Sync + 'static> {
259    /// Model dimension
260    dimension: usize,
261
262    /// Optional learnable scaling factor
263    scale_factor: Option<T>,
264}
265
266impl<T: Float + Debug + Send + Sync + 'static> ResidualConnections<T> {
267    /// Create new residual connections
268    pub fn new(dimension: usize) -> Self {
269        Self {
270            dimension,
271            scale_factor: None,
272        }
273    }
274
275    /// Create with learnable scaling
276    pub fn new_with_scaling(dimension: usize, initial_scale: T) -> Self {
277        Self {
278            dimension,
279            scale_factor: Some(initial_scale),
280        }
281    }
282
283    /// Add residual connection
284    pub fn add(&self, input: &Array2<T>, residual: &Array2<T>) -> Result<Array2<T>> {
285        if input.shape() != residual.shape() {
286            return Err(crate::error::OptimError::Other(
287                "Shape mismatch in residual connection".to_string(),
288            ));
289        }
290
291        let mut output = input + residual;
292
293        if let Some(scale) = self.scale_factor {
294            output.mapv_inplace(|x| x * scale);
295        }
296
297        Ok(output)
298    }
299
300    /// Set scaling factor
301    pub fn set_scale_factor(&mut self, scale: T) {
302        self.scale_factor = Some(scale);
303    }
304
305    /// Get scaling factor
306    pub fn get_scale_factor(&self) -> Option<T> {
307        self.scale_factor
308    }
309}
310
311/// Activation layer with various activation functions
312pub struct ActivationLayer;
313
314impl ActivationLayer {
315    /// Apply activation function
316    pub fn apply<T: Float + Debug + Send + Sync + 'static>(
317        input: &Array2<T>,
318        activation: ActivationFunction,
319    ) -> Array2<T> {
320        match activation {
321            ActivationFunction::ReLU => Self::relu(input),
322            ActivationFunction::GELU => Self::gelu(input),
323            ActivationFunction::Swish => Self::swish(input),
324            ActivationFunction::Tanh => Self::tanh(input),
325            ActivationFunction::Sigmoid => Self::sigmoid(input),
326            ActivationFunction::LeakyReLU => Self::leaky_relu(
327                input,
328                scirs2_core::numeric::NumCast::from(0.01).unwrap_or_else(|| T::zero()),
329            ),
330        }
331    }
332
333    /// ReLU activation
334    fn relu<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
335        input.map(|&x| if x > T::zero() { x } else { T::zero() })
336    }
337
338    /// GELU activation (approximation)
339    fn gelu<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
340        input.map(|&x| {
341            let half = scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero());
342            let one = T::one();
343            let sqrt_2_pi =
344                scirs2_core::numeric::NumCast::from(0.797884560802865).unwrap_or_else(|| T::zero()); // sqrt(2/π)
345            let coeff = scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
346
347            let tanh_arg = sqrt_2_pi * (x + coeff * x * x * x);
348            let tanh_val = tanh_arg.tanh();
349
350            half * x * (one + tanh_val)
351        })
352    }
353
354    /// Swish activation
355    fn swish<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
356        input.map(|&x| x * Self::sigmoid_scalar(x))
357    }
358
359    /// Tanh activation
360    fn tanh<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
361        input.map(|&x| x.tanh())
362    }
363
364    /// Sigmoid activation
365    fn sigmoid<T: Float + Debug + Send + Sync + 'static>(input: &Array2<T>) -> Array2<T> {
366        input.map(|&x| Self::sigmoid_scalar(x))
367    }
368
369    /// Leaky ReLU activation
370    fn leaky_relu<T: Float + Debug + Send + Sync + 'static>(
371        input: &Array2<T>,
372        alpha: T,
373    ) -> Array2<T> {
374        input.map(|&x| if x > T::zero() { x } else { alpha * x })
375    }
376
377    /// Sigmoid scalar function
378    fn sigmoid_scalar<T: Float + Debug + Send + Sync + 'static>(x: T) -> T {
379        let one = T::one();
380        one / (one + (-x).exp())
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn test_embedding_layer() {
390        let embedding = EmbeddingLayer::<f32>::new(100, 64);
391        assert!(embedding.is_ok());
392
393        let emb = embedding.unwrap();
394        assert_eq!(emb.parameter_count(), 100 * 64);
395    }
396
397    #[test]
398    fn test_layer_normalization() {
399        let layer_norm = LayerNormalization::<f32>::new(128);
400        assert!(layer_norm.is_ok());
401
402        let ln = layer_norm.unwrap();
403        let input = Array2::<f32>::ones((2, 128));
404        let result = ln.forward(&input);
405        assert!(result.is_ok());
406    }
407
408    #[test]
409    fn test_dropout_layer() {
410        let mut dropout = DropoutLayer::new(0.5);
411        let input = Array2::<f32>::ones((4, 128));
412
413        dropout.set_training(false);
414        let output = dropout.forward(&input);
415        assert_eq!(output, input);
416
417        dropout.set_training(true);
418        let output = dropout.forward(&input);
419        assert_eq!(output.shape(), input.shape());
420    }
421
422    #[test]
423    fn test_output_projection() {
424        let projection = OutputProjection::<f32>::new(128, 64);
425        assert!(projection.is_ok());
426
427        let proj = projection.unwrap();
428        let input = Array2::<f32>::zeros((2, 128));
429        let result = proj.forward(&input);
430        assert!(result.is_ok());
431
432        let output = result.unwrap();
433        assert_eq!(output.shape(), &[2, 64]);
434    }
435
436    #[test]
437    fn test_residual_connections() {
438        let residual = ResidualConnections::<f32>::new(64);
439        let input = Array2::<f32>::ones((2, 64));
440        let res_input = Array2::<f32>::ones((2, 64)) * 0.5;
441
442        let result = residual.add(&input, &res_input);
443        assert!(result.is_ok());
444
445        let output = result.unwrap();
446        assert_eq!(output[[0, 0]], 1.5);
447    }
448
449    #[test]
450    fn test_activation_functions() {
451        let input = Array2::<f32>::from_shape_vec((2, 2), vec![-1.0, 0.0, 0.5, 1.0]).unwrap();
452
453        let relu_output = ActivationLayer::apply(&input, ActivationFunction::ReLU);
454        assert_eq!(relu_output[[0, 0]], 0.0);
455        assert_eq!(relu_output[[1, 1]], 1.0);
456
457        let gelu_output = ActivationLayer::apply(&input, ActivationFunction::GELU);
458        assert_eq!(gelu_output.shape(), input.shape());
459
460        let sigmoid_output = ActivationLayer::apply(&input, ActivationFunction::Sigmoid);
461        assert!(sigmoid_output.iter().all(|&x| (0.0..=1.0).contains(&x)));
462    }
463}