Skip to main content

optirs_learned/transformer/architecture/
feedforward.rs

1use std::fmt::Debug;
2// Feed-forward network components for transformer layers
3//
4// This module implements the feed-forward network (FFN) layers used in
5// transformer encoder/decoder blocks, including various activation functions
6// and output projection layers.
7
8#[allow(dead_code)]
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use scirs2_core::random::{Random, Rng as SCRRng};
12
13use super::super::TransformerOptimizerConfig;
14use crate::error::{OptimError, Result};
15
16/// Output transformation types
17#[derive(Debug, Clone, Copy)]
18pub enum OutputTransformation {
19    /// Linear transformation
20    Linear,
21    /// Tanh activation
22    Tanh,
23    /// Sigmoid activation
24    Sigmoid,
25    /// Learned activation
26    LearnedActivation,
27    /// Parameter-specific scaling
28    ParameterScaling,
29}
30
31/// Output projection layer for final transformer output
32#[derive(Debug, Clone)]
33pub struct OutputProjectionLayer<T: Float + Debug + Send + Sync + 'static> {
34    /// Projection weights
35    weights: Array2<T>,
36
37    /// Projection bias
38    bias: Array1<T>,
39
40    /// Output transformation
41    transformation: OutputTransformation,
42}
43
44/// Input embedding layer for transformer input processing
45#[derive(Debug, Clone)]
46pub struct InputEmbedding<T: Float + Debug + Send + Sync + 'static> {
47    /// Embedding weights
48    weights: Array2<T>,
49
50    /// Input dimension
51    input_dim: usize,
52
53    /// Model dimension
54    modeldim: usize,
55}
56
57impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> OutputProjectionLayer<T> {
58    /// Create new output projection layer
59    pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
60        let mut rng = scirs2_core::random::thread_rng();
61        let mut weights = Array2::zeros((input_dim, output_dim));
62
63        // Xavier initialization
64        let bound = (6.0 / (input_dim + output_dim) as f64).sqrt();
65        for elem in weights.iter_mut() {
66            *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
67        }
68
69        let bias = Array1::zeros(output_dim);
70
71        Ok(Self {
72            weights,
73            bias,
74            transformation: OutputTransformation::Linear,
75        })
76    }
77
78    /// Create with specific transformation
79    pub fn new_with_transformation(
80        input_dim: usize,
81        output_dim: usize,
82        transformation: OutputTransformation,
83    ) -> Result<Self> {
84        let mut layer = Self::new(input_dim, output_dim)?;
85        layer.transformation = transformation;
86        Ok(layer)
87    }
88
89    /// Forward pass through output projection
90    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
91        let (seq_len, input_dim) = input.dim();
92        let (weight_in, weight_out) = self.weights.dim();
93
94        if input_dim != weight_in {
95            return Err(OptimError::InvalidConfig(
96                "Input dimension doesn't match weight matrix".to_string(),
97            ));
98        }
99
100        let mut output = Array2::zeros((seq_len, weight_out));
101
102        // Linear transformation
103        for i in 0..seq_len {
104            for j in 0..weight_out {
105                let mut sum = T::zero();
106                for k in 0..input_dim {
107                    sum = sum + input[[i, k]] * self.weights[[k, j]];
108                }
109                output[[i, j]] = sum + self.bias[j];
110            }
111        }
112
113        // Apply output transformation
114        match self.transformation {
115            OutputTransformation::Linear => {
116                // No additional transformation
117            }
118            OutputTransformation::Tanh => {
119                output.mapv_inplace(|x| x.tanh());
120            }
121            OutputTransformation::Sigmoid => {
122                output.mapv_inplace(|x| T::one() / (T::one() + (-x).exp()));
123            }
124            OutputTransformation::LearnedActivation => {
125                // For now, use a simple learned scaling
126                output.mapv_inplace(|x| {
127                    x * scirs2_core::numeric::NumCast::from(1.1).unwrap_or_else(|| T::zero())
128                });
129            }
130            OutputTransformation::ParameterScaling => {
131                // Apply different scaling per parameter dimension
132                for j in 0..weight_out {
133                    let scale = T::from(1.0 + 0.1 * (j as f64).sin()).expect("unwrap failed");
134                    for i in 0..seq_len {
135                        output[[i, j]] = output[[i, j]] * scale;
136                    }
137                }
138            }
139        }
140
141        Ok(output)
142    }
143
144    /// Get output transformation type
145    pub fn transformation(&self) -> OutputTransformation {
146        self.transformation
147    }
148
149    /// Set output transformation type
150    pub fn set_transformation(&mut self, transformation: OutputTransformation) {
151        self.transformation = transformation;
152    }
153
154    /// Get projection weights
155    pub fn weights(&self) -> &Array2<T> {
156        &self.weights
157    }
158
159    /// Get projection bias
160    pub fn bias(&self) -> &Array1<T> {
161        &self.bias
162    }
163
164    /// Update weights and bias
165    pub fn update_parameters(&mut self, weights: Array2<T>, bias: Array1<T>) -> Result<()> {
166        let (weight_in, weight_out) = weights.dim();
167        let bias_dim = bias.len();
168
169        if weight_out != bias_dim {
170            return Err(OptimError::InvalidConfig(
171                "Weight output dimension doesn't match bias dimension".to_string(),
172            ));
173        }
174
175        // Update internal dimensions if they match
176        if (weight_in, weight_out) == self.weights.dim() && bias_dim == self.bias.len() {
177            self.weights = weights;
178            self.bias = bias;
179            Ok(())
180        } else {
181            Err(OptimError::InvalidConfig(
182                "New parameter dimensions don't match current layer dimensions".to_string(),
183            ))
184        }
185    }
186}
187
188impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> InputEmbedding<T> {
189    /// Create new input embedding layer
190    pub fn new(input_dim: usize, model_dim: usize) -> Result<Self> {
191        let mut rng = scirs2_core::random::thread_rng();
192        let mut weights = Array2::zeros((input_dim, model_dim));
193
194        // Xavier initialization
195        let bound = (6.0 / (input_dim + model_dim) as f64).sqrt();
196        for elem in weights.iter_mut() {
197            *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound).expect("unwrap failed");
198        }
199
200        Ok(Self {
201            weights,
202            input_dim,
203            modeldim: model_dim,
204        })
205    }
206
207    /// Forward pass through input embedding
208    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
209        let (seq_len, input_dim) = input.dim();
210
211        if input_dim != self.input_dim {
212            return Err(OptimError::InvalidConfig(format!(
213                "Input dimension {} doesn't match embedding input dimension {}",
214                input_dim, self.input_dim
215            )));
216        }
217
218        let mut output = Array2::zeros((seq_len, self.modeldim));
219
220        // Linear transformation
221        for i in 0..seq_len {
222            for j in 0..self.modeldim {
223                let mut sum = T::zero();
224                for k in 0..self.input_dim {
225                    sum = sum + input[[i, k]] * self.weights[[k, j]];
226                }
227                output[[i, j]] = sum;
228            }
229        }
230
231        Ok(output)
232    }
233
234    /// Get input dimension
235    pub fn input_dim(&self) -> usize {
236        self.input_dim
237    }
238
239    /// Get model dimension
240    pub fn model_dim(&self) -> usize {
241        self.modeldim
242    }
243
244    /// Get embedding weights
245    pub fn weights(&self) -> &Array2<T> {
246        &self.weights
247    }
248
249    /// Update embedding weights
250    pub fn update_weights(&mut self, weights: Array2<T>) -> Result<()> {
251        let (weight_in, weight_out) = weights.dim();
252
253        if weight_in != self.input_dim || weight_out != self.modeldim {
254            return Err(OptimError::InvalidConfig(
255                "New weight dimensions don't match embedding dimensions".to_string(),
256            ));
257        }
258
259        self.weights = weights;
260        Ok(())
261    }
262}