optirs_learned/transformer/architecture/
encoder.rs1use std::fmt::Debug;
2#[allow(dead_code)]
8use scirs2_core::ndarray::{s, Array1, Array2};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::{CoreRandom as Random, Rng as SCRRng};
11
12use super::super::TransformerOptimizerConfig;
13use super::attention::MultiHeadAttention;
14use crate::error::{OptimError, Result};
15
16#[derive(Debug, Clone, Copy)]
18pub enum ActivationFunction {
19 ReLU,
21 GELU,
23 Swish,
25 GLU,
27 GeGLU,
29}
30
31#[derive(Debug, Clone)]
33pub struct TransformerLayer<T: Float + Debug + Send + Sync + 'static> {
34 self_attention: MultiHeadAttention<T>,
36
37 cross_attention: Option<MultiHeadAttention<T>>,
39
40 feed_forward: FeedForwardNetwork<T>,
42
43 ln1: LayerNorm<T>,
45 ln2: LayerNorm<T>,
46 ln3: Option<LayerNorm<T>>, dropout1: DropoutLayer,
50 dropout2: DropoutLayer,
51 dropout3: Option<DropoutLayer>,
52
53 pre_layer_norm: bool,
55}
56
57#[derive(Debug, Clone)]
59pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
60 linear1: Array2<T>,
62
63 bias1: Array1<T>,
65
66 linear2: Array2<T>,
68
69 bias2: Array1<T>,
71
72 activation: ActivationFunction,
74
75 dropout: DropoutLayer,
77}
78
79#[derive(Debug, Clone)]
81pub struct LayerNorm<T: Float + Debug + Send + Sync + 'static> {
82 gamma: Array1<T>,
84
85 beta: Array1<T>,
87
88 eps: T,
90
91 dim: usize,
93}
94
95#[derive(Debug, Clone)]
97pub struct DropoutLayer {
98 prob: f64,
100
101 training: bool,
103}
104
105impl<T: Float + Debug + Default + Clone + std::iter::Sum + Send + Sync> TransformerLayer<T> {
106 pub fn new(config: &TransformerOptimizerConfig, _rng: &mut Random) -> Result<Self> {
107 let self_attention = MultiHeadAttention::new(config)?;
108 let cross_attention = if config.cross_attention {
109 Some(MultiHeadAttention::new(config)?)
110 } else {
111 None
112 };
113
114 let feed_forward = FeedForwardNetwork::new(config)?;
115
116 let ln1 = LayerNorm::new(config.modeldim);
117 let ln2 = LayerNorm::new(config.modeldim);
118 let ln3 = if config.cross_attention {
119 Some(LayerNorm::new(config.modeldim))
120 } else {
121 None
122 };
123
124 let dropout1 = DropoutLayer::new(config.attention_dropout);
125 let dropout2 = DropoutLayer::new(config.ff_dropout);
126 let dropout3 = if config.cross_attention {
127 Some(DropoutLayer::new(config.attention_dropout))
128 } else {
129 None
130 };
131
132 Ok(Self {
133 self_attention,
134 cross_attention,
135 feed_forward,
136 ln1,
137 ln2,
138 ln3,
139 dropout1,
140 dropout2,
141 dropout3,
142 pre_layer_norm: config.pre_layer_norm,
143 })
144 }
145
146 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
147 let mut x = input.clone();
148
149 let residual = x.clone();
151 if self.pre_layer_norm {
152 x = self.ln1.forward(&x)?;
153 }
154
155 x = self.self_attention.forward(&x, &x, &x)?;
156 x = self.dropout1.forward(&x)?;
157 x = x + &residual;
158
159 if !self.pre_layer_norm {
160 x = self.ln1.forward(&x)?;
161 }
162
163 if let Some(ref mut cross_attn) = self.cross_attention {
165 let residual = x.clone();
166 if self.pre_layer_norm {
167 if let Some(ref ln3) = self.ln3 {
168 x = ln3.forward(&x)?;
169 }
170 }
171
172 x = cross_attn.forward(&x, &x, &x)?;
174 if let Some(ref dropout3) = self.dropout3 {
175 x = dropout3.forward(&x)?;
176 }
177 x = x + &residual;
178
179 if !self.pre_layer_norm {
180 if let Some(ref ln3) = self.ln3 {
181 x = ln3.forward(&x)?;
182 }
183 }
184 }
185
186 let residual = x.clone();
188 if self.pre_layer_norm {
189 x = self.ln2.forward(&x)?;
190 }
191
192 x = self.feed_forward.forward(&x)?;
193 x = self.dropout2.forward(&x)?;
194 x = x + &residual;
195
196 if !self.pre_layer_norm {
197 x = self.ln2.forward(&x)?;
198 }
199
200 Ok(x)
201 }
202
203 pub fn get_attention_patterns(&self) -> Option<&scirs2_core::ndarray::Array3<T>> {
205 self.self_attention.get_attention_patterns()
206 }
207}
208
209impl<T: Float + Debug + Default + Clone + Send + Sync + 'static> FeedForwardNetwork<T> {
210 pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
211 let modeldim = config.modeldim;
212 let ff_dim = config.ff_dim;
213 let mut rng = scirs2_core::random::thread_rng();
214
215 let bound1 = (6.0 / (modeldim + ff_dim) as f64).sqrt();
217 let bound2 = (6.0 / (ff_dim + modeldim) as f64).sqrt();
218
219 let mut linear1 = Array2::zeros((modeldim, ff_dim));
220 let mut linear2 = Array2::zeros((ff_dim, modeldim));
221
222 for elem in linear1.iter_mut() {
223 *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound1).unwrap();
224 }
225 for elem in linear2.iter_mut() {
226 *elem = T::from((rng.random::<f64>() - 0.5) * 2.0 * bound2).unwrap();
227 }
228
229 let bias1 = Array1::zeros(ff_dim);
230 let bias2 = Array1::zeros(modeldim);
231
232 Ok(Self {
233 linear1,
234 bias1,
235 linear2,
236 bias2,
237 activation: ActivationFunction::GELU,
238 dropout: DropoutLayer::new(config.ff_dropout),
239 })
240 }
241
242 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
243 let x1 = self.linear_transform(input, &self.linear1, &self.bias1)?;
245
246 let x2 = self.apply_activation(&x1)?;
248
249 let x3 = self.dropout.forward(&x2)?;
251
252 let output = self.linear_transform(&x3, &self.linear2, &self.bias2)?;
254
255 Ok(output)
256 }
257
258 fn linear_transform(
259 &self,
260 input: &Array2<T>,
261 weights: &Array2<T>,
262 bias: &Array1<T>,
263 ) -> Result<Array2<T>> {
264 let (seq_len, input_dim) = input.dim();
265 let (weight_in, weight_out) = weights.dim();
266
267 if input_dim != weight_in {
268 return Err(OptimError::InvalidConfig(
269 "Input dimension doesn't match weight matrix".to_string(),
270 ));
271 }
272
273 if bias.len() != weight_out {
274 return Err(OptimError::InvalidConfig(
275 "Bias dimension doesn't match output dimension".to_string(),
276 ));
277 }
278
279 let mut output = Array2::zeros((seq_len, weight_out));
280
281 for i in 0..seq_len {
282 for j in 0..weight_out {
283 let mut sum = T::zero();
284 for k in 0..input_dim {
285 sum = sum + input[[i, k]] * weights[[k, j]];
286 }
287 output[[i, j]] = sum + bias[j];
288 }
289 }
290
291 Ok(output)
292 }
293
294 fn apply_activation(&self, input: &Array2<T>) -> Result<Array2<T>> {
295 let mut output = input.clone();
296
297 match self.activation {
298 ActivationFunction::ReLU => {
299 output.mapv_inplace(|x| if x > T::zero() { x } else { T::zero() });
300 }
301 ActivationFunction::GELU => {
302 output.mapv_inplace(|x| {
304 let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
305 .unwrap_or_else(|| T::zero()); let coeff =
307 scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
308 let x_cubed = x * x * x;
309 let inner = sqrt_2_pi * (x + coeff * x_cubed);
310 let tanh_val = inner.tanh();
311 scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
312 * x
313 * (T::one() + tanh_val)
314 });
315 }
316 ActivationFunction::Swish => {
317 output.mapv_inplace(|x| x * x.exp() / (T::one() + x.exp()));
318 }
319 ActivationFunction::GLU => {
320 output.mapv_inplace(|x| {
322 let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
323 .unwrap_or_else(|| T::zero());
324 let coeff =
325 scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
326 let x_cubed = x * x * x;
327 let inner = sqrt_2_pi * (x + coeff * x_cubed);
328 let tanh_val = inner.tanh();
329 scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
330 * x
331 * (T::one() + tanh_val)
332 });
333 }
334 ActivationFunction::GeGLU => {
335 output.mapv_inplace(|x| {
337 let sqrt_2_pi = scirs2_core::numeric::NumCast::from(0.7978845608)
338 .unwrap_or_else(|| T::zero());
339 let coeff =
340 scirs2_core::numeric::NumCast::from(0.044715).unwrap_or_else(|| T::zero());
341 let x_cubed = x * x * x;
342 let inner = sqrt_2_pi * (x + coeff * x_cubed);
343 let tanh_val = inner.tanh();
344 scirs2_core::numeric::NumCast::from(0.5).unwrap_or_else(|| T::zero())
345 * x
346 * (T::one() + tanh_val)
347 });
348 }
349 }
350
351 Ok(output)
352 }
353
354 pub fn set_activation(&mut self, activation: ActivationFunction) {
356 self.activation = activation;
357 }
358}
359
360impl<T: Float + Debug + Default + Clone + std::iter::Sum + Send + Sync> LayerNorm<T> {
361 pub fn new(dim: usize) -> Self {
362 Self {
363 gamma: Array1::ones(dim),
364 beta: Array1::zeros(dim),
365 eps: scirs2_core::numeric::NumCast::from(1e-6).unwrap_or_else(|| T::zero()),
366 dim,
367 }
368 }
369
370 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
371 let (seq_len, input_dim) = input.dim();
372
373 if input_dim != self.dim {
374 return Err(OptimError::InvalidConfig(format!(
375 "Input dimension {} doesn't match layer norm dimension {}",
376 input_dim, self.dim
377 )));
378 }
379
380 let mut output = Array2::zeros((seq_len, input_dim));
381
382 for i in 0..seq_len {
383 let row = input.slice(s![i, ..]);
384
385 let mean = row.iter().cloned().sum::<T>()
387 / scirs2_core::numeric::NumCast::from(input_dim).unwrap_or_else(|| T::zero());
388
389 let variance = row
391 .iter()
392 .map(|&x| {
393 let diff = x - mean;
394 diff * diff
395 })
396 .sum::<T>()
397 / scirs2_core::numeric::NumCast::from(input_dim).unwrap_or_else(|| T::zero());
398
399 let std = (variance + self.eps).sqrt();
400
401 for j in 0..input_dim {
403 let normalized = (input[[i, j]] - mean) / std;
404 output[[i, j]] = self.gamma[j] * normalized + self.beta[j];
405 }
406 }
407
408 Ok(output)
409 }
410
411 pub fn parameters(&self) -> (&Array1<T>, &Array1<T>) {
413 (&self.gamma, &self.beta)
414 }
415
416 pub fn set_parameters(&mut self, gamma: Array1<T>, beta: Array1<T>) -> Result<()> {
418 if gamma.len() != self.dim || beta.len() != self.dim {
419 return Err(OptimError::InvalidConfig(
420 "Parameter dimensions don't match layer norm dimension".to_string(),
421 ));
422 }
423 self.gamma = gamma;
424 self.beta = beta;
425 Ok(())
426 }
427}
428
429impl DropoutLayer {
430 pub fn new(prob: f64) -> Self {
431 Self {
432 prob,
433 training: true,
434 }
435 }
436
437 pub fn forward<T: Float + Clone>(&self, input: &Array2<T>) -> Result<Array2<T>> {
438 if !self.training || self.prob == 0.0 {
439 return Ok(input.clone());
440 }
441
442 Ok(input.clone())
445 }
446
447 pub fn set_training(&mut self, training: bool) {
449 self.training = training;
450 }
451
452 pub fn prob(&self) -> f64 {
454 self.prob
455 }
456}