Skip to main content

optirs_learned/transformer_based_optimizer/
feedforward.rs

1// Feed-forward network implementations for transformer layers
2
3use super::layers::ActivationLayer;
4use crate::error::Result;
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::numeric::Float;
7use std::fmt::Debug;
8
9/// Feed-forward network implementation
10pub struct FeedForwardNetwork<T: Float + Debug + Send + Sync + 'static> {
11    /// First linear layer (expansion)
12    linear1: LinearLayer<T>,
13
14    /// Second linear layer (projection)
15    linear2: LinearLayer<T>,
16
17    /// Activation function
18    activation: ActivationFunction,
19
20    /// Input dimension
21    input_dimension: usize,
22
23    /// Hidden dimension (typically 4x input dimension)
24    hidden_dimension: usize,
25
26    /// Dropout layer for regularization
27    dropout: super::layers::DropoutLayer,
28}
29
30impl<T: Float + Debug + Send + Sync + 'static> FeedForwardNetwork<T> {
31    /// Create new feed-forward network
32    pub fn new(
33        input_dimension: usize,
34        hidden_dimension: usize,
35        activation: ActivationFunction,
36    ) -> Result<Self> {
37        let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
38        let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
39        let dropout = super::layers::DropoutLayer::new(0.1);
40
41        Ok(Self {
42            linear1,
43            linear2,
44            activation,
45            input_dimension,
46            hidden_dimension,
47            dropout,
48        })
49    }
50
51    /// Create with custom dropout rate
52    pub fn new_with_dropout(
53        input_dimension: usize,
54        hidden_dimension: usize,
55        activation: ActivationFunction,
56        dropout_rate: f64,
57    ) -> Result<Self> {
58        let linear1 = LinearLayer::new(input_dimension, hidden_dimension)?;
59        let linear2 = LinearLayer::new(hidden_dimension, input_dimension)?;
60        let dropout = super::layers::DropoutLayer::new(dropout_rate);
61
62        Ok(Self {
63            linear1,
64            linear2,
65            activation,
66            input_dimension,
67            hidden_dimension,
68            dropout,
69        })
70    }
71
72    /// Forward pass through feed-forward network
73    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
74        // First linear transformation
75        let hidden = self.linear1.forward(input)?;
76
77        // Apply activation function
78        let activated = ActivationLayer::apply(&hidden, self.activation);
79
80        // Apply dropout
81        let dropout_output = self.dropout.forward(&activated);
82
83        // Second linear transformation
84        let output = self.linear2.forward(&dropout_output)?;
85
86        Ok(output)
87    }
88
89    /// Get parameter count
90    pub fn parameter_count(&self) -> usize {
91        self.linear1.parameter_count() + self.linear2.parameter_count()
92    }
93
94    /// Reset all parameters
95    pub fn reset(&mut self) -> Result<()> {
96        self.linear1.reset()?;
97        self.linear2.reset()?;
98        Ok(())
99    }
100
101    /// Set training mode
102    pub fn set_training(&mut self, training: bool) {
103        self.dropout.set_training(training);
104    }
105
106    /// Get activation function
107    pub fn get_activation(&self) -> ActivationFunction {
108        self.activation
109    }
110
111    /// Set activation function
112    pub fn set_activation(&mut self, activation: ActivationFunction) {
113        self.activation = activation;
114    }
115}
116
117/// Linear layer implementation
118pub struct LinearLayer<T: Float + Debug + Send + Sync + 'static> {
119    /// Weight matrix
120    weight: Array2<T>,
121
122    /// Bias vector
123    bias: Array1<T>,
124
125    /// Input dimension
126    input_dim: usize,
127
128    /// Output dimension
129    output_dim: usize,
130}
131
132impl<T: Float + Debug + Send + Sync + 'static> LinearLayer<T> {
133    /// Create new linear layer
134    pub fn new(input_dim: usize, output_dim: usize) -> Result<Self> {
135        // Xavier/Glorot initialization
136        let scale = T::from(2.0 / (input_dim + output_dim) as f64)
137            .expect("unwrap failed")
138            .sqrt();
139        let mut weight = Array2::zeros((input_dim, output_dim));
140        let bias = Array1::zeros(output_dim);
141
142        // Initialize weights with Xavier initialization
143        for i in 0..input_dim {
144            for j in 0..output_dim {
145                let random_f64 = scirs2_core::random::random::<f64>();
146                let scaled_f64 = random_f64 * 2.0 - 1.0;
147                let random_val =
148                    <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
149                weight[[i, j]] = random_val * scale;
150            }
151        }
152
153        Ok(Self {
154            weight,
155            bias,
156            input_dim,
157            output_dim,
158        })
159    }
160
161    /// Create with He initialization (better for ReLU)
162    pub fn new_he_init(input_dim: usize, output_dim: usize) -> Result<Self> {
163        let scale = scirs2_core::numeric::NumCast::from(2.0 / input_dim as f64)
164            .unwrap_or_else(|| T::zero())
165            .sqrt();
166        let mut weight = Array2::zeros((input_dim, output_dim));
167        let bias = Array1::zeros(output_dim);
168
169        for i in 0..input_dim {
170            for j in 0..output_dim {
171                let random_f64 = scirs2_core::random::random::<f64>();
172                let scaled_f64 = random_f64 * 2.0 - 1.0;
173                let random_val =
174                    <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
175                weight[[i, j]] = random_val * scale;
176            }
177        }
178
179        Ok(Self {
180            weight,
181            bias,
182            input_dim,
183            output_dim,
184        })
185    }
186
187    /// Forward pass through linear layer
188    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
189        let batch_size = input.shape()[0];
190        let input_features = input.shape()[1];
191
192        if input_features != self.input_dim {
193            return Err(crate::error::OptimError::Other(format!(
194                "Input dimension mismatch: expected {}, got {}",
195                self.input_dim, input_features
196            )));
197        }
198
199        let mut output = Array2::zeros((batch_size, self.output_dim));
200
201        // Matrix multiplication: input @ weight + bias
202        for i in 0..batch_size {
203            for j in 0..self.output_dim {
204                let mut sum = self.bias[j];
205                for k in 0..self.input_dim {
206                    sum = sum + input[[i, k]] * self.weight[[k, j]];
207                }
208                output[[i, j]] = sum;
209            }
210        }
211
212        Ok(output)
213    }
214
215    /// Get parameter count
216    pub fn parameter_count(&self) -> usize {
217        self.input_dim * self.output_dim + self.output_dim
218    }
219
220    /// Reset parameters
221    pub fn reset(&mut self) -> Result<()> {
222        // Re-initialize with Xavier
223        let scale = T::from(2.0 / (self.input_dim + self.output_dim) as f64)
224            .expect("unwrap failed")
225            .sqrt();
226
227        for i in 0..self.input_dim {
228            for j in 0..self.output_dim {
229                let random_f64 = scirs2_core::random::random::<f64>();
230                let scaled_f64 = random_f64 * 2.0 - 1.0;
231                let random_val =
232                    <T as scirs2_core::numeric::NumCast>::from(scaled_f64).expect("unwrap failed");
233                self.weight[[i, j]] = random_val * scale;
234            }
235        }
236
237        self.bias.fill(T::zero());
238        Ok(())
239    }
240
241    /// Get weight matrix reference
242    pub fn get_weights(&self) -> &Array2<T> {
243        &self.weight
244    }
245
246    /// Get bias vector reference
247    pub fn get_bias(&self) -> &Array1<T> {
248        &self.bias
249    }
250
251    /// Update weights (for training)
252    pub fn update_weights(
253        &mut self,
254        weight_delta: &Array2<T>,
255        bias_delta: &Array1<T>,
256    ) -> Result<()> {
257        if weight_delta.shape() != self.weight.shape() {
258            return Err(crate::error::OptimError::Other(
259                "Weight delta shape mismatch".to_string(),
260            ));
261        }
262
263        if bias_delta.len() != self.bias.len() {
264            return Err(crate::error::OptimError::Other(
265                "Bias delta shape mismatch".to_string(),
266            ));
267        }
268
269        self.weight = &self.weight - weight_delta;
270        self.bias = &self.bias - bias_delta;
271
272        Ok(())
273    }
274}
275
276/// Gated Linear Unit (GLU) implementation
277pub struct GatedLinearUnit<T: Float + Debug + Send + Sync + 'static> {
278    /// Linear layer for gate
279    gate_linear: LinearLayer<T>,
280
281    /// Linear layer for values
282    value_linear: LinearLayer<T>,
283
284    /// Input dimension
285    input_dimension: usize,
286
287    /// Hidden dimension
288    hidden_dimension: usize,
289}
290
291impl<T: Float + Debug + Send + Sync + 'static> GatedLinearUnit<T> {
292    /// Create new GLU
293    pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
294        let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
295        let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
296
297        Ok(Self {
298            gate_linear,
299            value_linear,
300            input_dimension,
301            hidden_dimension,
302        })
303    }
304
305    /// Forward pass through GLU
306    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
307        let gate = self.gate_linear.forward(input)?;
308        let value = self.value_linear.forward(input)?;
309
310        // Apply sigmoid to gate and element-wise multiply
311        let sigmoid_gate = ActivationLayer::apply(&gate, ActivationFunction::Sigmoid);
312        let output = &sigmoid_gate * &value;
313
314        Ok(output)
315    }
316
317    /// Get parameter count
318    pub fn parameter_count(&self) -> usize {
319        self.gate_linear.parameter_count() + self.value_linear.parameter_count()
320    }
321
322    /// Reset parameters
323    pub fn reset(&mut self) -> Result<()> {
324        self.gate_linear.reset()?;
325        self.value_linear.reset()?;
326        Ok(())
327    }
328}
329
330/// Swish GLU variant
331pub struct SwiGLU<T: Float + Debug + Send + Sync + 'static> {
332    /// Linear layer for gate
333    gate_linear: LinearLayer<T>,
334
335    /// Linear layer for values
336    value_linear: LinearLayer<T>,
337
338    /// Input dimension
339    input_dimension: usize,
340
341    /// Hidden dimension
342    hidden_dimension: usize,
343}
344
345impl<T: Float + Debug + Send + Sync + 'static> SwiGLU<T> {
346    /// Create new SwiGLU
347    pub fn new(input_dimension: usize, hidden_dimension: usize) -> Result<Self> {
348        let gate_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
349        let value_linear = LinearLayer::new(input_dimension, hidden_dimension)?;
350
351        Ok(Self {
352            gate_linear,
353            value_linear,
354            input_dimension,
355            hidden_dimension,
356        })
357    }
358
359    /// Forward pass through SwiGLU
360    pub fn forward(&self, input: &Array2<T>) -> Result<Array2<T>> {
361        let gate = self.gate_linear.forward(input)?;
362        let value = self.value_linear.forward(input)?;
363
364        // Apply Swish to gate and element-wise multiply
365        let swish_gate = ActivationLayer::apply(&gate, ActivationFunction::Swish);
366        let output = &swish_gate * &value;
367
368        Ok(output)
369    }
370
371    /// Get parameter count
372    pub fn parameter_count(&self) -> usize {
373        self.gate_linear.parameter_count() + self.value_linear.parameter_count()
374    }
375
376    /// Reset parameters
377    pub fn reset(&mut self) -> Result<()> {
378        self.gate_linear.reset()?;
379        self.value_linear.reset()?;
380        Ok(())
381    }
382}
383
384/// Expert mixture for sparse feed-forward networks
385pub struct MixtureOfExperts<T: Float + Debug + Send + Sync + 'static> {
386    /// Individual expert networks
387    experts: Vec<FeedForwardNetwork<T>>,
388
389    /// Gating network
390    gate: LinearLayer<T>,
391
392    /// Number of experts
393    num_experts: usize,
394
395    /// Number of experts to activate (top-k)
396    top_k: usize,
397
398    /// Input dimension
399    input_dimension: usize,
400
401    /// Hidden dimension per expert
402    hidden_dimension: usize,
403}
404
405impl<T: Float + Debug + Send + Sync + 'static> MixtureOfExperts<T> {
406    /// Create new mixture of experts
407    pub fn new(
408        input_dimension: usize,
409        hidden_dimension: usize,
410        num_experts: usize,
411        top_k: usize,
412        activation: ActivationFunction,
413    ) -> Result<Self> {
414        let mut experts = Vec::new();
415        for _ in 0..num_experts {
416            experts.push(FeedForwardNetwork::new(
417                input_dimension,
418                hidden_dimension,
419                activation,
420            )?);
421        }
422
423        let gate = LinearLayer::new(input_dimension, num_experts)?;
424
425        Ok(Self {
426            experts,
427            gate,
428            num_experts,
429            top_k: top_k.min(num_experts),
430            input_dimension,
431            hidden_dimension,
432        })
433    }
434
435    /// Forward pass through mixture of experts
436    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
437        let batch_size = input.shape()[0];
438
439        // Compute gating scores
440        let gate_scores = self.gate.forward(input)?;
441        let gate_probs = self.softmax(&gate_scores);
442
443        // Find top-k experts for each sample
444        let mut output = Array2::zeros((batch_size, self.input_dimension));
445
446        for i in 0..batch_size {
447            let sample_input = input.row(i).insert_axis(Axis(0)).to_owned();
448            let sample_probs = gate_probs.row(i);
449
450            // Get top-k indices
451            let mut prob_indices: Vec<(usize, T)> = sample_probs
452                .iter()
453                .enumerate()
454                .map(|(idx, &prob)| (idx, prob))
455                .collect();
456
457            prob_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("unwrap failed"));
458
459            let top_k_indices: Vec<usize> = prob_indices
460                .iter()
461                .take(self.top_k)
462                .map(|(idx, _)| *idx)
463                .collect();
464
465            // Compute weighted sum from top-k experts
466            let mut sample_output = Array1::zeros(self.input_dimension);
467            let mut total_weight = T::zero();
468
469            for &expert_idx in &top_k_indices {
470                let expert_output = self.experts[expert_idx].forward(&sample_input)?;
471                let weight = sample_probs[expert_idx];
472
473                total_weight = total_weight + weight;
474
475                for j in 0..self.input_dimension {
476                    sample_output[j] = sample_output[j] + weight * expert_output[[0, j]];
477                }
478            }
479
480            // Normalize by total weight
481            if total_weight > T::zero() {
482                for j in 0..self.input_dimension {
483                    sample_output[j] = sample_output[j] / total_weight;
484                    output[[i, j]] = sample_output[j];
485                }
486            }
487        }
488
489        Ok(output)
490    }
491
492    /// Softmax activation for gating
493    fn softmax(&self, input: &Array2<T>) -> Array2<T> {
494        let mut output = Array2::zeros(input.raw_dim());
495        let batch_size = input.shape()[0];
496
497        for i in 0..batch_size {
498            let row = input.row(i);
499            let max_val = row.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
500
501            let mut exp_sum = T::zero();
502            let mut exp_row = Array1::zeros(row.len());
503
504            for (j, &val) in row.iter().enumerate() {
505                exp_row[j] = (val - max_val).exp();
506                exp_sum = exp_sum + exp_row[j];
507            }
508
509            for (j, &exp_val) in exp_row.iter().enumerate() {
510                output[[i, j]] = exp_val / exp_sum;
511            }
512        }
513
514        output
515    }
516
517    /// Get parameter count
518    pub fn parameter_count(&self) -> usize {
519        let expert_params: usize = self
520            .experts
521            .iter()
522            .map(|expert| expert.parameter_count())
523            .sum();
524
525        expert_params + self.gate.parameter_count()
526    }
527
528    /// Reset all parameters
529    pub fn reset(&mut self) -> Result<()> {
530        for expert in &mut self.experts {
531            expert.reset()?;
532        }
533        self.gate.reset()?;
534        Ok(())
535    }
536
537    /// Set training mode for all experts
538    pub fn set_training(&mut self, training: bool) {
539        for expert in &mut self.experts {
540            expert.set_training(training);
541        }
542    }
543}
544
545// Re-export for backward compatibility
546pub use super::config::ActivationFunction;
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_feedforward_network() {
554        let ffn = FeedForwardNetwork::<f32>::new(
555            128,
556            512,
557            crate::transformer_based_optimizer::config::ActivationFunction::ReLU,
558        );
559        assert!(ffn.is_ok());
560
561        let mut network = ffn.expect("unwrap failed");
562        let input = Array2::<f32>::ones((4, 128));
563        let result = network.forward(&input);
564        assert!(result.is_ok());
565
566        let output = result.expect("unwrap failed");
567        assert_eq!(output.shape(), &[4, 128]);
568    }
569
570    #[test]
571    fn test_linear_layer() {
572        let linear = LinearLayer::<f32>::new(64, 128);
573        assert!(linear.is_ok());
574
575        let layer = linear.expect("unwrap failed");
576        let input = Array2::<f32>::zeros((2, 64));
577        let result = layer.forward(&input);
578        assert!(result.is_ok());
579
580        let output = result.expect("unwrap failed");
581        assert_eq!(output.shape(), &[2, 128]);
582        assert_eq!(layer.parameter_count(), 64 * 128 + 128);
583    }
584
585    #[test]
586    fn test_gated_linear_unit() {
587        let glu = GatedLinearUnit::<f32>::new(128, 256);
588        assert!(glu.is_ok());
589
590        let unit = glu.expect("unwrap failed");
591        let input = Array2::<f32>::ones((2, 128));
592        let result = unit.forward(&input);
593        assert!(result.is_ok());
594
595        let output = result.expect("unwrap failed");
596        assert_eq!(output.shape(), &[2, 256]);
597    }
598
599    #[test]
600    fn test_swiglu() {
601        let swiglu = SwiGLU::<f32>::new(128, 256);
602        assert!(swiglu.is_ok());
603
604        let unit = swiglu.expect("unwrap failed");
605        let input = Array2::<f32>::ones((2, 128));
606        let result = unit.forward(&input);
607        assert!(result.is_ok());
608
609        let output = result.expect("unwrap failed");
610        assert_eq!(output.shape(), &[2, 256]);
611    }
612
613    #[test]
614    fn test_mixture_of_experts() {
615        let moe = MixtureOfExperts::<f32>::new(128, 256, 4, 2, ActivationFunction::ReLU);
616        assert!(moe.is_ok());
617
618        let mut mixture = moe.expect("unwrap failed");
619        let input = Array2::<f32>::ones((3, 128));
620        let result = mixture.forward(&input);
621        assert!(result.is_ok());
622
623        let output = result.expect("unwrap failed");
624        assert_eq!(output.shape(), &[3, 128]);
625    }
626
627    #[test]
628    fn test_linear_layer_initialization() {
629        let xavier_layer = LinearLayer::<f32>::new(64, 128).expect("unwrap failed");
630        let he_layer = LinearLayer::<f32>::new_he_init(64, 128).expect("unwrap failed");
631
632        assert_eq!(xavier_layer.parameter_count(), he_layer.parameter_count());
633        assert_eq!(
634            xavier_layer.get_weights().shape(),
635            he_layer.get_weights().shape()
636        );
637    }
638}