sklears_multioutput/
performance.rs

1//! Performance Optimization for Multi-Output Learning
2//!
3//! This module provides optimized algorithms and utilities for improving computational
4//! efficiency in multi-output learning scenarios.
5
6// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Predict, Untrained},
11    types::Float,
12};
13use std::collections::HashMap;
14
15// ============================================================================
16// Early Stopping Criteria
17// ============================================================================
18
19/// Early stopping configuration
20#[derive(Debug, Clone)]
21pub struct EarlyStoppingConfig {
22    /// Minimum improvement required to continue training
23    pub min_delta: Float,
24    /// Number of iterations with no improvement before stopping
25    pub patience: usize,
26    /// Metric to monitor ("loss" or "validation_score")
27    pub monitor: String,
28    /// Whether higher metric values are better
29    pub mode_max: bool,
30    /// Restore best weights when stopping
31    pub restore_best_weights: bool,
32}
33
34impl Default for EarlyStoppingConfig {
35    fn default() -> Self {
36        Self {
37            min_delta: 1e-4,
38            patience: 10,
39            monitor: "loss".to_string(),
40            mode_max: false,
41            restore_best_weights: true,
42        }
43    }
44}
45
46/// Early stopping tracker
47#[derive(Debug, Clone)]
48pub struct EarlyStopping {
49    config: EarlyStoppingConfig,
50    best_value: Option<Float>,
51    best_iteration: usize,
52    wait_count: usize,
53    should_stop: bool,
54}
55
56impl EarlyStopping {
57    /// Create a new early stopping tracker
58    pub fn new(config: EarlyStoppingConfig) -> Self {
59        Self {
60            config,
61            best_value: None,
62            best_iteration: 0,
63            wait_count: 0,
64            should_stop: false,
65        }
66    }
67
68    /// Update with new metric value
69    pub fn update(&mut self, value: Float, iteration: usize) -> bool {
70        match self.best_value {
71            None => {
72                self.best_value = Some(value);
73                self.best_iteration = iteration;
74                false
75            }
76            Some(best) => {
77                let is_improvement = if self.config.mode_max {
78                    value > best + self.config.min_delta
79                } else {
80                    value < best - self.config.min_delta
81                };
82
83                if is_improvement {
84                    self.best_value = Some(value);
85                    self.best_iteration = iteration;
86                    self.wait_count = 0;
87                    false
88                } else {
89                    self.wait_count += 1;
90                    if self.wait_count >= self.config.patience {
91                        self.should_stop = true;
92                        true
93                    } else {
94                        false
95                    }
96                }
97            }
98        }
99    }
100
101    /// Check if should stop
102    pub fn should_stop(&self) -> bool {
103        self.should_stop
104    }
105
106    /// Get best value
107    pub fn best_value(&self) -> Option<Float> {
108        self.best_value
109    }
110
111    /// Get best iteration
112    pub fn best_iteration(&self) -> usize {
113        self.best_iteration
114    }
115}
116
117// ============================================================================
118// Warm Start Multi-Output Regressor
119// ============================================================================
120
121/// Configuration for warm start regressor
122#[derive(Debug, Clone)]
123pub struct WarmStartRegressorConfig {
124    /// Maximum number of iterations
125    pub max_iter: usize,
126    /// Learning rate
127    pub learning_rate: Float,
128    /// L2 regularization
129    pub alpha: Float,
130    /// Tolerance for convergence
131    pub tol: Float,
132    /// Early stopping configuration
133    pub early_stopping: Option<EarlyStoppingConfig>,
134    /// Verbosity level
135    pub verbose: bool,
136}
137
138impl Default for WarmStartRegressorConfig {
139    fn default() -> Self {
140        Self {
141            max_iter: 1000,
142            learning_rate: 0.01,
143            alpha: 0.0001,
144            tol: 1e-4,
145            early_stopping: Some(EarlyStoppingConfig::default()),
146            verbose: false,
147        }
148    }
149}
150
151/// Warm Start Multi-Output Regressor
152///
153/// Multi-output regressor with warm start capabilities for iterative optimization.
154/// Supports resuming training from previous state and early stopping.
155///
156/// # Examples
157///
158/// ```rust
159/// use sklears_multioutput::performance::{WarmStartRegressor, WarmStartRegressorConfig};
160/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
161/// use scirs2_core::ndarray::array;
162/// use sklears_core::traits::{Fit, Predict};
163///
164/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
165/// let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
166///
167/// let mut config = WarmStartRegressorConfig::default();
168/// config.max_iter = 100;
169///
170/// let model = WarmStartRegressor::new().config(config);
171/// let trained = model.fit(&X.view(), &y.view()).unwrap();
172///
173/// // Continue training with warm start
174/// let continued = trained.continue_training(&X.view(), &y.view(), 50).unwrap();
175///
176/// let predictions = continued.predict(&X.view()).unwrap();
177/// assert_eq!(predictions.dim(), (3, 2));
178/// ```
179#[derive(Debug, Clone)]
180pub struct WarmStartRegressor<S = Untrained> {
181    state: S,
182    config: WarmStartRegressorConfig,
183}
184
185/// Trained state for Warm Start Regressor
186#[derive(Debug, Clone)]
187pub struct WarmStartRegressorTrained {
188    /// Coefficient matrix
189    pub coef: Array2<Float>,
190    /// Intercept vector
191    pub intercept: Array1<Float>,
192    /// Number of features
193    pub n_features: usize,
194    /// Number of outputs
195    pub n_outputs: usize,
196    /// Number of iterations performed
197    pub n_iter: usize,
198    /// Loss history
199    pub loss_history: Vec<Float>,
200    /// Best loss achieved
201    pub best_loss: Float,
202    /// Best iteration
203    pub best_iter: usize,
204    /// Best coefficients (if early stopping enabled)
205    pub best_coef: Option<Array2<Float>>,
206    /// Best intercept (if early stopping enabled)
207    pub best_intercept: Option<Array1<Float>>,
208    /// Whether converged
209    pub converged: bool,
210    /// Configuration
211    pub config: WarmStartRegressorConfig,
212}
213
214impl WarmStartRegressor<Untrained> {
215    /// Create a new warm start regressor
216    pub fn new() -> Self {
217        Self {
218            state: Untrained,
219            config: WarmStartRegressorConfig::default(),
220        }
221    }
222
223    /// Set the configuration
224    pub fn config(mut self, config: WarmStartRegressorConfig) -> Self {
225        self.config = config;
226        self
227    }
228
229    /// Set maximum iterations
230    pub fn max_iter(mut self, max_iter: usize) -> Self {
231        self.config.max_iter = max_iter;
232        self
233    }
234
235    /// Set learning rate
236    pub fn learning_rate(mut self, lr: Float) -> Self {
237        self.config.learning_rate = lr;
238        self
239    }
240
241    /// Enable early stopping
242    pub fn early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
243        self.config.early_stopping = Some(config);
244        self
245    }
246}
247
248impl Default for WarmStartRegressor<Untrained> {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for WarmStartRegressor<Untrained> {
255    type Fitted = WarmStartRegressor<WarmStartRegressorTrained>;
256
257    fn fit(self, X: &ArrayView2<Float>, y: &ArrayView2<Float>) -> SklResult<Self::Fitted> {
258        if X.nrows() != y.nrows() {
259            return Err(SklearsError::InvalidInput(
260                "Number of samples in X and y must match".to_string(),
261            ));
262        }
263
264        let n_samples = X.nrows();
265        let n_features = X.ncols();
266        let n_outputs = y.ncols();
267
268        // Initialize coefficients
269        let mut coef = Array2::zeros((n_features, n_outputs));
270        let mut intercept = Array1::zeros(n_outputs);
271
272        let mut loss_history = Vec::new();
273        let mut best_loss = Float::INFINITY;
274        let mut best_iter = 0;
275        let mut best_coef = None;
276        let mut best_intercept = None;
277
278        let mut early_stopping = self
279            .config
280            .early_stopping
281            .as_ref()
282            .map(|cfg| EarlyStopping::new(cfg.clone()));
283
284        let mut converged = false;
285
286        // Gradient descent with early stopping
287        for iter in 0..self.config.max_iter {
288            let mut total_loss = 0.0;
289
290            // Compute predictions and gradients
291            for i in 0..n_samples {
292                let x_i = X.row(i);
293                let y_i = y.row(i);
294
295                // Prediction
296                let pred = coef.t().dot(&x_i) + &intercept;
297
298                // Error
299                let error = &y_i - &pred;
300                total_loss += error.mapv(|x| x.powi(2)).sum();
301
302                // Update coefficients
303                for j in 0..n_features {
304                    for k in 0..n_outputs {
305                        let gradient = -error[k] * x_i[j] + self.config.alpha * coef[[j, k]];
306                        coef[[j, k]] -= self.config.learning_rate * gradient;
307                    }
308                }
309
310                // Update intercept
311                for k in 0..n_outputs {
312                    intercept[k] += self.config.learning_rate * error[k];
313                }
314            }
315
316            // Average loss
317            let avg_loss = total_loss / (n_samples as Float * n_outputs as Float);
318            loss_history.push(avg_loss);
319
320            // Track best model
321            if avg_loss < best_loss {
322                best_loss = avg_loss;
323                best_iter = iter;
324                if self.config.early_stopping.is_some() {
325                    best_coef = Some(coef.clone());
326                    best_intercept = Some(intercept.clone());
327                }
328            }
329
330            // Check convergence
331            if iter > 0 && (loss_history[iter - 1] - avg_loss).abs() < self.config.tol {
332                converged = true;
333                if self.config.verbose {
334                    println!("Converged at iteration {}", iter);
335                }
336                break;
337            }
338
339            // Early stopping
340            if let Some(ref mut es) = early_stopping {
341                if es.update(avg_loss, iter) {
342                    if self.config.verbose {
343                        println!("Early stopping at iteration {}", iter);
344                    }
345                    break;
346                }
347            }
348
349            if self.config.verbose && iter % 100 == 0 {
350                println!("Iteration {}: loss = {:.6}", iter, avg_loss);
351            }
352        }
353
354        // Restore best weights if early stopping is enabled
355        if let Some(cfg) = &self.config.early_stopping {
356            if cfg.restore_best_weights {
357                if let Some(ref best_c) = best_coef {
358                    coef = best_c.clone();
359                }
360                if let Some(ref best_i) = best_intercept {
361                    intercept = best_i.clone();
362                }
363            }
364        }
365
366        Ok(WarmStartRegressor {
367            state: WarmStartRegressorTrained {
368                coef,
369                intercept,
370                n_features,
371                n_outputs,
372                n_iter: loss_history.len(),
373                loss_history,
374                best_loss,
375                best_iter,
376                best_coef,
377                best_intercept,
378                converged,
379                config: self.config,
380            },
381            config: WarmStartRegressorConfig::default(),
382        })
383    }
384}
385
386impl WarmStartRegressor<WarmStartRegressorTrained> {
387    /// Continue training from current state
388    pub fn continue_training(
389        mut self,
390        X: &ArrayView2<Float>,
391        y: &ArrayView2<Float>,
392        additional_iterations: usize,
393    ) -> SklResult<Self> {
394        if X.nrows() != y.nrows() {
395            return Err(SklearsError::InvalidInput(
396                "Number of samples in X and y must match".to_string(),
397            ));
398        }
399
400        if X.ncols() != self.state.n_features || y.ncols() != self.state.n_outputs {
401            return Err(SklearsError::InvalidInput(
402                "Feature or output dimensions do not match".to_string(),
403            ));
404        }
405
406        let n_samples = X.nrows();
407
408        let mut early_stopping = self
409            .state
410            .config
411            .early_stopping
412            .as_ref()
413            .map(|cfg| EarlyStopping::new(cfg.clone()));
414
415        // Continue from where we left off
416        for iter in 0..additional_iterations {
417            let mut total_loss = 0.0;
418
419            // Gradient descent step
420            for i in 0..n_samples {
421                let x_i = X.row(i);
422                let y_i = y.row(i);
423
424                let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
425                let error = &y_i - &pred;
426                total_loss += error.mapv(|x| x.powi(2)).sum();
427
428                // Update coefficients
429                for j in 0..self.state.n_features {
430                    for k in 0..self.state.n_outputs {
431                        let gradient =
432                            -error[k] * x_i[j] + self.state.config.alpha * self.state.coef[[j, k]];
433                        self.state.coef[[j, k]] -= self.state.config.learning_rate * gradient;
434                    }
435                }
436
437                // Update intercept
438                for k in 0..self.state.n_outputs {
439                    self.state.intercept[k] += self.state.config.learning_rate * error[k];
440                }
441            }
442
443            let avg_loss = total_loss / (n_samples as Float * self.state.n_outputs as Float);
444            self.state.loss_history.push(avg_loss);
445
446            // Update best
447            if avg_loss < self.state.best_loss {
448                self.state.best_loss = avg_loss;
449                self.state.best_iter = self.state.n_iter + iter;
450                if self.state.config.early_stopping.is_some() {
451                    self.state.best_coef = Some(self.state.coef.clone());
452                    self.state.best_intercept = Some(self.state.intercept.clone());
453                }
454            }
455
456            // Check convergence
457            let loss_len = self.state.loss_history.len();
458            if loss_len > 1 {
459                let prev_loss = self.state.loss_history[loss_len - 2];
460                if (prev_loss - avg_loss).abs() < self.state.config.tol {
461                    self.state.converged = true;
462                    break;
463                }
464            }
465
466            // Early stopping
467            if let Some(ref mut es) = early_stopping {
468                if es.update(avg_loss, self.state.n_iter + iter) {
469                    break;
470                }
471            }
472        }
473
474        self.state.n_iter += additional_iterations;
475        Ok(self)
476    }
477
478    /// Get training history
479    pub fn loss_history(&self) -> &[Float] {
480        &self.state.loss_history
481    }
482
483    /// Get best loss
484    pub fn best_loss(&self) -> Float {
485        self.state.best_loss
486    }
487
488    /// Check if converged
489    pub fn converged(&self) -> bool {
490        self.state.converged
491    }
492
493    /// Get coefficients
494    pub fn coef(&self) -> &Array2<Float> {
495        &self.state.coef
496    }
497
498    /// Get number of iterations performed
499    pub fn n_iter(&self) -> usize {
500        self.state.n_iter
501    }
502}
503
504impl Predict<ArrayView2<'_, Float>, Array2<Float>>
505    for WarmStartRegressor<WarmStartRegressorTrained>
506{
507    fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
508        if X.ncols() != self.state.n_features {
509            return Err(SklearsError::InvalidInput(format!(
510                "Expected {} features, got {}",
511                self.state.n_features,
512                X.ncols()
513            )));
514        }
515
516        let n_samples = X.nrows();
517        let mut predictions = Array2::zeros((n_samples, self.state.n_outputs));
518
519        for i in 0..n_samples {
520            let x_i = X.row(i);
521            let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
522            predictions.row_mut(i).assign(&pred);
523        }
524
525        Ok(predictions)
526    }
527}
528
529impl Estimator for WarmStartRegressor<Untrained> {
530    type Config = WarmStartRegressorConfig;
531    type Error = SklearsError;
532    type Float = Float;
533
534    fn config(&self) -> &Self::Config {
535        &self.config
536    }
537}
538
539impl Estimator for WarmStartRegressor<WarmStartRegressorTrained> {
540    type Config = WarmStartRegressorConfig;
541    type Error = SklearsError;
542    type Float = Float;
543
544    fn config(&self) -> &Self::Config {
545        &self.state.config
546    }
547}
548
549// ============================================================================
550// Fast Prediction Cache
551// ============================================================================
552
553/// Prediction cache for fast repeated predictions
554#[derive(Debug, Clone)]
555pub struct PredictionCache {
556    /// Cached predictions keyed by input hash
557    cache: HashMap<u64, Array2<Float>>,
558    /// Maximum cache size
559    max_size: usize,
560    /// Number of cache hits
561    hits: usize,
562    /// Number of cache misses
563    misses: usize,
564}
565
566impl PredictionCache {
567    /// Create a new prediction cache
568    pub fn new(max_size: usize) -> Self {
569        Self {
570            cache: HashMap::new(),
571            max_size,
572            hits: 0,
573            misses: 0,
574        }
575    }
576
577    /// Get cached prediction
578    pub fn get(&mut self, X: &ArrayView2<Float>) -> Option<Array2<Float>> {
579        let hash = self.hash_input(X);
580        if let Some(pred) = self.cache.get(&hash) {
581            self.hits += 1;
582            Some(pred.clone())
583        } else {
584            self.misses += 1;
585            None
586        }
587    }
588
589    /// Store prediction in cache
590    pub fn put(&mut self, X: &ArrayView2<Float>, prediction: Array2<Float>) {
591        if self.cache.len() >= self.max_size {
592            // Simple eviction: remove first entry
593            if let Some(first_key) = self.cache.keys().next().copied() {
594                self.cache.remove(&first_key);
595            }
596        }
597        let hash = self.hash_input(X);
598        self.cache.insert(hash, prediction);
599    }
600
601    /// Clear cache
602    pub fn clear(&mut self) {
603        self.cache.clear();
604    }
605
606    /// Get cache statistics
607    pub fn stats(&self) -> (usize, usize, Float) {
608        let total = self.hits + self.misses;
609        let hit_rate = if total > 0 {
610            self.hits as Float / total as Float
611        } else {
612            0.0
613        };
614        (self.hits, self.misses, hit_rate)
615    }
616
617    /// Simple hash function for input
618    fn hash_input(&self, X: &ArrayView2<Float>) -> u64 {
619        use std::collections::hash_map::DefaultHasher;
620        use std::hash::{Hash, Hasher};
621
622        let mut hasher = DefaultHasher::new();
623        for &val in X.iter() {
624            val.to_bits().hash(&mut hasher);
625        }
626        hasher.finish()
627    }
628}
629
630// ============================================================================
631// Tests
632// ============================================================================
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use approx::assert_abs_diff_eq;
638    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
639    use scirs2_core::ndarray::array;
640
641    #[test]
642    fn test_early_stopping_basic() {
643        let config = EarlyStoppingConfig {
644            min_delta: 0.1,
645            patience: 3,
646            mode_max: false,
647            ..Default::default()
648        };
649
650        let mut es = EarlyStopping::new(config);
651
652        assert!(!es.update(1.0, 0));
653        assert!(!es.update(0.8, 1)); // Improvement (1.0 - 0.8 = 0.2 > min_delta)
654        assert!(!es.update(0.79, 2)); // No improvement #1 (0.8 - 0.79 = 0.01 < min_delta)
655        assert!(!es.update(0.78, 3)); // No improvement #2
656        assert!(es.update(0.77, 4)); // No improvement #3, should stop after patience (3)
657    }
658
659    #[test]
660    fn test_early_stopping_mode_max() {
661        let config = EarlyStoppingConfig {
662            min_delta: 0.01,
663            patience: 2,
664            mode_max: true,
665            ..Default::default()
666        };
667
668        let mut es = EarlyStopping::new(config);
669
670        assert!(!es.update(0.5, 0));
671        assert!(!es.update(0.6, 1)); // Improvement
672        assert!(!es.update(0.59, 2)); // No improvement
673        assert!(es.update(0.58, 3)); // Should stop
674    }
675
676    #[test]
677    #[allow(non_snake_case)]
678    fn test_warm_start_regressor_basic() {
679        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
680        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
681
682        let model = WarmStartRegressor::new().max_iter(100).learning_rate(0.1);
683
684        let trained = model.fit(&X.view(), &y.view()).unwrap();
685        let predictions = trained.predict(&X.view()).unwrap();
686
687        assert_eq!(predictions.dim(), (3, 2));
688        assert!(trained.n_iter() > 0);
689    }
690
691    #[test]
692    #[allow(non_snake_case)]
693    fn test_warm_start_continue_training() {
694        let X = array![[1.0, 2.0], [2.0, 3.0]];
695        let y = array![[1.0, 2.0], [2.0, 3.0]];
696
697        let model = WarmStartRegressor::new().max_iter(10).learning_rate(0.1);
698
699        let trained = model.fit(&X.view(), &y.view()).unwrap();
700        let initial_iter = trained.n_iter();
701        let initial_loss = trained.loss_history().last().copied().unwrap();
702
703        // Continue training
704        let continued = trained.continue_training(&X.view(), &y.view(), 20).unwrap();
705        let final_loss = continued.loss_history().last().copied().unwrap();
706
707        assert!(continued.n_iter() > initial_iter);
708        // Loss should generally decrease (or stay similar)
709        assert!(final_loss <= initial_loss + 1.0); // Allow some tolerance
710    }
711
712    #[test]
713    #[allow(non_snake_case)]
714    fn test_warm_start_with_early_stopping() {
715        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
716        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
717
718        let es_config = EarlyStoppingConfig {
719            patience: 5,
720            min_delta: 1e-6,
721            ..Default::default()
722        };
723
724        let model = WarmStartRegressor::new()
725            .max_iter(1000)
726            .early_stopping(es_config)
727            .learning_rate(0.1);
728
729        let trained = model.fit(&X.view(), &y.view()).unwrap();
730
731        // Should stop early due to convergence
732        assert!(trained.n_iter() < 1000);
733        assert!(trained.best_loss() < Float::INFINITY);
734    }
735
736    #[test]
737    fn test_prediction_cache_basic() {
738        let mut cache = PredictionCache::new(10);
739
740        let X = array![[1.0, 2.0], [2.0, 3.0]];
741        let pred = array![[1.0, 2.0], [2.0, 3.0]];
742
743        // Cache miss
744        assert!(cache.get(&X.view()).is_none());
745
746        // Store and retrieve
747        cache.put(&X.view(), pred.clone());
748        let cached = cache.get(&X.view()).unwrap();
749
750        assert_eq!(cached.dim(), pred.dim());
751        assert_eq!(cache.stats().0, 1); // 1 hit
752        assert_eq!(cache.stats().1, 1); // 1 miss
753    }
754
755    #[test]
756    fn test_prediction_cache_eviction() {
757        let mut cache = PredictionCache::new(2);
758
759        let X1 = array![[1.0, 2.0]];
760        let X2 = array![[2.0, 3.0]];
761        let X3 = array![[3.0, 4.0]];
762        let pred = array![[1.0, 2.0]];
763
764        cache.put(&X1.view(), pred.clone());
765        cache.put(&X2.view(), pred.clone());
766        cache.put(&X3.view(), pred.clone()); // Should evict oldest
767
768        assert_eq!(cache.cache.len(), 2);
769    }
770
771    #[test]
772    fn test_cache_stats() {
773        let mut cache = PredictionCache::new(10);
774
775        let X = array![[1.0, 2.0]];
776        let pred = array![[1.0, 2.0]];
777
778        cache.get(&X.view()); // miss
779        cache.put(&X.view(), pred);
780        cache.get(&X.view()); // hit
781        cache.get(&X.view()); // hit
782
783        let (hits, misses, hit_rate) = cache.stats();
784        assert_eq!(hits, 2);
785        assert_eq!(misses, 1);
786        assert_abs_diff_eq!(hit_rate, 2.0 / 3.0, epsilon = 1e-6);
787    }
788}