optirs_learned/transformer_based_optimizer/
feedforward.rs1use super::layers::ActivationLayer;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
11 linear1: LinearLayer<T>,
13
14 linear2: LinearLayer<T>,
16
17 activation: ActivationFunction,
19
20 input_dimension: usize,
22
23 hidden_dimension: usize,
25
26 dropout: super::layers::DropoutLayer,
28}
29
30impl<T: Float + Debug + Send + Sync + 'static> FeedForwardNetwork<T> {
31 pub fn new(
33 input_dimension: usize,
34 hidden_dimension: usize,
35 activation: ActivationFunction,
36 ) -> Result<Self> {
37 let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
38 let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
39 let dropout = super::layers::DropoutLayer::new(0.1);
40
41 Ok(Self {
42 linear1,
43 linear2,
44 activation,
45 input_dimension,
46 hidden_dimension,
47 dropout,
48 })
49 }
50
51 pub fn new_with_dropout(
53 input_dimension: usize,
54 hidden_dimension: usize,
55 activation: ActivationFunction,
56 dropout_rate: f64,
57 ) -> Result<Self> {
58 let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
59 let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
60 let dropout = super::layers::DropoutLayer::new(dropout_rate);
61
62 Ok(Self {
63 linear1,
64 linear2,
65 activation,
66 input_dimension,
67 hidden_dimension,
68 dropout,
69 })
70 }
71
72 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
74 let hidden = self.linear1.forward(input)?;
76
77 let activated = ActivationLayer::apply(&hidden, self.activation);
79
80 let dropout_output = self.dropout.forward(&activated);
82
83 let output = self.linear2.forward(&dropout_output)?;
85
86 Ok(output)
87 }
88
89 pub fn parameter_count(&self) -> usize {
91 self.linear1.parameter_count() + self.linear2.parameter_count()
92 }
93
94 pub fn reset(&mut self) -> Result<()> {
96 self.linear1.reset()?;
97 self.linear2.reset()?;
98 Ok(())
99 }
100
101 pub fn set_training(&mut self, training: bool) {
103 self.dropout.set_training(training);
104 }
105
106 pub fn get_activation(&self) -> ActivationFunction {
108 self.activation
109 }
110
111 pub fn set_activation(&mut self, activation: ActivationFunction) {
113 self.activation = activation;
114 }
115}
116
117pub struct LinearLayer<T: Float + Debug + Send + Sync + 'static> {
119 weight: Array2<T>,
121
122 bias: Array1<T>,
124
125 input_dim: usize,
127
128 output_dim: usize,
130}
131
132impl<T: Float + Debug + Send + Sync + 'static> LinearLayer<T> {
133 pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
135 let scale = T::from(2.0 / (input_dim + output_dim) as f64)
137 .expect("unwrap failed")
138 .sqrt();
139 let mut weight = Array2::zeros((input_dim, output_dim));
140 let bias = Array1::zeros(output_dim);
141
142 for i in 0..input_dim {
144 for j in 0..output_dim {
145 let random_f64 = scirs2_core::random::random::<f64>();
146 let scaled_f64 = random_f64 * 2.0 - 1.0;
147 let random_val =
148 <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
149 weight[[i, j]] = random_val * scale;
150 }
151 }
152
153 Ok(Self {
154 weight,
155 bias,
156 input_dim,
157 output_dim,
158 })
159 }
160
161 pub fn new_he_init(input_dim: usize, output_dim: usize) -> Result<Self> {
163 let scale = scirs2_core::numeric::NumCast::from(2.0 / input_dim as f64)
164 .unwrap_or_else(|| T::zero())
165 .sqrt();
166 let mut weight = Array2::zeros((input_dim, output_dim));
167 let bias = Array1::zeros(output_dim);
168
169 for i in 0..input_dim {
170 for j in 0..output_dim {
171 let random_f64 = scirs2_core::random::random::<f64>();
172 let scaled_f64 = random_f64 * 2.0 - 1.0;
173 let random_val =
174 <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
175 weight[[i, j]] = random_val * scale;
176 }
177 }
178
179 Ok(Self {
180 weight,
181 bias,
182 input_dim,
183 output_dim,
184 })
185 }
186
187 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
189 let batch_size = input.shape()[0];
190 let input_features = input.shape()[1];
191
192 if input_features != self.input_dim {
193 return Err(crate::error::OptimError::Other(format!(
194 "Input dimension mismatch: expected {}, got {}",
195 self.input_dim, input_features
196 )));
197 }
198
199 let mut output = Array2::zeros((batch_size, self.output_dim));
200
201 for i in 0..batch_size {
203 for j in 0..self.output_dim {
204 let mut sum = self.bias[j];
205 for k in 0..self.input_dim {
206 sum = sum + input[[i, k]] * self.weight[[k, j]];
207 }
208 output[[i, j]] = sum;
209 }
210 }
211
212 Ok(output)
213 }
214
215 pub fn parameter_count(&self) -> usize {
217 self.input_dim * self.output_dim + self.output_dim
218 }
219
220 pub fn reset(&mut self) -> Result<()> {
222 let scale = T::from(2.0 / (self.input_dim + self.output_dim) as f64)
224 .expect("unwrap failed")
225 .sqrt();
226
227 for i in 0..self.input_dim {
228 for j in 0..self.output_dim {
229 let random_f64 = scirs2_core::random::random::<f64>();
230 let scaled_f64 = random_f64 * 2.0 - 1.0;
231 let random_val =
232 <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
233 self.weight[[i, j]] = random_val * scale;
234 }
235 }
236
237 self.bias.fill(T::zero());
238 Ok(())
239 }
240
241 pub fn get_weights(&self) -> &Array2<T> {
243 &self.weight
244 }
245
246 pub fn get_bias(&self) -> &Array1<T> {
248 &self.bias
249 }
250
251 pub fn update_weights(
253 &mut self,
254 weight_delta: &Array2<T>,
255 bias_delta: &Array1<T>,
256 ) -> Result<()> {
257 if weight_delta.shape() != self.weight.shape() {
258 return Err(crate::error::OptimError::Other(
259 "Weight delta shape mismatch".to_string(),
260 ));
261 }
262
263 if bias_delta.len() != self.bias.len() {
264 return Err(crate::error::OptimError::Other(
265 "Bias delta shape mismatch".to_string(),
266 ));
267 }
268
269 self.weight = &self.weight - weight_delta;
270 self.bias = &self.bias - bias_delta;
271
272 Ok(())
273 }
274}
275
276pub struct GatedLinearUnit<T: Float + Debug + Send + Sync + 'static> {
278 gate_linear: LinearLayer<T>,
280
281 value_linear: LinearLayer<T>,
283
284 input_dimension: usize,
286
287 hidden_dimension: usize,
289}
290
291impl<T: Float + Debug + Send + Sync + 'static> GatedLinearUnit<T> {
292 pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
294 let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
295 let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
296
297 Ok(Self {
298 gate_linear,
299 value_linear,
300 input_dimension,
301 hidden_dimension,
302 })
303 }
304
305 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
307 let gate = self.gate_linear.forward(input)?;
308 let value = self.value_linear.forward(input)?;
309
310 let sigmoid_gate = ActivationLayer::apply(&gate, ActivationFunction::Sigmoid);
312 let output = &sigmoid_gate * &value;
313
314 Ok(output)
315 }
316
317 pub fn parameter_count(&self) -> usize {
319 self.gate_linear.parameter_count() + self.value_linear.parameter_count()
320 }
321
322 pub fn reset(&mut self) -> Result<()> {
324 self.gate_linear.reset()?;
325 self.value_linear.reset()?;
326 Ok(())
327 }
328}
329
330pub struct SwiGLU<T: Float + Debug + Send + Sync + 'static> {
332 gate_linear: LinearLayer<T>,
334
335 value_linear: LinearLayer<T>,
337
338 input_dimension: usize,
340
341 hidden_dimension: usize,
343}
344
345impl<T: Float + Debug + Send + Sync + 'static> SwiGLU<T> {
346 pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
348 let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
349 let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
350
351 Ok(Self {
352 gate_linear,
353 value_linear,
354 input_dimension,
355 hidden_dimension,
356 })
357 }
358
359 pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
361 let gate = self.gate_linear.forward(input)?;
362 let value = self.value_linear.forward(input)?;
363
364 let swish_gate = ActivationLayer::apply(&gate, ActivationFunction::Swish);
366 let output = &swish_gate * &value;
367
368 Ok(output)
369 }
370
371 pub fn parameter_count(&self) -> usize {
373 self.gate_linear.parameter_count() + self.value_linear.parameter_count()
374 }
375
376 pub fn reset(&mut self) -> Result<()> {
378 self.gate_linear.reset()?;
379 self.value_linear.reset()?;
380 Ok(())
381 }
382}
383
384pub struct MixtureOfExperts<T: Float + Debug + Send + Sync + 'static> {
386 experts: Vec<FeedForwardNetwork<T>>,
388
389 gate: LinearLayer<T>,
391
392 num_experts: usize,
394
395 top_k: usize,
397
398 input_dimension: usize,
400
401 hidden_dimension: usize,
403}
404
405impl<T: Float + Debug + Send + Sync + 'static> MixtureOfExperts<T> {
406 pub fn new(
408 input_dimension: usize,
409 hidden_dimension: usize,
410 num_experts: usize,
411 top_k: usize,
412 activation: ActivationFunction,
413 ) -> Result<Self> {
414 let mut experts = Vec::new();
415 for _ in 0..num_experts {
416 experts.push(FeedForwardNetwork::new(
417 input_dimension,
418 hidden_dimension,
419 activation,
420 )?);
421 }
422
423 let gate = LinearLayer::new(input_dimension, num_experts)?;
424
425 Ok(Self {
426 experts,
427 gate,
428 num_experts,
429 top_k: top_k.min(num_experts),
430 input_dimension,
431 hidden_dimension,
432 })
433 }
434
435 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
437 let batch_size = input.shape()[0];
438
439 let gate_scores = self.gate.forward(input)?;
441 let gate_probs = self.softmax(&gate_scores);
442
443 let mut output = Array2::zeros((batch_size, self.input_dimension));
445
446 for i in 0..batch_size {
447 let sample_input = input.row(i).insert_axis(Axis(0)).to_owned();
448 let sample_probs = gate_probs.row(i);
449
450 let mut prob_indices: Vec<(usize, T)> = sample_probs
452 .iter()
453 .enumerate()
454 .map(|(idx, &prob)| (idx, prob))
455 .collect();
456
457 prob_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("unwrap failed"));
458
459 let top_k_indices: Vec<usize> = prob_indices
460 .iter()
461 .take(self.top_k)
462 .map(|(idx, _)| *idx)
463 .collect();
464
465 let mut sample_output = Array1::zeros(self.input_dimension);
467 let mut total_weight = T::zero();
468
469 for &expert_idx in &top_k_indices {
470 let expert_output = self.experts[expert_idx].forward(&sample_input)?;
471 let weight = sample_probs[expert_idx];
472
473 total_weight = total_weight + weight;
474
475 for j in 0..self.input_dimension {
476 sample_output[j] = sample_output[j] + weight * expert_output[[0, j]];
477 }
478 }
479
480 if total_weight > T::zero() {
482 for j in 0..self.input_dimension {
483 sample_output[j] = sample_output[j] / total_weight;
484 output[[i, j]] = sample_output[j];
485 }
486 }
487 }
488
489 Ok(output)
490 }
491
492 fn softmax(&self, input: &Array2<T>) -> Array2<T> {
494 let mut output = Array2::zeros(input.raw_dim());
495 let batch_size = input.shape()[0];
496
497 for i in 0..batch_size {
498 let row = input.row(i);
499 let max_val = row.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
500
501 let mut exp_sum = T::zero();
502 let mut exp_row = Array1::zeros(row.len());
503
504 for (j, &val) in row.iter().enumerate() {
505 exp_row[j] = (val - max_val).exp();
506 exp_sum = exp_sum + exp_row[j];
507 }
508
509 for (j, &exp_val) in exp_row.iter().enumerate() {
510 output[[i, j]] = exp_val / exp_sum;
511 }
512 }
513
514 output
515 }
516
517 pub fn parameter_count(&self) -> usize {
519 let expert_params: usize = self
520 .experts
521 .iter()
522 .map(|expert| expert.parameter_count())
523 .sum();
524
525 expert_params + self.gate.parameter_count()
526 }
527
528 pub fn reset(&mut self) -> Result<()> {
530 for expert in &mut self.experts {
531 expert.reset()?;
532 }
533 self.gate.reset()?;
534 Ok(())
535 }
536
537 pub fn set_training(&mut self, training: bool) {
539 for expert in &mut self.experts {
540 expert.set_training(training);
541 }
542 }
543}
544
545pub use super::config::ActivationFunction;
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_feedforward_network() {
554 let ffn = FeedForwardNetwork::<f32>::new(
555 128,
556 512,
557 crate::transformer_based_optimizer::config::ActivationFunction::ReLU,
558 );
559 assert!(ffn.is_ok());
560
561 let mut network = ffn.expect("unwrap failed");
562 let input = Array2::<f32>::ones((4, 128));
563 let result = network.forward(&input);
564 assert!(result.is_ok());
565
566 let output = result.expect("unwrap failed");
567 assert_eq!(output.shape(), &[4, 128]);
568 }
569
570 #[test]
571 fn test_linear_layer() {
572 let linear = LinearLayer::<f32>::new(64, 128);
573 assert!(linear.is_ok());
574
575 let layer = linear.expect("unwrap failed");
576 let input = Array2::<f32>::zeros((2, 64));
577 let result = layer.forward(&input);
578 assert!(result.is_ok());
579
580 let output = result.expect("unwrap failed");
581 assert_eq!(output.shape(), &[2, 128]);
582 assert_eq!(layer.parameter_count(), 64 * 128 + 128);
583 }
584
585 #[test]
586 fn test_gated_linear_unit() {
587 let glu = GatedLinearUnit::<f32>::new(128, 256);
588 assert!(glu.is_ok());
589
590 let unit = glu.expect("unwrap failed");
591 let input = Array2::<f32>::ones((2, 128));
592 let result = unit.forward(&input);
593 assert!(result.is_ok());
594
595 let output = result.expect("unwrap failed");
596 assert_eq!(output.shape(), &[2, 256]);
597 }
598
599 #[test]
600 fn test_swiglu() {
601 let swiglu = SwiGLU::<f32>::new(128, 256);
602 assert!(swiglu.is_ok());
603
604 let unit = swiglu.expect("unwrap failed");
605 let input = Array2::<f32>::ones((2, 128));
606 let result = unit.forward(&input);
607 assert!(result.is_ok());
608
609 let output = result.expect("unwrap failed");
610 assert_eq!(output.shape(), &[2, 256]);
611 }
612
613 #[test]
614 fn test_mixture_of_experts() {
615 let moe = MixtureOfExperts::<f32>::new(128, 256, 4, 2, ActivationFunction::ReLU);
616 assert!(moe.is_ok());
617
618 let mut mixture = moe.expect("unwrap failed");
619 let input = Array2::<f32>::ones((3, 128));
620 let result = mixture.forward(&input);
621 assert!(result.is_ok());
622
623 let output = result.expect("unwrap failed");
624 assert_eq!(output.shape(), &[3, 128]);
625 }
626
627 #[test]
628 fn test_linear_layer_initialization() {
629 let xavier_layer = LinearLayer::<f32>::new(64, 128).expect("unwrap failed");
630 let he_layer = LinearLayer::<f32>::new_he_init(64, 128).expect("unwrap failed");
631
632 assert_eq!(xavier_layer.parameter_count(), he_layer.parameter_count());
633 assert_eq!(
634 xavier_layer.get_weights().shape(),
635 he_layer.get_weights().shape()
636 );
637 }
638}