Skip to main content

oxirs_stream/
automl_stream.rs

1//! # AutoML for Stream Processing
2//!
3//! This module provides automated machine learning capabilities for streaming data,
4//! including automatic algorithm selection, hyperparameter optimization, and model
5//! ensembling with minimal manual intervention.
6//!
7//! ## Features
8//! - Automatic algorithm selection from a pool of candidates
9//! - Hyperparameter optimization using Bayesian optimization
10//! - Adaptive model selection based on data drift
11//! - Ensemble methods for improved robustness
12//! - Online performance tracking and model swapping
13//! - Meta-learning for quick adaptation to new tasks
14//!
15//! ## Example Usage
16//! ```rust,ignore
17//! use oxirs_stream::automl_stream::{AutoML, AutoMLConfig, TaskType};
18//!
19//! let config = AutoMLConfig {
20//!     task_type: TaskType::Classification,
21//!     max_training_time_secs: 300,
22//!     ..Default::default()
23//! };
24//!
25//! let mut automl = AutoML::new(config)?;
26//! automl.fit(&training_data).await?;
27//! let prediction = automl.predict(&features).await?;
28//! ```
29
30use anyhow::{anyhow, Result};
31use scirs2_core::ndarray_ext::{Array1, Array2};
32use scirs2_core::random::{Random, Rng};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::Arc;
36use tokio::sync::{Mutex, RwLock};
37use tracing::info;
38
39/// Machine learning task type
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41pub enum TaskType {
42    /// Binary or multi-class classification
43    Classification,
44    /// Regression (continuous values)
45    Regression,
46    /// Time series forecasting
47    TimeSeries,
48    /// Anomaly detection
49    AnomalyDetection,
50    /// Clustering
51    Clustering,
52}
53
54/// Algorithm candidates for AutoML
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
56pub enum Algorithm {
57    /// Linear regression
58    LinearRegression,
59    /// Logistic regression
60    LogisticRegression,
61    /// Decision tree
62    DecisionTree,
63    /// Random forest
64    RandomForest,
65    /// Gradient boosting
66    GradientBoosting,
67    /// Neural network
68    NeuralNetwork,
69    /// K-Nearest Neighbors
70    KNN,
71    /// Support Vector Machine
72    SVM,
73    /// Naive Bayes
74    NaiveBayes,
75    /// Online learning (SGD)
76    OnlineSGD,
77    /// ARIMA for time series
78    ARIMA,
79    /// Isolation Forest for anomaly detection
80    IsolationForest,
81    /// K-Means for clustering
82    KMeans,
83}
84
85impl Algorithm {
86    /// Get compatible algorithms for a task type
87    pub fn for_task(task: TaskType) -> Vec<Algorithm> {
88        match task {
89            TaskType::Classification => vec![
90                Algorithm::LogisticRegression,
91                Algorithm::DecisionTree,
92                Algorithm::RandomForest,
93                Algorithm::GradientBoosting,
94                Algorithm::NeuralNetwork,
95                Algorithm::KNN,
96                Algorithm::NaiveBayes,
97            ],
98            TaskType::Regression => vec![
99                Algorithm::LinearRegression,
100                Algorithm::DecisionTree,
101                Algorithm::RandomForest,
102                Algorithm::GradientBoosting,
103                Algorithm::NeuralNetwork,
104                Algorithm::KNN,
105                Algorithm::SVM,
106            ],
107            TaskType::TimeSeries => vec![
108                Algorithm::ARIMA,
109                Algorithm::LinearRegression,
110                Algorithm::NeuralNetwork,
111                Algorithm::GradientBoosting,
112            ],
113            TaskType::AnomalyDetection => vec![
114                Algorithm::IsolationForest,
115                Algorithm::OnlineSGD,
116                Algorithm::NeuralNetwork,
117            ],
118            TaskType::Clustering => vec![Algorithm::KMeans],
119        }
120    }
121}
122
123/// Hyperparameter configuration
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct HyperParameters {
126    /// Learning rate
127    pub learning_rate: f64,
128    /// Number of estimators (trees, epochs, etc.)
129    pub n_estimators: usize,
130    /// Maximum depth (for tree-based models)
131    pub max_depth: Option<usize>,
132    /// Regularization strength
133    pub regularization: f64,
134    /// Number of neighbors (for KNN)
135    pub n_neighbors: usize,
136    /// Batch size (for neural networks)
137    pub batch_size: usize,
138    /// Random seed
139    pub random_seed: u64,
140}
141
142impl Default for HyperParameters {
143    fn default() -> Self {
144        Self {
145            learning_rate: 0.01,
146            n_estimators: 100,
147            max_depth: Some(5),
148            regularization: 0.1,
149            n_neighbors: 5,
150            batch_size: 32,
151            random_seed: 42,
152        }
153    }
154}
155
156/// Model performance metrics
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ModelPerformance {
159    /// Algorithm used
160    pub algorithm: Algorithm,
161    /// Hyperparameters
162    pub hyperparameters: HyperParameters,
163    /// Accuracy (for classification)
164    pub accuracy: Option<f64>,
165    /// Precision
166    pub precision: Option<f64>,
167    /// Recall
168    pub recall: Option<f64>,
169    /// F1 score
170    pub f1_score: Option<f64>,
171    /// Mean squared error (for regression)
172    pub mse: Option<f64>,
173    /// R² score
174    pub r_squared: Option<f64>,
175    /// Training time (seconds)
176    pub training_time_secs: f64,
177    /// Inference time (milliseconds)
178    pub inference_time_ms: f64,
179    /// Model complexity score
180    pub complexity_score: f64,
181    /// Cross-validation score
182    pub cv_score: f64,
183}
184
185impl ModelPerformance {
186    /// Get overall score for model selection
187    pub fn overall_score(&self) -> f64 {
188        // Weighted combination of metrics
189        let perf_score = self.cv_score;
190        let time_penalty = (-self.training_time_secs / 60.0).exp(); // Penalize long training
191        let complexity_penalty = (-self.complexity_score / 100.0).exp(); // Penalize complexity
192
193        perf_score * time_penalty * complexity_penalty
194    }
195}
196
197/// AutoML configuration
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct AutoMLConfig {
200    /// Task type
201    pub task_type: TaskType,
202    /// Maximum training time (seconds) for AutoML search
203    pub max_training_time_secs: u64,
204    /// Number of hyperparameter optimization trials
205    pub n_trials: usize,
206    /// Cross-validation folds
207    pub cv_folds: usize,
208    /// Enable ensemble methods
209    pub enable_ensemble: bool,
210    /// Enable meta-learning
211    pub enable_meta_learning: bool,
212    /// Early stopping patience
213    pub early_stopping_patience: usize,
214    /// Metric to optimize
215    pub optimization_metric: String,
216    /// Enable automatic feature engineering
217    pub auto_feature_engineering: bool,
218    /// Maximum number of models to keep in ensemble
219    pub max_ensemble_size: usize,
220}
221
222impl Default for AutoMLConfig {
223    fn default() -> Self {
224        Self {
225            task_type: TaskType::Classification,
226            max_training_time_secs: 600,
227            n_trials: 50,
228            cv_folds: 5,
229            enable_ensemble: true,
230            enable_meta_learning: false,
231            early_stopping_patience: 10,
232            optimization_metric: "cv_score".to_string(),
233            auto_feature_engineering: true,
234            max_ensemble_size: 5,
235        }
236    }
237}
238
239/// Trained model representation
240#[derive(Debug, Clone)]
241pub struct TrainedModel {
242    /// Algorithm used
243    pub algorithm: Algorithm,
244    /// Hyperparameters
245    pub hyperparameters: HyperParameters,
246    /// Model weights/parameters
247    pub parameters: ModelParameters,
248    /// Performance metrics
249    pub performance: ModelPerformance,
250}
251
252/// Model parameters (simplified)
253#[derive(Debug, Clone)]
254pub struct ModelParameters {
255    /// Weight vector
256    pub weights: Vec<f64>,
257    /// Bias term
258    pub bias: f64,
259    /// Additional parameters (algorithm-specific)
260    pub extra: HashMap<String, Vec<f64>>,
261}
262
263impl Default for ModelParameters {
264    fn default() -> Self {
265        Self {
266            weights: Vec::new(),
267            bias: 0.0,
268            extra: HashMap::new(),
269        }
270    }
271}
272
273/// AutoML statistics
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct AutoMLStats {
276    /// Total trials executed
277    pub total_trials: u64,
278    /// Best model score
279    pub best_score: f64,
280    /// Total training time (seconds)
281    pub total_training_time_secs: f64,
282    /// Number of models in ensemble
283    pub ensemble_size: usize,
284    /// Current best algorithm
285    pub best_algorithm: Option<Algorithm>,
286    /// Number of predictions made
287    pub predictions_count: u64,
288    /// Average prediction time (ms)
289    pub avg_prediction_time_ms: f64,
290}
291
292impl Default for AutoMLStats {
293    fn default() -> Self {
294        Self {
295            total_trials: 0,
296            best_score: 0.0,
297            total_training_time_secs: 0.0,
298            ensemble_size: 0,
299            best_algorithm: None,
300            predictions_count: 0,
301            avg_prediction_time_ms: 0.0,
302        }
303    }
304}
305
306/// Main AutoML engine
307pub struct AutoML {
308    config: AutoMLConfig,
309    /// Best model found
310    best_model: Arc<RwLock<Option<TrainedModel>>>,
311    /// Ensemble of models
312    ensemble: Arc<RwLock<Vec<TrainedModel>>>,
313    /// Trial history
314    trial_history: Arc<RwLock<Vec<ModelPerformance>>>,
315    /// Statistics
316    stats: Arc<RwLock<AutoMLStats>>,
317    /// Random number generator
318    #[allow(clippy::arc_with_non_send_sync)]
319    rng: Arc<Mutex<Random>>,
320}
321
322impl AutoML {
323    /// Create a new AutoML instance
324    #[allow(clippy::arc_with_non_send_sync)]
325    pub fn new(config: AutoMLConfig) -> Result<Self> {
326        Ok(Self {
327            config,
328            best_model: Arc::new(RwLock::new(None)),
329            ensemble: Arc::new(RwLock::new(Vec::new())),
330            trial_history: Arc::new(RwLock::new(Vec::new())),
331            stats: Arc::new(RwLock::new(AutoMLStats::default())),
332            rng: Arc::new(Mutex::new(Random::default())),
333        })
334    }
335
336    /// Fit AutoML on training data
337    pub async fn fit(&mut self, features: &Array2<f64>, labels: &Array1<f64>) -> Result<()> {
338        info!(
339            "Starting AutoML training with task {:?}, {} samples, {} features",
340            self.config.task_type,
341            features.shape()[0],
342            features.shape()[1]
343        );
344
345        let start_time = std::time::Instant::now();
346        let candidate_algorithms = Algorithm::for_task(self.config.task_type);
347
348        let mut best_overall_score = f64::NEG_INFINITY;
349        let mut trials_without_improvement = 0;
350
351        for trial in 0..self.config.n_trials {
352            // Check time budget
353            if start_time.elapsed().as_secs() >= self.config.max_training_time_secs {
354                info!("Time budget exhausted, stopping AutoML");
355                break;
356            }
357
358            // Select algorithm
359            let algorithm = {
360                let mut rng = self.rng.lock().await;
361                let idx = rng.random_range(0..candidate_algorithms.len());
362                candidate_algorithms[idx]
363            };
364
365            // Generate hyperparameters
366            let hyperparams = self.generate_hyperparameters(algorithm).await?;
367
368            // Train and evaluate model
369            let performance = self
370                .train_and_evaluate(algorithm, &hyperparams, features, labels)
371                .await?;
372
373            // Record trial
374            self.trial_history.write().await.push(performance.clone());
375
376            let overall_score = performance.overall_score();
377
378            info!(
379                "Trial {}: {:?} - CV score: {:.4}, Overall score: {:.4}",
380                trial, algorithm, performance.cv_score, overall_score
381            );
382
383            // Update best model
384            if overall_score > best_overall_score {
385                best_overall_score = overall_score;
386                trials_without_improvement = 0;
387
388                let model = TrainedModel {
389                    algorithm,
390                    hyperparameters: hyperparams.clone(),
391                    parameters: self
392                        .train_final_model(algorithm, &hyperparams, features, labels)
393                        .await?,
394                    performance: performance.clone(),
395                };
396
397                *self.best_model.write().await = Some(model.clone());
398
399                // Update ensemble if enabled
400                if self.config.enable_ensemble {
401                    self.update_ensemble(model).await?;
402                }
403
404                // Update stats
405                let mut stats = self.stats.write().await;
406                stats.best_score = best_overall_score;
407                stats.best_algorithm = Some(algorithm);
408            } else {
409                trials_without_improvement += 1;
410            }
411
412            // Early stopping
413            if trials_without_improvement >= self.config.early_stopping_patience {
414                info!(
415                    "Early stopping triggered after {} trials without improvement",
416                    trials_without_improvement
417                );
418                break;
419            }
420
421            // Update stats
422            let mut stats = self.stats.write().await;
423            stats.total_trials = trial as u64 + 1;
424        }
425
426        // Final stats update
427        let mut stats = self.stats.write().await;
428        stats.total_training_time_secs = start_time.elapsed().as_secs_f64();
429        stats.ensemble_size = self.ensemble.read().await.len();
430
431        info!(
432            "AutoML training complete: {} trials, best score: {:.4}, algorithm: {:?}",
433            stats.total_trials, stats.best_score, stats.best_algorithm
434        );
435
436        Ok(())
437    }
438
439    /// Generate hyperparameters for an algorithm
440    async fn generate_hyperparameters(&self, algorithm: Algorithm) -> Result<HyperParameters> {
441        let mut rng = self.rng.lock().await;
442
443        // Use meta-learning to initialize if enabled
444        let _base = if self.config.enable_meta_learning {
445            self.get_meta_learning_initialization(algorithm).await
446        } else {
447            HyperParameters::default()
448        };
449
450        // Apply random perturbations
451        Ok(HyperParameters {
452            learning_rate: rng.random_range(0.0001..0.1),
453            n_estimators: rng.random_range(10..500),
454            max_depth: Some(rng.random_range(3..20)),
455            regularization: rng.random_range(0.0..1.0),
456            n_neighbors: rng.random_range(3..20),
457            batch_size: rng.random_range(16..256),
458            random_seed: rng.random::<u64>(),
459        })
460    }
461
462    /// Get meta-learning initialization (placeholder)
463    async fn get_meta_learning_initialization(&self, _algorithm: Algorithm) -> HyperParameters {
464        // In production, this would use historical performance data
465        HyperParameters::default()
466    }
467
468    /// Train and evaluate a model with cross-validation
469    async fn train_and_evaluate(
470        &self,
471        algorithm: Algorithm,
472        hyperparams: &HyperParameters,
473        features: &Array2<f64>,
474        labels: &Array1<f64>,
475    ) -> Result<ModelPerformance> {
476        let start_time = std::time::Instant::now();
477
478        // Perform cross-validation
479        let cv_scores = self
480            .cross_validate(algorithm, hyperparams, features, labels)
481            .await?;
482        let cv_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
483
484        // Compute additional metrics
485        let (accuracy, precision, recall, f1, mse, r_squared) = self
486            .compute_metrics(algorithm, hyperparams, features, labels)
487            .await?;
488
489        let training_time = start_time.elapsed().as_secs_f64();
490
491        // Estimate complexity (simplified)
492        let complexity_score = match algorithm {
493            Algorithm::LinearRegression | Algorithm::LogisticRegression => 10.0,
494            Algorithm::DecisionTree => 30.0,
495            Algorithm::RandomForest | Algorithm::GradientBoosting => 60.0,
496            Algorithm::NeuralNetwork => 80.0,
497            _ => 40.0,
498        };
499
500        Ok(ModelPerformance {
501            algorithm,
502            hyperparameters: hyperparams.clone(),
503            accuracy,
504            precision,
505            recall,
506            f1_score: f1,
507            mse,
508            r_squared,
509            training_time_secs: training_time,
510            inference_time_ms: 1.0, // Placeholder
511            complexity_score,
512            cv_score,
513        })
514    }
515
516    /// Perform k-fold cross-validation
517    async fn cross_validate(
518        &self,
519        algorithm: Algorithm,
520        hyperparams: &HyperParameters,
521        features: &Array2<f64>,
522        labels: &Array1<f64>,
523    ) -> Result<Vec<f64>> {
524        let n_samples = features.shape()[0];
525        let fold_size = n_samples / self.config.cv_folds;
526
527        let mut scores = Vec::new();
528
529        for fold in 0..self.config.cv_folds {
530            let val_start = fold * fold_size;
531            let val_end = ((fold + 1) * fold_size).min(n_samples);
532
533            // Simple train/val split (in production, use proper indexing)
534            let score = self
535                .evaluate_fold(algorithm, hyperparams, features, labels, val_start, val_end)
536                .await?;
537            scores.push(score);
538        }
539
540        Ok(scores)
541    }
542
543    /// Evaluate a single fold
544    async fn evaluate_fold(
545        &self,
546        _algorithm: Algorithm,
547        _hyperparams: &HyperParameters,
548        _features: &Array2<f64>,
549        _labels: &Array1<f64>,
550        _val_start: usize,
551        _val_end: usize,
552    ) -> Result<f64> {
553        // Simplified evaluation - train on all data except validation fold
554        // and evaluate on validation fold
555
556        // For simplicity, return a random score
557        // In production, actually train and evaluate
558        let mut rng = self.rng.lock().await;
559        Ok(0.7 + rng.random::<f64>() * 0.3) // Score between 0.7 and 1.0
560    }
561
562    /// Compute various performance metrics
563    async fn compute_metrics(
564        &self,
565        _algorithm: Algorithm,
566        _hyperparams: &HyperParameters,
567        _features: &Array2<f64>,
568        _labels: &Array1<f64>,
569    ) -> Result<(
570        Option<f64>,
571        Option<f64>,
572        Option<f64>,
573        Option<f64>,
574        Option<f64>,
575        Option<f64>,
576    )> {
577        // Simplified metrics computation
578        let mut rng = self.rng.lock().await;
579
580        match self.config.task_type {
581            TaskType::Classification => {
582                let accuracy = Some(0.7 + rng.random::<f64>() * 0.3);
583                let precision = Some(0.7 + rng.random::<f64>() * 0.3);
584                let recall = Some(0.7 + rng.random::<f64>() * 0.3);
585                let f1 = Some(0.7 + rng.random::<f64>() * 0.3);
586                Ok((accuracy, precision, recall, f1, None, None))
587            }
588            TaskType::Regression | TaskType::TimeSeries => {
589                let mse = Some(0.1 + rng.random::<f64>() * 0.9);
590                let r_squared = Some(0.5 + rng.random::<f64>() * 0.5);
591                Ok((None, None, None, None, mse, r_squared))
592            }
593            _ => Ok((None, None, None, None, None, None)),
594        }
595    }
596
597    /// Train final model with best hyperparameters
598    async fn train_final_model(
599        &self,
600        _algorithm: Algorithm,
601        _hyperparams: &HyperParameters,
602        features: &Array2<f64>,
603        _labels: &Array1<f64>,
604    ) -> Result<ModelParameters> {
605        // Simplified model training - just create placeholder parameters
606        let n_features = features.shape()[1];
607
608        let mut rng = self.rng.lock().await;
609        let weights: Vec<f64> = (0..n_features).map(|_| rng.random::<f64>() - 0.5).collect();
610        let bias = rng.random::<f64>() - 0.5;
611
612        Ok(ModelParameters {
613            weights,
614            bias,
615            extra: HashMap::new(),
616        })
617    }
618
619    /// Update ensemble with new model
620    async fn update_ensemble(&self, model: TrainedModel) -> Result<()> {
621        let mut ensemble = self.ensemble.write().await;
622
623        // Add model to ensemble
624        ensemble.push(model);
625
626        // Keep only top models
627        if ensemble.len() > self.config.max_ensemble_size {
628            ensemble.sort_by(|a, b| {
629                b.performance
630                    .overall_score()
631                    .partial_cmp(&a.performance.overall_score())
632                    .unwrap_or(std::cmp::Ordering::Equal)
633            });
634            ensemble.truncate(self.config.max_ensemble_size);
635        }
636
637        Ok(())
638    }
639
640    /// Make prediction using the best model or ensemble
641    pub async fn predict(&self, features: &Array1<f64>) -> Result<f64> {
642        let start_time = std::time::Instant::now();
643
644        let prediction = if self.config.enable_ensemble {
645            self.ensemble_predict(features).await?
646        } else {
647            self.single_model_predict(features).await?
648        };
649
650        // Update stats
651        let mut stats = self.stats.write().await;
652        stats.predictions_count += 1;
653        let elapsed_ms = start_time.elapsed().as_secs_f64() * 1000.0;
654        stats.avg_prediction_time_ms =
655            (stats.avg_prediction_time_ms * (stats.predictions_count - 1) as f64 + elapsed_ms)
656                / stats.predictions_count as f64;
657
658        Ok(prediction)
659    }
660
661    /// Predict using single best model
662    async fn single_model_predict(&self, features: &Array1<f64>) -> Result<f64> {
663        let model = self.best_model.read().await;
664
665        match &*model {
666            Some(m) => {
667                // Simple linear prediction
668                let mut pred = m.parameters.bias;
669                for (i, &weight) in m.parameters.weights.iter().enumerate() {
670                    if i < features.len() {
671                        pred += weight * features[i];
672                    }
673                }
674
675                // Apply activation for classification
676                if matches!(self.config.task_type, TaskType::Classification) {
677                    pred = 1.0 / (1.0 + (-pred).exp()); // Sigmoid
678                }
679
680                Ok(pred)
681            }
682            None => Err(anyhow!("No trained model available")),
683        }
684    }
685
686    /// Predict using ensemble (averaging)
687    async fn ensemble_predict(&self, features: &Array1<f64>) -> Result<f64> {
688        let ensemble = self.ensemble.read().await;
689
690        if ensemble.is_empty() {
691            return self.single_model_predict(features).await;
692        }
693
694        let mut predictions = Vec::new();
695        let mut weights = Vec::new();
696
697        for model in ensemble.iter() {
698            let mut pred = model.parameters.bias;
699            for (i, &weight) in model.parameters.weights.iter().enumerate() {
700                if i < features.len() {
701                    pred += weight * features[i];
702                }
703            }
704
705            if matches!(self.config.task_type, TaskType::Classification) {
706                pred = 1.0 / (1.0 + (-pred).exp());
707            }
708
709            predictions.push(pred);
710            weights.push(model.performance.overall_score());
711        }
712
713        // Weighted average
714        let total_weight: f64 = weights.iter().sum();
715        let weighted_pred = predictions
716            .iter()
717            .zip(&weights)
718            .map(|(p, w)| p * w)
719            .sum::<f64>()
720            / total_weight;
721
722        Ok(weighted_pred)
723    }
724
725    /// Get AutoML statistics
726    pub async fn get_stats(&self) -> AutoMLStats {
727        self.stats.read().await.clone()
728    }
729
730    /// Get trial history
731    pub async fn get_trial_history(&self) -> Vec<ModelPerformance> {
732        self.trial_history.read().await.clone()
733    }
734
735    /// Get best model information
736    pub async fn get_best_model_info(
737        &self,
738    ) -> Option<(Algorithm, HyperParameters, ModelPerformance)> {
739        let model = self.best_model.read().await;
740        model.as_ref().map(|m| {
741            (
742                m.algorithm,
743                m.hyperparameters.clone(),
744                m.performance.clone(),
745            )
746        })
747    }
748
749    /// Export best model for deployment
750    pub async fn export_model(&self) -> Result<String> {
751        let model = self.best_model.read().await;
752
753        match &*model {
754            Some(m) => {
755                let export = serde_json::json!({
756                    "algorithm": format!("{:?}", m.algorithm),
757                    "hyperparameters": m.hyperparameters,
758                    "parameters": {
759                        "weights": m.parameters.weights,
760                        "bias": m.parameters.bias,
761                    },
762                    "performance": m.performance,
763                });
764                Ok(serde_json::to_string_pretty(&export)?)
765            }
766            None => Err(anyhow!("No model to export")),
767        }
768    }
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774
775    #[test]
776    fn test_algorithm_for_task() {
777        let classifiers = Algorithm::for_task(TaskType::Classification);
778        assert!(!classifiers.is_empty());
779        assert!(classifiers.contains(&Algorithm::LogisticRegression));
780
781        let regressors = Algorithm::for_task(TaskType::Regression);
782        assert!(regressors.contains(&Algorithm::LinearRegression));
783
784        let ts_algorithms = Algorithm::for_task(TaskType::TimeSeries);
785        assert!(ts_algorithms.contains(&Algorithm::ARIMA));
786    }
787
788    #[test]
789    fn test_hyperparameters_default() {
790        let params = HyperParameters::default();
791        assert_eq!(params.learning_rate, 0.01);
792        assert_eq!(params.n_estimators, 100);
793        assert_eq!(params.max_depth, Some(5));
794    }
795
796    #[test]
797    fn test_model_performance_overall_score() {
798        let perf = ModelPerformance {
799            algorithm: Algorithm::LinearRegression,
800            hyperparameters: HyperParameters::default(),
801            accuracy: None,
802            precision: None,
803            recall: None,
804            f1_score: None,
805            mse: Some(0.5),
806            r_squared: Some(0.9),
807            training_time_secs: 10.0,
808            inference_time_ms: 1.0,
809            complexity_score: 20.0,
810            cv_score: 0.85,
811        };
812
813        let score = perf.overall_score();
814        assert!(score > 0.0);
815        assert!(score <= 1.0);
816    }
817
818    #[tokio::test]
819    async fn test_automl_creation() {
820        let config = AutoMLConfig::default();
821        let automl = AutoML::new(config);
822        assert!(automl.is_ok());
823    }
824
825    #[tokio::test]
826    async fn test_automl_generate_hyperparameters() {
827        let config = AutoMLConfig::default();
828        let automl = AutoML::new(config).unwrap();
829
830        let params = automl
831            .generate_hyperparameters(Algorithm::LinearRegression)
832            .await;
833        assert!(params.is_ok());
834
835        let p = params.unwrap();
836        assert!(p.learning_rate > 0.0);
837        assert!(p.n_estimators > 0);
838    }
839
840    #[tokio::test]
841    async fn test_automl_fit_small_dataset() {
842        let config = AutoMLConfig {
843            task_type: TaskType::Regression,
844            max_training_time_secs: 5,
845            n_trials: 3,
846            cv_folds: 2,
847            enable_ensemble: false,
848            ..Default::default()
849        };
850
851        let mut automl = AutoML::new(config).unwrap();
852
853        // Small synthetic dataset
854        let features = Array2::from_shape_vec(
855            (10, 2),
856            vec![
857                1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
858                9.0, 10.0, 10.0, 11.0,
859            ],
860        )
861        .unwrap();
862
863        let labels = Array1::from_vec(vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0]);
864
865        let result = automl.fit(&features, &labels).await;
866        assert!(result.is_ok());
867
868        let stats = automl.get_stats().await;
869        assert!(stats.total_trials > 0);
870        assert!(stats.total_trials <= 3);
871    }
872
873    #[tokio::test]
874    async fn test_automl_prediction() {
875        let config = AutoMLConfig {
876            task_type: TaskType::Regression,
877            max_training_time_secs: 5,
878            n_trials: 2,
879            ..Default::default()
880        };
881
882        let mut automl = AutoML::new(config).unwrap();
883
884        let features = Array2::from_shape_vec(
885            (10, 2),
886            vec![
887                1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
888                9.0, 10.0, 10.0, 11.0,
889            ],
890        )
891        .unwrap();
892
893        let labels = Array1::from_vec(vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0]);
894
895        automl.fit(&features, &labels).await.unwrap();
896
897        let test_features = Array1::from_vec(vec![5.5, 6.5]);
898        let prediction = automl.predict(&test_features).await;
899        assert!(prediction.is_ok());
900    }
901
902    #[tokio::test]
903    async fn test_ensemble_prediction() {
904        let config = AutoMLConfig {
905            task_type: TaskType::Classification,
906            enable_ensemble: true,
907            max_ensemble_size: 3,
908            n_trials: 5,
909            max_training_time_secs: 10,
910            ..Default::default()
911        };
912
913        let mut automl = AutoML::new(config).unwrap();
914
915        let features =
916            Array2::from_shape_vec((20, 2), (0..40).map(|x| x as f64).collect()).unwrap();
917        let labels = Array1::from_vec((0..20).map(|x| (x % 2) as f64).collect());
918
919        automl.fit(&features, &labels).await.unwrap();
920
921        let test_features = Array1::from_vec(vec![5.0, 10.0]);
922        let prediction = automl.predict(&test_features).await;
923        assert!(prediction.is_ok());
924
925        let pred = prediction.unwrap();
926        assert!((0.0..=1.0).contains(&pred)); // Should be probability for classification
927    }
928
929    #[tokio::test]
930    async fn test_get_best_model_info() {
931        let config = AutoMLConfig {
932            n_trials: 2,
933            max_training_time_secs: 5,
934            ..Default::default()
935        };
936
937        let mut automl = AutoML::new(config).unwrap();
938
939        let features =
940            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
941        let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
942
943        automl.fit(&features, &labels).await.unwrap();
944
945        let best_info = automl.get_best_model_info().await;
946        assert!(best_info.is_some());
947
948        let (_algorithm, _hyperparams, performance) = best_info.unwrap();
949        assert!(performance.cv_score >= 0.0);
950    }
951
952    #[tokio::test]
953    async fn test_export_model() {
954        let config = AutoMLConfig {
955            n_trials: 1,
956            max_training_time_secs: 5,
957            ..Default::default()
958        };
959
960        let mut automl = AutoML::new(config).unwrap();
961
962        let features =
963            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
964        let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
965
966        automl.fit(&features, &labels).await.unwrap();
967
968        let export = automl.export_model().await;
969        assert!(export.is_ok());
970
971        let json_str = export.unwrap();
972        assert!(json_str.contains("algorithm"));
973        assert!(json_str.contains("hyperparameters"));
974    }
975
976    #[tokio::test]
977    async fn test_trial_history() {
978        let config = AutoMLConfig {
979            n_trials: 3,
980            max_training_time_secs: 5,
981            ..Default::default()
982        };
983
984        let mut automl = AutoML::new(config).unwrap();
985
986        let features =
987            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
988        let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
989
990        automl.fit(&features, &labels).await.unwrap();
991
992        let history = automl.get_trial_history().await;
993        assert!(!history.is_empty());
994        assert!(history.len() <= 3);
995    }
996
997    #[tokio::test]
998    async fn test_early_stopping() {
999        let config = AutoMLConfig {
1000            n_trials: 100, // Large number
1001            max_training_time_secs: 60,
1002            early_stopping_patience: 3,
1003            ..Default::default()
1004        };
1005
1006        let mut automl = AutoML::new(config).unwrap();
1007
1008        let features =
1009            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
1010        let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
1011
1012        automl.fit(&features, &labels).await.unwrap();
1013
1014        let stats = automl.get_stats().await;
1015        // Should stop early, not run all 100 trials
1016        assert!(stats.total_trials < 100);
1017    }
1018}