optirs_learned/transformer_based_optimizer/
feedforward.rs

1// Feed-forward network implementations for transformer layers
2
3use super::layers::ActivationLayer;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9/// Feed-forward network implementation
10pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
11    /// First linear layer (expansion)
12    linear1: LinearLayer<T>,
13
14    /// Second linear layer (projection)
15    linear2: LinearLayer<T>,
16
17    /// Activation function
18    activation: ActivationFunction,
19
20    /// Input dimension
21    input_dimension: usize,
22
23    /// Hidden dimension (typically 4x input dimension)
24    hidden_dimension: usize,
25
26    /// Dropout layer for regularization
27    dropout: super::layers::DropoutLayer,
28}
29
30impl<T: Float + Debug + Send + Sync + 'static> FeedForwardNetwork<T> {
31    /// Create new feed-forward network
32    pub fn new(
33        input_dimension: usize,
34        hidden_dimension: usize,
35        activation: ActivationFunction,
36    ) -> Result<Self> {
37        let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
38        let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
39        let dropout = super::layers::DropoutLayer::new(0.1);
40
41        Ok(Self {
42            linear1,
43            linear2,
44            activation,
45            input_dimension,
46            hidden_dimension,
47            dropout,
48        })
49    }
50
51    /// Create with custom dropout rate
52    pub fn new_with_dropout(
53        input_dimension: usize,
54        hidden_dimension: usize,
55        activation: ActivationFunction,
56        dropout_rate: f64,
57    ) -> Result<Self> {
58        let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
59        let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
60        let dropout = super::layers::DropoutLayer::new(dropout_rate);
61
62        Ok(Self {
63            linear1,
64            linear2,
65            activation,
66            input_dimension,
67            hidden_dimension,
68            dropout,
69        })
70    }
71
72    /// Forward pass through feed-forward network
73    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
74        // First linear transformation
75        let hidden = self.linear1.forward(input)?;
76
77        // Apply activation function
78        let activated = ActivationLayer::apply(&hidden, self.activation);
79
80        // Apply dropout
81        let dropout_output = self.dropout.forward(&activated);
82
83        // Second linear transformation
84        let output = self.linear2.forward(&dropout_output)?;
85
86        Ok(output)
87    }
88
89    /// Get parameter count
90    pub fn parameter_count(&self) -> usize {
91        self.linear1.parameter_count() + self.linear2.parameter_count()
92    }
93
94    /// Reset all parameters
95    pub fn reset(&mut self) -> Result<()> {
96        self.linear1.reset()?;
97        self.linear2.reset()?;
98        Ok(())
99    }
100
101    /// Set training mode
102    pub fn set_training(&mut self, training: bool) {
103        self.dropout.set_training(training);
104    }
105
106    /// Get activation function
107    pub fn get_activation(&self) -> ActivationFunction {
108        self.activation
109    }
110
111    /// Set activation function
112    pub fn set_activation(&mut self, activation: ActivationFunction) {
113        self.activation = activation;
114    }
115}
116
117/// Linear layer implementation
118pub struct LinearLayer<T: Float + Debug + Send + Sync + 'static> {
119    /// Weight matrix
120    weight: Array2<T>,
121
122    /// Bias vector
123    bias: Array1<T>,
124
125    /// Input dimension
126    input_dim: usize,
127
128    /// Output dimension
129    output_dim: usize,
130}
131
132impl<T: Float + Debug + Send + Sync + 'static> LinearLayer<T> {
133    /// Create new linear layer
134    pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
135        // Xavier/Glorot initialization
136        let scale = T::from(2.0 / (input_dim + output_dim) as f64)
137            .unwrap()
138            .sqrt();
139        let mut weight = Array2::zeros((input_dim, output_dim));
140        let bias = Array1::zeros(output_dim);
141
142        // Initialize weights with Xavier initialization
143        for i in 0..input_dim {
144            for j in 0..output_dim {
145                let random_f64 = scirs2_core::random::random::<f64>();
146                let scaled_f64 = random_f64 * 2.0 - 1.0;
147                let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
148                weight[[i, j]] = random_val * scale;
149            }
150        }
151
152        Ok(Self {
153            weight,
154            bias,
155            input_dim,
156            output_dim,
157        })
158    }
159
160    /// Create with He initialization (better for ReLU)
161    pub fn new_he_init(input_dim: usize, output_dim: usize) -> Result<Self> {
162        let scale = scirs2_core::numeric::NumCast::from(2.0 / input_dim as f64)
163            .unwrap_or_else(|| T::zero())
164            .sqrt();
165        let mut weight = Array2::zeros((input_dim, output_dim));
166        let bias = Array1::zeros(output_dim);
167
168        for i in 0..input_dim {
169            for j in 0..output_dim {
170                let random_f64 = scirs2_core::random::random::<f64>();
171                let scaled_f64 = random_f64 * 2.0 - 1.0;
172                let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
173                weight[[i, j]] = random_val * scale;
174            }
175        }
176
177        Ok(Self {
178            weight,
179            bias,
180            input_dim,
181            output_dim,
182        })
183    }
184
185    /// Forward pass through linear layer
186    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
187        let batch_size = input.shape()[0];
188        let input_features = input.shape()[1];
189
190        if input_features != self.input_dim {
191            return Err(crate::error::OptimError::Other(format!(
192                "Input dimension mismatch: expected {}, got {}",
193                self.input_dim, input_features
194            )));
195        }
196
197        let mut output = Array2::zeros((batch_size, self.output_dim));
198
199        // Matrix multiplication: input @ weight + bias
200        for i in 0..batch_size {
201            for j in 0..self.output_dim {
202                let mut sum = self.bias[j];
203                for k in 0..self.input_dim {
204                    sum = sum + input[[i, k]] * self.weight[[k, j]];
205                }
206                output[[i, j]] = sum;
207            }
208        }
209
210        Ok(output)
211    }
212
213    /// Get parameter count
214    pub fn parameter_count(&self) -> usize {
215        self.input_dim * self.output_dim + self.output_dim
216    }
217
218    /// Reset parameters
219    pub fn reset(&mut self) -> Result<()> {
220        // Re-initialize with Xavier
221        let scale = T::from(2.0 / (self.input_dim + self.output_dim) as f64)
222            .unwrap()
223            .sqrt();
224
225        for i in 0..self.input_dim {
226            for j in 0..self.output_dim {
227                let random_f64 = scirs2_core::random::random::<f64>();
228                let scaled_f64 = random_f64 * 2.0 - 1.0;
229                let random_val = <T as scirs2_core::numeric::NumCast>::from(scaled_f64).unwrap();
230                self.weight[[i, j]] = random_val * scale;
231            }
232        }
233
234        self.bias.fill(T::zero());
235        Ok(())
236    }
237
238    /// Get weight matrix reference
239    pub fn get_weights(&self) -> &Array2<T> {
240        &self.weight
241    }
242
243    /// Get bias vector reference
244    pub fn get_bias(&self) -> &Array1<T> {
245        &self.bias
246    }
247
248    /// Update weights (for training)
249    pub fn update_weights(
250        &mut self,
251        weight_delta: &Array2<T>,
252        bias_delta: &Array1<T>,
253    ) -> Result<()> {
254        if weight_delta.shape() != self.weight.shape() {
255            return Err(crate::error::OptimError::Other(
256                "Weight delta shape mismatch".to_string(),
257            ));
258        }
259
260        if bias_delta.len() != self.bias.len() {
261            return Err(crate::error::OptimError::Other(
262                "Bias delta shape mismatch".to_string(),
263            ));
264        }
265
266        self.weight = &self.weight - weight_delta;
267        self.bias = &self.bias - bias_delta;
268
269        Ok(())
270    }
271}
272
273/// Gated Linear Unit (GLU) implementation
274pub struct GatedLinearUnit<T: Float + Debug + Send + Sync + 'static> {
275    /// Linear layer for gate
276    gate_linear: LinearLayer<T>,
277
278    /// Linear layer for values
279    value_linear: LinearLayer<T>,
280
281    /// Input dimension
282    input_dimension: usize,
283
284    /// Hidden dimension
285    hidden_dimension: usize,
286}
287
288impl<T: Float + Debug + Send + Sync + 'static> GatedLinearUnit<T> {
289    /// Create new GLU
290    pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
291        let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
292        let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
293
294        Ok(Self {
295            gate_linear,
296            value_linear,
297            input_dimension,
298            hidden_dimension,
299        })
300    }
301
302    /// Forward pass through GLU
303    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
304        let gate = self.gate_linear.forward(input)?;
305        let value = self.value_linear.forward(input)?;
306
307        // Apply sigmoid to gate and element-wise multiply
308        let sigmoid_gate = ActivationLayer::apply(&gate, ActivationFunction::Sigmoid);
309        let output = &sigmoid_gate * &value;
310
311        Ok(output)
312    }
313
314    /// Get parameter count
315    pub fn parameter_count(&self) -> usize {
316        self.gate_linear.parameter_count() + self.value_linear.parameter_count()
317    }
318
319    /// Reset parameters
320    pub fn reset(&mut self) -> Result<()> {
321        self.gate_linear.reset()?;
322        self.value_linear.reset()?;
323        Ok(())
324    }
325}
326
327/// Swish GLU variant
328pub struct SwiGLU<T: Float + Debug + Send + Sync + 'static> {
329    /// Linear layer for gate
330    gate_linear: LinearLayer<T>,
331
332    /// Linear layer for values
333    value_linear: LinearLayer<T>,
334
335    /// Input dimension
336    input_dimension: usize,
337
338    /// Hidden dimension
339    hidden_dimension: usize,
340}
341
342impl<T: Float + Debug + Send + Sync + 'static> SwiGLU<T> {
343    /// Create new SwiGLU
344    pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
345        let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
346        let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
347
348        Ok(Self {
349            gate_linear,
350            value_linear,
351            input_dimension,
352            hidden_dimension,
353        })
354    }
355
356    /// Forward pass through SwiGLU
357    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
358        let gate = self.gate_linear.forward(input)?;
359        let value = self.value_linear.forward(input)?;
360
361        // Apply Swish to gate and element-wise multiply
362        let swish_gate = ActivationLayer::apply(&gate, ActivationFunction::Swish);
363        let output = &swish_gate * &value;
364
365        Ok(output)
366    }
367
368    /// Get parameter count
369    pub fn parameter_count(&self) -> usize {
370        self.gate_linear.parameter_count() + self.value_linear.parameter_count()
371    }
372
373    /// Reset parameters
374    pub fn reset(&mut self) -> Result<()> {
375        self.gate_linear.reset()?;
376        self.value_linear.reset()?;
377        Ok(())
378    }
379}
380
381/// Expert mixture for sparse feed-forward networks
382pub struct MixtureOfExperts<T: Float + Debug + Send + Sync + 'static> {
383    /// Individual expert networks
384    experts: Vec<FeedForwardNetwork<T>>,
385
386    /// Gating network
387    gate: LinearLayer<T>,
388
389    /// Number of experts
390    num_experts: usize,
391
392    /// Number of experts to activate (top-k)
393    top_k: usize,
394
395    /// Input dimension
396    input_dimension: usize,
397
398    /// Hidden dimension per expert
399    hidden_dimension: usize,
400}
401
402impl<T: Float + Debug + Send + Sync + 'static> MixtureOfExperts<T> {
403    /// Create new mixture of experts
404    pub fn new(
405        input_dimension: usize,
406        hidden_dimension: usize,
407        num_experts: usize,
408        top_k: usize,
409        activation: ActivationFunction,
410    ) -> Result<Self> {
411        let mut experts = Vec::new();
412        for _ in 0..num_experts {
413            experts.push(FeedForwardNetwork::new(
414                input_dimension,
415                hidden_dimension,
416                activation,
417            )?);
418        }
419
420        let gate = LinearLayer::new(input_dimension, num_experts)?;
421
422        Ok(Self {
423            experts,
424            gate,
425            num_experts,
426            top_k: top_k.min(num_experts),
427            input_dimension,
428            hidden_dimension,
429        })
430    }
431
432    /// Forward pass through mixture of experts
433    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
434        let batch_size = input.shape()[0];
435
436        // Compute gating scores
437        let gate_scores = self.gate.forward(input)?;
438        let gate_probs = self.softmax(&gate_scores);
439
440        // Find top-k experts for each sample
441        let mut output = Array2::zeros((batch_size, self.input_dimension));
442
443        for i in 0..batch_size {
444            let sample_input = input.row(i).insert_axis(Axis(0)).to_owned();
445            let sample_probs = gate_probs.row(i);
446
447            // Get top-k indices
448            let mut prob_indices: Vec<(usize, T)> = sample_probs
449                .iter()
450                .enumerate()
451                .map(|(idx, &prob)| (idx, prob))
452                .collect();
453
454            prob_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
455
456            let top_k_indices: Vec<usize> = prob_indices
457                .iter()
458                .take(self.top_k)
459                .map(|(idx, _)| *idx)
460                .collect();
461
462            // Compute weighted sum from top-k experts
463            let mut sample_output = Array1::zeros(self.input_dimension);
464            let mut total_weight = T::zero();
465
466            for &expert_idx in &top_k_indices {
467                let expert_output = self.experts[expert_idx].forward(&sample_input)?;
468                let weight = sample_probs[expert_idx];
469
470                total_weight = total_weight + weight;
471
472                for j in 0..self.input_dimension {
473                    sample_output[j] = sample_output[j] + weight * expert_output[[0, j]];
474                }
475            }
476
477            // Normalize by total weight
478            if total_weight > T::zero() {
479                for j in 0..self.input_dimension {
480                    sample_output[j] = sample_output[j] / total_weight;
481                    output[[i, j]] = sample_output[j];
482                }
483            }
484        }
485
486        Ok(output)
487    }
488
489    /// Softmax activation for gating
490    fn softmax(&self, input: &Array2<T>) -> Array2<T> {
491        let mut output = Array2::zeros(input.raw_dim());
492        let batch_size = input.shape()[0];
493
494        for i in 0..batch_size {
495            let row = input.row(i);
496            let max_val = row.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
497
498            let mut exp_sum = T::zero();
499            let mut exp_row = Array1::zeros(row.len());
500
501            for (j, &val) in row.iter().enumerate() {
502                exp_row[j] = (val - max_val).exp();
503                exp_sum = exp_sum + exp_row[j];
504            }
505
506            for (j, &exp_val) in exp_row.iter().enumerate() {
507                output[[i, j]] = exp_val / exp_sum;
508            }
509        }
510
511        output
512    }
513
514    /// Get parameter count
515    pub fn parameter_count(&self) -> usize {
516        let expert_params: usize = self
517            .experts
518            .iter()
519            .map(|expert| expert.parameter_count())
520            .sum();
521
522        expert_params + self.gate.parameter_count()
523    }
524
525    /// Reset all parameters
526    pub fn reset(&mut self) -> Result<()> {
527        for expert in &mut self.experts {
528            expert.reset()?;
529        }
530        self.gate.reset()?;
531        Ok(())
532    }
533
534    /// Set training mode for all experts
535    pub fn set_training(&mut self, training: bool) {
536        for expert in &mut self.experts {
537            expert.set_training(training);
538        }
539    }
540}
541
542// Re-export for backward compatibility
543pub use super::config::ActivationFunction;
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_feedforward_network() {
551        let ffn = FeedForwardNetwork::<f32>::new(
552            128,
553            512,
554            crate::transformer_based_optimizer::config::ActivationFunction::ReLU,
555        );
556        assert!(ffn.is_ok());
557
558        let mut network = ffn.unwrap();
559        let input = Array2::<f32>::ones((4, 128));
560        let result = network.forward(&input);
561        assert!(result.is_ok());
562
563        let output = result.unwrap();
564        assert_eq!(output.shape(), &[4, 128]);
565    }
566
567    #[test]
568    fn test_linear_layer() {
569        let linear = LinearLayer::<f32>::new(64, 128);
570        assert!(linear.is_ok());
571
572        let layer = linear.unwrap();
573        let input = Array2::<f32>::zeros((2, 64));
574        let result = layer.forward(&input);
575        assert!(result.is_ok());
576
577        let output = result.unwrap();
578        assert_eq!(output.shape(), &[2, 128]);
579        assert_eq!(layer.parameter_count(), 64 * 128 + 128);
580    }
581
582    #[test]
583    fn test_gated_linear_unit() {
584        let glu = GatedLinearUnit::<f32>::new(128, 256);
585        assert!(glu.is_ok());
586
587        let unit = glu.unwrap();
588        let input = Array2::<f32>::ones((2, 128));
589        let result = unit.forward(&input);
590        assert!(result.is_ok());
591
592        let output = result.unwrap();
593        assert_eq!(output.shape(), &[2, 256]);
594    }
595
596    #[test]
597    fn test_swiglu() {
598        let swiglu = SwiGLU::<f32>::new(128, 256);
599        assert!(swiglu.is_ok());
600
601        let unit = swiglu.unwrap();
602        let input = Array2::<f32>::ones((2, 128));
603        let result = unit.forward(&input);
604        assert!(result.is_ok());
605
606        let output = result.unwrap();
607        assert_eq!(output.shape(), &[2, 256]);
608    }
609
610    #[test]
611    fn test_mixture_of_experts() {
612        let moe = MixtureOfExperts::<f32>::new(128, 256, 4, 2, ActivationFunction::ReLU);
613        assert!(moe.is_ok());
614
615        let mut mixture = moe.unwrap();
616        let input = Array2::<f32>::ones((3, 128));
617        let result = mixture.forward(&input);
618        assert!(result.is_ok());
619
620        let output = result.unwrap();
621        assert_eq!(output.shape(), &[3, 128]);
622    }
623
624    #[test]
625    fn test_linear_layer_initialization() {
626        let xavier_layer = LinearLayer::<f32>::new(64, 128).unwrap();
627        let he_layer = LinearLayer::<f32>::new_he_init(64, 128).unwrap();
628
629        assert_eq!(xavier_layer.parameter_count(), he_layer.parameter_count());
630        assert_eq!(
631            xavier_layer.get_weights().shape(),
632            he_layer.get_weights().shape()
633        );
634    }
635}