sklears_multioutput/
mlp.rs

1//! Multi-Layer Perceptron for Multi-Output Learning
2//!
3//! This module provides a flexible Multi-Layer Perceptron implementation that can handle
4//! both regression and classification tasks with multiple outputs. It supports configurable
5//! architecture, activation functions, and training parameters.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::RandNormal;
10use scirs2_core::random::Rng;
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    traits::{Estimator, Fit, Predict, Untrained},
14    types::Float,
15};
16
17use crate::activation::ActivationFunction;
18use crate::loss::LossFunction;
19
20/// Multi-Layer Perceptron for Multi-Output Learning
21///
22/// This neural network can handle both regression and classification tasks with multiple outputs.
23/// It supports configurable architecture, activation functions, and training parameters.
24///
25/// # Examples
26///
27/// ```
28/// use sklears_multioutput::mlp::{MultiOutputMLP};
29/// use sklears_multioutput::activation::ActivationFunction;
30/// use sklears_multioutput::loss::LossFunction;
31/// use sklears_core::traits::{Predict, Fit};
32/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
33/// use scirs2_core::ndarray::array;
34///
35/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
36/// let y = array![[0.5, 1.2], [1.0, 2.1], [1.5, 0.8], [2.0, 2.5]]; // Multi-output regression
37///
38/// let mlp = MultiOutputMLP::new()
39///     .hidden_layer_sizes(vec![10, 5])
40///     .activation(ActivationFunction::ReLU)
41///     .output_activation(ActivationFunction::Linear)
42///     .loss_function(LossFunction::MeanSquaredError)
43///     .learning_rate(0.01)
44///     .max_iter(1000)
45///     .random_state(Some(42));
46///
47/// let trained_mlp = mlp.fit(&X.view(), &y).unwrap();
48/// let predictions = trained_mlp.predict(&X.view()).unwrap();
49/// ```
50#[derive(Debug, Clone)]
51pub struct MultiOutputMLP<S = Untrained> {
52    state: S,
53    hidden_layer_sizes: Vec<usize>,
54    activation: ActivationFunction,
55    output_activation: ActivationFunction,
56    loss_function: LossFunction,
57    learning_rate: Float,
58    max_iter: usize,
59    tolerance: Float,
60    random_state: Option<u64>,
61    alpha: Float, // L2 regularization
62    batch_size: Option<usize>,
63    early_stopping: bool,
64    validation_fraction: Float,
65}
66
67/// Trained state for MultiOutputMLP
68#[derive(Debug, Clone)]
69pub struct MultiOutputMLPTrained {
70    /// Weights for each layer
71    weights: Vec<Array2<Float>>,
72    /// Biases for each layer
73    biases: Vec<Array1<Float>>,
74    /// Number of input features
75    n_features: usize,
76    /// Number of outputs
77    n_outputs: usize,
78    /// Training configuration
79    hidden_layer_sizes: Vec<usize>,
80    activation: ActivationFunction,
81    output_activation: ActivationFunction,
82    /// Training history
83    loss_curve: Vec<Float>,
84    /// Number of iterations performed
85    n_iter: usize,
86}
87
88impl MultiOutputMLP<Untrained> {
89    /// Create a new MultiOutputMLP instance
90    pub fn new() -> Self {
91        Self {
92            state: Untrained,
93            hidden_layer_sizes: vec![100],
94            activation: ActivationFunction::ReLU,
95            output_activation: ActivationFunction::Linear,
96            loss_function: LossFunction::MeanSquaredError,
97            learning_rate: 0.001,
98            max_iter: 200,
99            tolerance: 1e-4,
100            random_state: None,
101            alpha: 0.0001,
102            batch_size: None,
103            early_stopping: false,
104            validation_fraction: 0.1,
105        }
106    }
107
108    /// Set hidden layer sizes
109    pub fn hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
110        self.hidden_layer_sizes = sizes;
111        self
112    }
113
114    /// Set activation function for hidden layers
115    pub fn activation(mut self, activation: ActivationFunction) -> Self {
116        self.activation = activation;
117        self
118    }
119
120    /// Set activation function for output layer
121    pub fn output_activation(mut self, activation: ActivationFunction) -> Self {
122        self.output_activation = activation;
123        self
124    }
125
126    /// Set loss function
127    pub fn loss_function(mut self, loss_function: LossFunction) -> Self {
128        self.loss_function = loss_function;
129        self
130    }
131
132    /// Set learning rate
133    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
134        self.learning_rate = learning_rate;
135        self
136    }
137
138    /// Set maximum number of iterations
139    pub fn max_iter(mut self, max_iter: usize) -> Self {
140        self.max_iter = max_iter;
141        self
142    }
143
144    /// Set convergence tolerance
145    pub fn tolerance(mut self, tolerance: Float) -> Self {
146        self.tolerance = tolerance;
147        self
148    }
149
150    /// Set random state for reproducibility
151    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
152        self.random_state = random_state;
153        self
154    }
155
156    /// Set L2 regularization parameter
157    pub fn alpha(mut self, alpha: Float) -> Self {
158        self.alpha = alpha;
159        self
160    }
161
162    /// Set batch size for mini-batch gradient descent
163    pub fn batch_size(mut self, batch_size: Option<usize>) -> Self {
164        self.batch_size = batch_size;
165        self
166    }
167
168    /// Enable early stopping
169    pub fn early_stopping(mut self, early_stopping: bool) -> Self {
170        self.early_stopping = early_stopping;
171        self
172    }
173
174    /// Set validation fraction for early stopping
175    pub fn validation_fraction(mut self, validation_fraction: Float) -> Self {
176        self.validation_fraction = validation_fraction;
177        self
178    }
179}
180
181impl Default for MultiOutputMLP<Untrained> {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl Estimator for MultiOutputMLP<Untrained> {
188    type Config = ();
189    type Error = SklearsError;
190    type Float = Float;
191
192    fn config(&self) -> &Self::Config {
193        &()
194    }
195}
196
197impl Fit<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputMLP<Untrained> {
198    type Fitted = MultiOutputMLP<MultiOutputMLPTrained>;
199
200    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
201        let (n_samples, n_features) = X.dim();
202        let (n_samples_y, n_outputs) = y.dim();
203
204        if n_samples != n_samples_y {
205            return Err(SklearsError::InvalidInput(
206                "X and y must have the same number of samples".to_string(),
207            ));
208        }
209
210        if n_samples == 0 {
211            return Err(SklearsError::InvalidInput(
212                "Cannot fit with zero samples".to_string(),
213            ));
214        }
215
216        // Initialize random number generator
217        let mut rng = match self.random_state {
218            Some(seed) => scirs2_core::random::seeded_rng(seed),
219            None => scirs2_core::random::seeded_rng(42),
220        };
221
222        // Build network architecture
223        let mut layer_sizes = vec![n_features];
224        layer_sizes.extend(&self.hidden_layer_sizes);
225        layer_sizes.push(n_outputs);
226
227        // Initialize weights and biases
228        let mut weights = Vec::new();
229        let mut biases = Vec::new();
230
231        for i in 0..layer_sizes.len() - 1 {
232            let input_size = layer_sizes[i];
233            let output_size = layer_sizes[i + 1];
234
235            // Xavier/Glorot initialization
236            let scale = (2.0 / (input_size + output_size) as Float).sqrt();
237            let normal_dist = RandNormal::new(0.0, scale).unwrap();
238            let mut weight_matrix = Array2::<Float>::zeros((output_size, input_size));
239            for i in 0..output_size {
240                for j in 0..input_size {
241                    weight_matrix[[i, j]] = rng.sample(normal_dist);
242                }
243            }
244            let bias_vector = Array1::<Float>::zeros(output_size);
245
246            weights.push(weight_matrix);
247            biases.push(bias_vector);
248        }
249
250        // Training loop
251        let mut loss_curve = Vec::new();
252        let X_owned = X.to_owned();
253        let y_owned = y.to_owned();
254
255        for epoch in 0..self.max_iter {
256            // Forward pass
257            let (activations, _) = self.forward_pass(&X_owned, &weights, &biases)?;
258            let predictions = activations.last().unwrap();
259
260            // Compute loss
261            let loss = self.loss_function.compute_loss(predictions, &y_owned);
262            loss_curve.push(loss);
263
264            // Check convergence
265            if epoch > 0 && (loss_curve[epoch - 1] - loss).abs() < self.tolerance {
266                break;
267            }
268
269            // Backward pass
270            self.backward_pass(&X_owned, &y_owned, &mut weights, &mut biases)?;
271        }
272
273        let trained_state = MultiOutputMLPTrained {
274            weights,
275            biases,
276            n_features,
277            n_outputs,
278            hidden_layer_sizes: self.hidden_layer_sizes.clone(),
279            activation: self.activation,
280            output_activation: self.output_activation,
281            loss_curve,
282            n_iter: self.max_iter,
283        };
284
285        Ok(MultiOutputMLP {
286            state: trained_state,
287            hidden_layer_sizes: self.hidden_layer_sizes,
288            activation: self.activation,
289            output_activation: self.output_activation,
290            loss_function: self.loss_function,
291            learning_rate: self.learning_rate,
292            max_iter: self.max_iter,
293            tolerance: self.tolerance,
294            random_state: self.random_state,
295            alpha: self.alpha,
296            batch_size: self.batch_size,
297            early_stopping: self.early_stopping,
298            validation_fraction: self.validation_fraction,
299        })
300    }
301}
302
303impl MultiOutputMLP<Untrained> {
304    /// Forward pass through the network
305    #[allow(clippy::type_complexity)]
306    fn forward_pass(
307        &self,
308        X: &Array2<Float>,
309        weights: &[Array2<Float>],
310        biases: &[Array1<Float>],
311    ) -> SklResult<(Vec<Array2<Float>>, Vec<Array2<Float>>)> {
312        let mut activations = vec![X.clone()];
313        let mut z_values = Vec::new();
314
315        for (i, (weight, bias)) in weights.iter().zip(biases.iter()).enumerate() {
316            let current_input = activations.last().unwrap();
317
318            // Linear transformation: z = X * W^T + b
319            let z = current_input.dot(&weight.t()) + bias.view().insert_axis(Axis(0));
320            z_values.push(z.clone());
321
322            // Apply activation function
323            let activation_fn = if i == weights.len() - 1 {
324                self.output_activation
325            } else {
326                self.activation
327            };
328
329            let activated = activation_fn.apply_2d(&z);
330            activations.push(activated);
331        }
332
333        Ok((activations, z_values))
334    }
335
336    /// Backward pass with gradient computation
337    fn backward_pass(
338        &self,
339        X: &Array2<Float>,
340        y: &Array2<Float>,
341        weights: &mut [Array2<Float>],
342        biases: &mut [Array1<Float>],
343    ) -> SklResult<()> {
344        let (activations, z_values) = self.forward_pass(X, weights, biases)?;
345        let n_samples = X.nrows() as Float;
346
347        // Compute output layer error
348        let output_predictions = activations.last().unwrap();
349        let mut delta = output_predictions - y;
350
351        // Backpropagate errors
352        for i in (0..weights.len()).rev() {
353            let current_activation = &activations[i];
354
355            // Compute gradients
356            let weight_gradient = delta.t().dot(current_activation) / n_samples;
357            let bias_gradient = delta.mean_axis(Axis(0)).unwrap();
358
359            // Add L2 regularization to weight gradient
360            let regularized_weight_gradient = weight_gradient + self.alpha * &weights[i];
361
362            // Update weights and biases
363            weights[i] = &weights[i] - self.learning_rate * regularized_weight_gradient;
364            biases[i] = &biases[i] - self.learning_rate * bias_gradient;
365
366            // Compute delta for next layer (if not the first layer)
367            if i > 0 {
368                let activation_fn = if i == weights.len() - 1 {
369                    self.output_activation
370                } else {
371                    self.activation
372                };
373
374                // For simplicity, we'll use a basic derivative approximation
375                let derivative_approx = match activation_fn {
376                    ActivationFunction::ReLU => {
377                        z_values[i - 1].map(|&val| if val > 0.0 { 1.0 } else { 0.0 })
378                    }
379                    ActivationFunction::Sigmoid => {
380                        let sigmoid_vals = &activations[i];
381                        sigmoid_vals.map(|&val| val * (1.0 - val))
382                    }
383                    ActivationFunction::Tanh => {
384                        let tanh_vals = &activations[i];
385                        tanh_vals.map(|&val| 1.0 - val * val)
386                    }
387                    _ => Array2::ones(z_values[i - 1].dim()),
388                };
389
390                delta = delta.dot(&weights[i]) * derivative_approx;
391            }
392        }
393
394        Ok(())
395    }
396}
397
398impl Predict<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputMLP<MultiOutputMLPTrained> {
399    #[allow(non_snake_case)]
400    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
401        let (n_samples, n_features) = X.dim();
402
403        if n_features != self.state.n_features {
404            return Err(SklearsError::InvalidInput(
405                "X has different number of features than training data".to_string(),
406            ));
407        }
408
409        let X_owned = X.to_owned();
410        let (activations, _) = self.forward_pass_trained(&X_owned)?;
411        let predictions = activations.last().unwrap().clone();
412
413        Ok(predictions)
414    }
415}
416
417impl MultiOutputMLP<MultiOutputMLPTrained> {
418    /// Forward pass for trained model
419    #[allow(clippy::type_complexity)]
420    fn forward_pass_trained(
421        &self,
422        X: &Array2<Float>,
423    ) -> SklResult<(Vec<Array2<Float>>, Vec<Array2<Float>>)> {
424        let mut activations = vec![X.clone()];
425        let mut z_values = Vec::new();
426
427        for (i, (weight, bias)) in self
428            .state
429            .weights
430            .iter()
431            .zip(self.state.biases.iter())
432            .enumerate()
433        {
434            let current_input = activations.last().unwrap();
435
436            // Linear transformation: z = X * W^T + b
437            let z = current_input.dot(&weight.t()) + bias.view().insert_axis(Axis(0));
438            z_values.push(z.clone());
439
440            // Apply activation function
441            let activation_fn = if i == self.state.weights.len() - 1 {
442                self.state.output_activation
443            } else {
444                self.state.activation
445            };
446
447            let activated = activation_fn.apply_2d(&z);
448            activations.push(activated);
449        }
450
451        Ok((activations, z_values))
452    }
453
454    /// Get the loss curve from training
455    pub fn loss_curve(&self) -> &[Float] {
456        &self.state.loss_curve
457    }
458
459    /// Get the number of iterations performed during training
460    pub fn n_iter(&self) -> usize {
461        self.state.n_iter
462    }
463
464    /// Get the network weights
465    pub fn weights(&self) -> &[Array2<Float>] {
466        &self.state.weights
467    }
468
469    /// Get the network biases
470    pub fn biases(&self) -> &[Array1<Float>] {
471        &self.state.biases
472    }
473}
474
475/// Multi-Output MLP Classifier
476///
477/// This is a specialized version of MultiOutputMLP for classification tasks.
478/// It automatically configures the network for multi-class or multi-label classification.
479pub type MultiOutputMLPClassifier<S = Untrained> = MultiOutputMLP<S>;
480
481impl MultiOutputMLPClassifier<Untrained> {
482    /// Create a new classifier with appropriate defaults
483    pub fn new_classifier() -> Self {
484        Self::new()
485            .output_activation(ActivationFunction::Sigmoid)
486            .loss_function(LossFunction::BinaryCrossEntropy)
487    }
488}
489
490/// Multi-Output MLP Regressor
491///
492/// This is a specialized version of MultiOutputMLP for regression tasks.
493/// It automatically configures the network for multi-output regression.
494pub type MultiOutputMLPRegressor<S = Untrained> = MultiOutputMLP<S>;
495
496impl MultiOutputMLPRegressor<Untrained> {
497    /// Create a new regressor with appropriate defaults
498    pub fn new_regressor() -> Self {
499        Self::new()
500            .output_activation(ActivationFunction::Linear)
501            .loss_function(LossFunction::MeanSquaredError)
502    }
503}