optirs_learned/transformer/architecture/
encoder.rs

1use std::fmt::Debug;
2// Transformer encoder layers and components
3//
4// This module implements the encoder components of the transformer optimizer,
5// including the transformer layer, feed-forward network, and layer normalization.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{CoreRandom as Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use super::attention::MultiHeadAttention;
14use crate::error::{OptimError, Result};
15
16/// Activation functions for feed-forward networks
17#[derive(Debug, Clone, Copy)]
18pub enum ActivationFunction {
19    /// ReLU activation
20    ReLU,
21    /// GELU activation  
22    GELU,
23    /// Swish/SiLU activation
24    Swish,
25    /// GLU (Gated Linear Unit)
26    GLU,
27    /// GeGLU (GELU variant of GLU)
28    GeGLU,
29}
30
31/// Single transformer encoder layer
32#[derive(Debug, Clone)]
33pub struct TransformerLayer<T: Float + Debug + Send + Sync + 'static> {
34    /// Multi-head self-attention
35    self_attention: MultiHeadAttention<T>,
36
37    /// Cross-attention (for multi-task learning)
38    cross_attention: Option<MultiHeadAttention<T>>,
39
40    /// Feed-forward network
41    feed_forward: FeedForwardNetwork<T>,
42
43    /// Layer normalization layers
44    ln1: LayerNorm<T>,
45    ln2: LayerNorm<T>,
46    ln3: Option<LayerNorm<T>>, // For cross-attention
47
48    /// Dropout layers
49    dropout1: DropoutLayer,
50    dropout2: DropoutLayer,
51    dropout3: Option<DropoutLayer>,
52
53    /// Use pre-layer normalization
54    pre_layer_norm: bool,
55}
56
57/// Feed-forward network
58#[derive(Debug, Clone)]
59pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
60    /// First linear layer weights
61    linear1: Array2<T>,
62
63    /// First linear layer bias
64    bias1: Array1<T>,
65
66    /// Second linear layer weights
67    linear2: Array2<T>,
68
69    /// Second linear layer bias
70    bias2: Array1<T>,
71
72    /// Activation function
73    activation: ActivationFunction,
74
75    /// Dropout layer
76    dropout: DropoutLayer,
77}
78
79/// Layer normalization
80#[derive(Debug, Clone)]
81pub struct LayerNorm<T: Float + Debug + Send + Sync + 'static> {
82    /// Scale parameters (gamma)
83    gamma: Array1<T>,
84
85    /// Shift parameters (beta)
86    beta: Array1<T>,
87
88    /// Epsilon for numerical stability
89    eps: T,
90
91    /// Dimension
92    dim: usize,
93}
94
95/// Dropout layer
96#[derive(Debug, Clone)]
97pub struct DropoutLayer {
98    /// Dropout probability
99    prob: f64,
100
101    /// Training mode
102    training: bool,
103}
104
105impl<T: Float + Debug + Default + Clone + std::iter::Sum + Send + Sync> TransformerLayer<T> {
106    pub fn new(config: &TransformerOptimizerConfig, _rng: &mut Random) -> Result<Self> {
107        let self_attention = MultiHeadAttention::new(config)?;
108        let cross_attention = if config.cross_attention {
109            Some(MultiHeadAttention::new(config)?)
110        } else {
111            None
112        };
113
114        let feed_forward = FeedForwardNetwork::new(config)?;
115
116        let ln1 = LayerNorm::new(config.modeldim);
117        let ln2 = LayerNorm::new(config.modeldim);
118        let ln3 = if config.cross_attention {
119            Some(LayerNorm::new(config.modeldim))
120        } else {
121            None
122        };
123
124        let dropout1 = DropoutLayer::new(config.attention_dropout);
125        let dropout2 = DropoutLayer::new(config.ff_dropout);
126        let dropout3 = if config.cross_attention {
127            Some(DropoutLayer::new(config.attention_dropout))
128        } else {
129            None
130        };
131
132        Ok(Self {
133            self_attention,
134            cross_attention,
135            feed_forward,
136            ln1,
137            ln2,
138            ln3,
139            dropout1,
140            dropout2,
141            dropout3,
142            pre_layer_norm: config.pre_layer_norm,
143        })
144    }
145
146    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
147        let mut x = input.clone();
148
149        // Self-attention with residual connection
150        let residual = x.clone();
151        if self.pre_layer_norm {
152            x = self.ln1.forward(&x)?;
153        }
154
155        x = self.self_attention.forward(&x, &x, &x)?;
156        x = self.dropout1.forward(&x)?;
157        x = x + &residual;
158
159        if !self.pre_layer_norm {
160            x = self.ln1.forward(&x)?;
161        }
162
163        // Cross-attention (if enabled)
164        if let Some(ref mut cross_attn) = self.cross_attention {
165            let residual = x.clone();
166            if self.pre_layer_norm {
167                if let Some(ref ln3) = self.ln3 {
168                    x = ln3.forward(&x)?;
169                }
170            }
171
172            // For now, use same input as key/value for cross-attention
173            x = cross_attn.forward(&x, &x, &x)?;
174            if let Some(ref dropout3) = self.dropout3 {
175                x = dropout3.forward(&x)?;
176            }
177            x = x + &residual;
178
179            if !self.pre_layer_norm {
180                if let Some(ref ln3) = self.ln3 {
181                    x = ln3.forward(&x)?;
182                }
183            }
184        }
185
186        // Feed-forward with residual connection
187        let residual = x.clone();
188        if self.pre_layer_norm {
189            x = self.ln2.forward(&x)?;
190        }
191
192        x = self.feed_forward.forward(&x)?;
193        x = self.dropout2.forward(&x)?;
194        x = x + &residual;
195
196        if !self.pre_layer_norm {
197            x = self.ln2.forward(&x)?;
198        }
199
200        Ok(x)
201    }
202
203    /// Get attention patterns for analysis
204    pub fn get_attention_patterns(&self) -> Option<&scirs2_core::ndarray::Array3<T>> {
205        self.self_attention.get_attention_patterns()
206    }
207}
208
209impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> FeedForwardNetwork<T> {
210    pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
211        let modeldim = config.modeldim;
212        let ff_dim = config.ff_dim;
213        let mut rng = scirs2_core::random::thread_rng();
214
215        // Initialize weights with Xavier initialization
216        let bound1 = (6.0 / (modeldim + ff_dim) as f64).sqrt();
217        let bound2 = (6.0 / (ff_dim + modeldim) as f64).sqrt();
218
219        let mut linear1 = Array2::zeros((modeldim, ff_dim));
220        let mut linear2 = Array2::zeros((ff_dim, modeldim));
221
222        for elem in linear1.iter_mut() {
223            *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound1).unwrap();
224        }
225        for elem in linear2.iter_mut() {
226            *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound2).unwrap();
227        }
228
229        let bias1 = Array1::zeros(ff_dim);
230        let bias2 = Array1::zeros(modeldim);
231
232        Ok(Self {
233            linear1,
234            bias1,
235            linear2,
236            bias2,
237            activation: ActivationFunction::GELU,
238            dropout: DropoutLayer::new(config.ff_dropout),
239        })
240    }
241
242    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
243        // First linear layer
244        let x1 = self.linear_transform(input, &self.linear1, &self.bias1)?;
245
246        // Activation
247        let x2 = self.apply_activation(&x1)?;
248
249        // Dropout
250        let x3 = self.dropout.forward(&x2)?;
251
252        // Second linear layer
253        let output = self.linear_transform(&x3, &self.linear2, &self.bias2)?;
254
255        Ok(output)
256    }
257
258    fn linear_transform(
259        &self,
260        input: &Array2<T>,
261        weights: &Array2<T>,
262        bias: &Array1<T>,
263    ) -> Result<Array2<T>> {
264        let (seq_len, input_dim) = input.dim();
265        let (weight_in, weight_out) = weights.dim();
266
267        if input_dim != weight_in {
268            return Err(OptimError::InvalidConfig(
269                "Input dimension doesn't match weight matrix".to_string(),
270            ));
271        }
272
273        if bias.len() != weight_out {
274            return Err(OptimError::InvalidConfig(
275                "Bias dimension doesn't match output dimension".to_string(),
276            ));
277        }
278
279        let mut output = Array2::zeros((seq_len, weight_out));
280
281        for i in 0..seq_len {
282            for j in 0..weight_out {
283                let mut sum = T::zero();
284                for k in 0..input_dim {
285                    sum = sum + input[[i, k]] * weights[[k, j]];
286                }
287                output[[i, j]] = sum + bias[j];
288            }
289        }
290
291        Ok(output)
292    }
293
294    fn apply_activation(&self, input: &Array2<T>) -> Result<Array2<T>> {
295        let mut output = input.clone();
296
297        match self.activation {
298            ActivationFunction::ReLU => {
299                output.mapv_inplace(|x| if x > T::zero() { x } else { T::zero() });
300            }
301            ActivationFunction::GELU => {
302                // Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
303                output.mapv_inplace(|x| {
304                    let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
305                        .unwrap_or_else(|| T::zero()); // sqrt(2/π)
306                    let coeff =
307                        scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
308                    let x_cubed = x * x * x;
309                    let inner = sqrt_2_pi * (x + coeff * x_cubed);
310                    let tanh_val = inner.tanh();
311                    scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
312                        * x
313                        * (T::one() + tanh_val)
314                });
315            }
316            ActivationFunction::Swish => {
317                output.mapv_inplace(|x| x * x.exp() / (T::one() + x.exp()));
318            }
319            ActivationFunction::GLU => {
320                // For simplicity, treating as GELU for now
321                output.mapv_inplace(|x| {
322                    let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
323                        .unwrap_or_else(|| T::zero());
324                    let coeff =
325                        scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
326                    let x_cubed = x * x * x;
327                    let inner = sqrt_2_pi * (x + coeff * x_cubed);
328                    let tanh_val = inner.tanh();
329                    scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
330                        * x
331                        * (T::one() + tanh_val)
332                });
333            }
334            ActivationFunction::GeGLU => {
335                // For simplicity, treating as GELU for now
336                output.mapv_inplace(|x| {
337                    let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
338                        .unwrap_or_else(|| T::zero());
339                    let coeff =
340                        scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
341                    let x_cubed = x * x * x;
342                    let inner = sqrt_2_pi * (x + coeff * x_cubed);
343                    let tanh_val = inner.tanh();
344                    scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
345                        * x
346                        * (T::one() + tanh_val)
347                });
348            }
349        }
350
351        Ok(output)
352    }
353
354    /// Set activation function
355    pub fn set_activation(&mut self, activation: ActivationFunction) {
356        self.activation = activation;
357    }
358}
359
360impl<T: Float + Debug + Default + Clone + std::iter::Sum + Send + Sync> LayerNorm<T> {
361    pub fn new(dim: usize) -> Self {
362        Self {
363            gamma: Array1::ones(dim),
364            beta: Array1::zeros(dim),
365            eps: scirs2_core::numeric::NumCast::from(1e-6).unwrap_or_else(|| T::zero()),
366            dim,
367        }
368    }
369
370    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
371        let (seq_len, input_dim) = input.dim();
372
373        if input_dim != self.dim {
374            return Err(OptimError::InvalidConfig(format!(
375                "Input dimension {} doesn't match layer norm dimension {}",
376                input_dim, self.dim
377            )));
378        }
379
380        let mut output = Array2::zeros((seq_len, input_dim));
381
382        for i in 0..seq_len {
383            let row = input.slice(s![i, ..]);
384
385            // Compute mean
386            let mean = row.iter().cloned().sum::<T>()
387                / scirs2_core::numeric::NumCast::from(input_dim).unwrap_or_else(|| T::zero());
388
389            // Compute variance
390            let variance = row
391                .iter()
392                .map(|&x| {
393                    let diff = x - mean;
394                    diff * diff
395                })
396                .sum::<T>()
397                / scirs2_core::numeric::NumCast::from(input_dim).unwrap_or_else(|| T::zero());
398
399            let std = (variance + self.eps).sqrt();
400
401            // Normalize and scale/shift
402            for j in 0..input_dim {
403                let normalized = (input[[i, j]] - mean) / std;
404                output[[i, j]] = self.gamma[j] * normalized + self.beta[j];
405            }
406        }
407
408        Ok(output)
409    }
410
411    /// Get layer normalization parameters
412    pub fn parameters(&self) -> (&Array1<T>, &Array1<T>) {
413        (&self.gamma, &self.beta)
414    }
415
416    /// Set layer normalization parameters
417    pub fn set_parameters(&mut self, gamma: Array1<T>, beta: Array1<T>) -> Result<()> {
418        if gamma.len() != self.dim || beta.len() != self.dim {
419            return Err(OptimError::InvalidConfig(
420                "Parameter dimensions don't match layer norm dimension".to_string(),
421            ));
422        }
423        self.gamma = gamma;
424        self.beta = beta;
425        Ok(())
426    }
427}
428
429impl DropoutLayer {
430    pub fn new(prob: f64) -> Self {
431        Self {
432            prob,
433            training: true,
434        }
435    }
436
437    pub fn forward<T: Float + Clone>(&self, input: &Array2<T>) -> Result<Array2<T>> {
438        if !self.training || self.prob == 0.0 {
439            return Ok(input.clone());
440        }
441
442        // For simplicity, just return input during inference/testing
443        // In a full implementation, this would apply dropout during training
444        Ok(input.clone())
445    }
446
447    /// Set training mode
448    pub fn set_training(&mut self, training: bool) {
449        self.training = training;
450    }
451
452    /// Get dropout probability
453    pub fn prob(&self) -> f64 {
454        self.prob
455    }
456}