optirs_learned/transformer_based_optimizer/
feedforward.rs1use super::layers::ActivationLayer;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
11 linear1: LinearLayer<T>,
13
14 linear2: LinearLayer<T>,
16
17 activation: ActivationFunction,
19
20 input_dimension: usize,
22
23 hidden_dimension: usize,
25
26 dropout: super::layers::DropoutLayer,
28}
29
30impl<T: Float + Debug + Send + Sync + 'static> FeedForwardNetwork<T> {
31 pub fn new(
33 input_dimension: usize,
34 hidden_dimension: usize,
35 activation: ActivationFunction,
36 ) -> Result<Self> {
37 let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
38 let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
39 let dropout = super::layers::DropoutLayer::new(0.1);
40
41 Ok(Self {
42 linear1,
43 linear2,
44 activation,
45 input_dimension,
46 hidden_dimension,
47 dropout,
48 })
49 }
50
51 pub fn new_with_dropout(
53 input_dimension: usize,
54 hidden_dimension: usize,
55 activation: ActivationFunction,
56 dropout_rate: f64,
57 ) -> Result<Self> {
58 let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
59 let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
60 let dropout = super::layers::DropoutLayer::new(dropout_rate);
61
62 Ok(Self {
63 linear1,
64 linear2,
65 activation,
66 input_dimension,
67 hidden_dimension,
68 dropout,
69 })
70 }
71
72 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
74 let hidden = self.linear1.forward(input)?;
76
77 let activated = ActivationLayer::apply(&hidden, self.activation);
79
80 let dropout_output = self.dropout.forward(&activated);
82
83 let output = self.linear2.forward(&dropout_output)?;
85
86 Ok(output)
87 }
88
89 pub fn parameter_count(&self) -> usize {
91 self.linear1.parameter_count() + self.linear2.parameter_count()
92 }
93
94 pub fn reset(&mut self) -> Result<()> {
96 self.linear1.reset()?;
97 self.linear2.reset()?;
98 Ok(())
99 }
100
101 pub fn set_training(&mut self, training: bool) {
103 self.dropout.set_training(training);
104 }
105
106 pub fn get_activation(&self) -> ActivationFunction {
108 self.activation
109 }
110
111 pub fn set_activation(&mut self, activation: ActivationFunction) {
113 self.activation = activation;
114 }
115}
116
117pub struct LinearLayer<T: Float + Debug + Send + Sync + 'static> {
119 weight: Array2<T>,
121
122 bias: Array1<T>,
124
125 input_dim: usize,
127
128 output_dim: usize,
130}
131
132impl<T: Float + Debug + Send + Sync + 'static> LinearLayer<T> {
133 pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
135 let scale = T::from(2.0 / (input_dim + output_dim) as f64)
137 .unwrap()
138 .sqrt();
139 let mut weight = Array2::zeros((input_dim, output_dim));
140 let bias = Array1::zeros(output_dim);
141
142 for i in 0..input_dim {
144 for j in 0..output_dim {
145 let random_f64 = scirs2_core::random::random::<f64>();
146 let scaled_f64 = random_f64 * 2.0 - 1.0;
147 let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
148 weight[[i, j]] = random_val * scale;
149 }
150 }
151
152 Ok(Self {
153 weight,
154 bias,
155 input_dim,
156 output_dim,
157 })
158 }
159
160 pub fn new_he_init(input_dim: usize, output_dim: usize) -> Result<Self> {
162 let scale = scirs2_core::numeric::NumCast::from(2.0 / input_dim as f64)
163 .unwrap_or_else(|| T::zero())
164 .sqrt();
165 let mut weight = Array2::zeros((input_dim, output_dim));
166 let bias = Array1::zeros(output_dim);
167
168 for i in 0..input_dim {
169 for j in 0..output_dim {
170 let random_f64 = scirs2_core::random::random::<f64>();
171 let scaled_f64 = random_f64 * 2.0 - 1.0;
172 let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
173 weight[[i, j]] = random_val * scale;
174 }
175 }
176
177 Ok(Self {
178 weight,
179 bias,
180 input_dim,
181 output_dim,
182 })
183 }
184
185 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
187 let batch_size = input.shape()[0];
188 let input_features = input.shape()[1];
189
190 if input_features != self.input_dim {
191 return Err(crate::error::OptimError::Other(format!(
192 "Input dimension mismatch: expected {}, got {}",
193 self.input_dim, input_features
194 )));
195 }
196
197 let mut output = Array2::zeros((batch_size, self.output_dim));
198
199 for i in 0..batch_size {
201 for j in 0..self.output_dim {
202 let mut sum = self.bias[j];
203 for k in 0..self.input_dim {
204 sum = sum + input[[i, k]] * self.weight[[k, j]];
205 }
206 output[[i, j]] = sum;
207 }
208 }
209
210 Ok(output)
211 }
212
213 pub fn parameter_count(&self) -> usize {
215 self.input_dim * self.output_dim + self.output_dim
216 }
217
218 pub fn reset(&mut self) -> Result<()> {
220 let scale = T::from(2.0 / (self.input_dim + self.output_dim) as f64)
222 .unwrap()
223 .sqrt();
224
225 for i in 0..self.input_dim {
226 for j in 0..self.output_dim {
227 let random_f64 = scirs2_core::random::random::<f64>();
228 let scaled_f64 = random_f64 * 2.0 - 1.0;
229 let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
230 self.weight[[i, j]] = random_val * scale;
231 }
232 }
233
234 self.bias.fill(T::zero());
235 Ok(())
236 }
237
238 pub fn get_weights(&self) -> &Array2<T> {
240 &self.weight
241 }
242
243 pub fn get_bias(&self) -> &Array1<T> {
245 &self.bias
246 }
247
248 pub fn update_weights(
250 &mut self,
251 weight_delta: &Array2<T>,
252 bias_delta: &Array1<T>,
253 ) -> Result<()> {
254 if weight_delta.shape() != self.weight.shape() {
255 return Err(crate::error::OptimError::Other(
256 "Weight delta shape mismatch".to_string(),
257 ));
258 }
259
260 if bias_delta.len() != self.bias.len() {
261 return Err(crate::error::OptimError::Other(
262 "Bias delta shape mismatch".to_string(),
263 ));
264 }
265
266 self.weight = &self.weight - weight_delta;
267 self.bias = &self.bias - bias_delta;
268
269 Ok(())
270 }
271}
272
273pub struct GatedLinearUnit<T: Float + Debug + Send + Sync + 'static> {
275 gate_linear: LinearLayer<T>,
277
278 value_linear: LinearLayer<T>,
280
281 input_dimension: usize,
283
284 hidden_dimension: usize,
286}
287
288impl<T: Float + Debug + Send + Sync + 'static> GatedLinearUnit<T> {
289 pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
291 let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
292 let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
293
294 Ok(Self {
295 gate_linear,
296 value_linear,
297 input_dimension,
298 hidden_dimension,
299 })
300 }
301
302 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
304 let gate = self.gate_linear.forward(input)?;
305 let value = self.value_linear.forward(input)?;
306
307 let sigmoid_gate = ActivationLayer::apply(&gate, ActivationFunction::Sigmoid);
309 let output = &sigmoid_gate * &value;
310
311 Ok(output)
312 }
313
314 pub fn parameter_count(&self) -> usize {
316 self.gate_linear.parameter_count() + self.value_linear.parameter_count()
317 }
318
319 pub fn reset(&mut self) -> Result<()> {
321 self.gate_linear.reset()?;
322 self.value_linear.reset()?;
323 Ok(())
324 }
325}
326
327pub struct SwiGLU<T: Float + Debug + Send + Sync + 'static> {
329 gate_linear: LinearLayer<T>,
331
332 value_linear: LinearLayer<T>,
334
335 input_dimension: usize,
337
338 hidden_dimension: usize,
340}
341
342impl<T: Float + Debug + Send + Sync + 'static> SwiGLU<T> {
343 pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
345 let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
346 let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
347
348 Ok(Self {
349 gate_linear,
350 value_linear,
351 input_dimension,
352 hidden_dimension,
353 })
354 }
355
356 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
358 let gate = self.gate_linear.forward(input)?;
359 let value = self.value_linear.forward(input)?;
360
361 let swish_gate = ActivationLayer::apply(&gate, ActivationFunction::Swish);
363 let output = &swish_gate * &value;
364
365 Ok(output)
366 }
367
368 pub fn parameter_count(&self) -> usize {
370 self.gate_linear.parameter_count() + self.value_linear.parameter_count()
371 }
372
373 pub fn reset(&mut self) -> Result<()> {
375 self.gate_linear.reset()?;
376 self.value_linear.reset()?;
377 Ok(())
378 }
379}
380
381pub struct MixtureOfExperts<T: Float + Debug + Send + Sync + 'static> {
383 experts: Vec<FeedForwardNetwork<T>>,
385
386 gate: LinearLayer<T>,
388
389 num_experts: usize,
391
392 top_k: usize,
394
395 input_dimension: usize,
397
398 hidden_dimension: usize,
400}
401
402impl<T: Float + Debug + Send + Sync + 'static> MixtureOfExperts<T> {
403 pub fn new(
405 input_dimension: usize,
406 hidden_dimension: usize,
407 num_experts: usize,
408 top_k: usize,
409 activation: ActivationFunction,
410 ) -> Result<Self> {
411 let mut experts = Vec::new();
412 for _ in 0..num_experts {
413 experts.push(FeedForwardNetwork::new(
414 input_dimension,
415 hidden_dimension,
416 activation,
417 )?);
418 }
419
420 let gate = LinearLayer::new(input_dimension, num_experts)?;
421
422 Ok(Self {
423 experts,
424 gate,
425 num_experts,
426 top_k: top_k.min(num_experts),
427 input_dimension,
428 hidden_dimension,
429 })
430 }
431
432 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
434 let batch_size = input.shape()[0];
435
436 let gate_scores = self.gate.forward(input)?;
438 let gate_probs = self.softmax(&gate_scores);
439
440 let mut output = Array2::zeros((batch_size, self.input_dimension));
442
443 for i in 0..batch_size {
444 let sample_input = input.row(i).insert_axis(Axis(0)).to_owned();
445 let sample_probs = gate_probs.row(i);
446
447 let mut prob_indices: Vec<(usize, T)> = sample_probs
449 .iter()
450 .enumerate()
451 .map(|(idx, &prob)| (idx, prob))
452 .collect();
453
454 prob_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
455
456 let top_k_indices: Vec<usize> = prob_indices
457 .iter()
458 .take(self.top_k)
459 .map(|(idx, _)| *idx)
460 .collect();
461
462 let mut sample_output = Array1::zeros(self.input_dimension);
464 let mut total_weight = T::zero();
465
466 for &expert_idx in &top_k_indices {
467 let expert_output = self.experts[expert_idx].forward(&sample_input)?;
468 let weight = sample_probs[expert_idx];
469
470 total_weight = total_weight + weight;
471
472 for j in 0..self.input_dimension {
473 sample_output[j] = sample_output[j] + weight * expert_output[[0, j]];
474 }
475 }
476
477 if total_weight > T::zero() {
479 for j in 0..self.input_dimension {
480 sample_output[j] = sample_output[j] / total_weight;
481 output[[i, j]] = sample_output[j];
482 }
483 }
484 }
485
486 Ok(output)
487 }
488
489 fn softmax(&self, input: &Array2<T>) -> Array2<T> {
491 let mut output = Array2::zeros(input.raw_dim());
492 let batch_size = input.shape()[0];
493
494 for i in 0..batch_size {
495 let row = input.row(i);
496 let max_val = row.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
497
498 let mut exp_sum = T::zero();
499 let mut exp_row = Array1::zeros(row.len());
500
501 for (j, &val) in row.iter().enumerate() {
502 exp_row[j] = (val - max_val).exp();
503 exp_sum = exp_sum + exp_row[j];
504 }
505
506 for (j, &exp_val) in exp_row.iter().enumerate() {
507 output[[i, j]] = exp_val / exp_sum;
508 }
509 }
510
511 output
512 }
513
514 pub fn parameter_count(&self) -> usize {
516 let expert_params: usize = self
517 .experts
518 .iter()
519 .map(|expert| expert.parameter_count())
520 .sum();
521
522 expert_params + self.gate.parameter_count()
523 }
524
525 pub fn reset(&mut self) -> Result<()> {
527 for expert in &mut self.experts {
528 expert.reset()?;
529 }
530 self.gate.reset()?;
531 Ok(())
532 }
533
534 pub fn set_training(&mut self, training: bool) {
536 for expert in &mut self.experts {
537 expert.set_training(training);
538 }
539 }
540}
541
542pub use super::config::ActivationFunction;
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_feedforward_network() {
551 let ffn = FeedForwardNetwork::<f32>::new(
552 128,
553 512,
554 crate::transformer_based_optimizer::config::ActivationFunction::ReLU,
555 );
556 assert!(ffn.is_ok());
557
558 let mut network = ffn.unwrap();
559 let input = Array2::<f32>::ones((4, 128));
560 let result = network.forward(&input);
561 assert!(result.is_ok());
562
563 let output = result.unwrap();
564 assert_eq!(output.shape(), &[4, 128]);
565 }
566
567 #[test]
568 fn test_linear_layer() {
569 let linear = LinearLayer::<f32>::new(64, 128);
570 assert!(linear.is_ok());
571
572 let layer = linear.unwrap();
573 let input = Array2::<f32>::zeros((2, 64));
574 let result = layer.forward(&input);
575 assert!(result.is_ok());
576
577 let output = result.unwrap();
578 assert_eq!(output.shape(), &[2, 128]);
579 assert_eq!(layer.parameter_count(), 64 * 128 + 128);
580 }
581
582 #[test]
583 fn test_gated_linear_unit() {
584 let glu = GatedLinearUnit::<f32>::new(128, 256);
585 assert!(glu.is_ok());
586
587 let unit = glu.unwrap();
588 let input = Array2::<f32>::ones((2, 128));
589 let result = unit.forward(&input);
590 assert!(result.is_ok());
591
592 let output = result.unwrap();
593 assert_eq!(output.shape(), &[2, 256]);
594 }
595
596 #[test]
597 fn test_swiglu() {
598 let swiglu = SwiGLU::<f32>::new(128, 256);
599 assert!(swiglu.is_ok());
600
601 let unit = swiglu.unwrap();
602 let input = Array2::<f32>::ones((2, 128));
603 let result = unit.forward(&input);
604 assert!(result.is_ok());
605
606 let output = result.unwrap();
607 assert_eq!(output.shape(), &[2, 256]);
608 }
609
610 #[test]
611 fn test_mixture_of_experts() {
612 let moe = MixtureOfExperts::<f32>::new(128, 256, 4, 2, ActivationFunction::ReLU);
613 assert!(moe.is_ok());
614
615 let mut mixture = moe.unwrap();
616 let input = Array2::<f32>::ones((3, 128));
617 let result = mixture.forward(&input);
618 assert!(result.is_ok());
619
620 let output = result.unwrap();
621 assert_eq!(output.shape(), &[3, 128]);
622 }
623
624 #[test]
625 fn test_linear_layer_initialization() {
626 let xavier_layer = LinearLayer::<f32>::new(64, 128).unwrap();
627 let he_layer = LinearLayer::<f32>::new_he_init(64, 128).unwrap();
628
629 assert_eq!(xavier_layer.parameter_count(), he_layer.parameter_count());
630 assert_eq!(
631 xavier_layer.get_weights().shape(),
632 he_layer.get_weights().shape()
633 );
634 }
635}