1use anyhow::{anyhow, Result};
31use scirs2_core::ndarray_ext::{Array1, Array2};
32use scirs2_core::random::{Random, Rng};
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::Arc;
36use tokio::sync::{Mutex, RwLock};
37use tracing::info;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41pub enum TaskType {
42 Classification,
44 Regression,
46 TimeSeries,
48 AnomalyDetection,
50 Clustering,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
56pub enum Algorithm {
57 LinearRegression,
59 LogisticRegression,
61 DecisionTree,
63 RandomForest,
65 GradientBoosting,
67 NeuralNetwork,
69 KNN,
71 SVM,
73 NaiveBayes,
75 OnlineSGD,
77 ARIMA,
79 IsolationForest,
81 KMeans,
83}
84
85impl Algorithm {
86 pub fn for_task(task: TaskType) -> Vec<Algorithm> {
88 match task {
89 TaskType::Classification => vec![
90 Algorithm::LogisticRegression,
91 Algorithm::DecisionTree,
92 Algorithm::RandomForest,
93 Algorithm::GradientBoosting,
94 Algorithm::NeuralNetwork,
95 Algorithm::KNN,
96 Algorithm::NaiveBayes,
97 ],
98 TaskType::Regression => vec![
99 Algorithm::LinearRegression,
100 Algorithm::DecisionTree,
101 Algorithm::RandomForest,
102 Algorithm::GradientBoosting,
103 Algorithm::NeuralNetwork,
104 Algorithm::KNN,
105 Algorithm::SVM,
106 ],
107 TaskType::TimeSeries => vec![
108 Algorithm::ARIMA,
109 Algorithm::LinearRegression,
110 Algorithm::NeuralNetwork,
111 Algorithm::GradientBoosting,
112 ],
113 TaskType::AnomalyDetection => vec![
114 Algorithm::IsolationForest,
115 Algorithm::OnlineSGD,
116 Algorithm::NeuralNetwork,
117 ],
118 TaskType::Clustering => vec![Algorithm::KMeans],
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct HyperParameters {
126 pub learning_rate: f64,
128 pub n_estimators: usize,
130 pub max_depth: Option<usize>,
132 pub regularization: f64,
134 pub n_neighbors: usize,
136 pub batch_size: usize,
138 pub random_seed: u64,
140}
141
142impl Default for HyperParameters {
143 fn default() -> Self {
144 Self {
145 learning_rate: 0.01,
146 n_estimators: 100,
147 max_depth: Some(5),
148 regularization: 0.1,
149 n_neighbors: 5,
150 batch_size: 32,
151 random_seed: 42,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct ModelPerformance {
159 pub algorithm: Algorithm,
161 pub hyperparameters: HyperParameters,
163 pub accuracy: Option<f64>,
165 pub precision: Option<f64>,
167 pub recall: Option<f64>,
169 pub f1_score: Option<f64>,
171 pub mse: Option<f64>,
173 pub r_squared: Option<f64>,
175 pub training_time_secs: f64,
177 pub inference_time_ms: f64,
179 pub complexity_score: f64,
181 pub cv_score: f64,
183}
184
185impl ModelPerformance {
186 pub fn overall_score(&self) -> f64 {
188 let perf_score = self.cv_score;
190 let time_penalty = (-self.training_time_secs / 60.0).exp(); let complexity_penalty = (-self.complexity_score / 100.0).exp(); perf_score * time_penalty * complexity_penalty
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct AutoMLConfig {
200 pub task_type: TaskType,
202 pub max_training_time_secs: u64,
204 pub n_trials: usize,
206 pub cv_folds: usize,
208 pub enable_ensemble: bool,
210 pub enable_meta_learning: bool,
212 pub early_stopping_patience: usize,
214 pub optimization_metric: String,
216 pub auto_feature_engineering: bool,
218 pub max_ensemble_size: usize,
220}
221
222impl Default for AutoMLConfig {
223 fn default() -> Self {
224 Self {
225 task_type: TaskType::Classification,
226 max_training_time_secs: 600,
227 n_trials: 50,
228 cv_folds: 5,
229 enable_ensemble: true,
230 enable_meta_learning: false,
231 early_stopping_patience: 10,
232 optimization_metric: "cv_score".to_string(),
233 auto_feature_engineering: true,
234 max_ensemble_size: 5,
235 }
236 }
237}
238
239#[derive(Debug, Clone)]
241pub struct TrainedModel {
242 pub algorithm: Algorithm,
244 pub hyperparameters: HyperParameters,
246 pub parameters: ModelParameters,
248 pub performance: ModelPerformance,
250}
251
252#[derive(Debug, Clone)]
254pub struct ModelParameters {
255 pub weights: Vec<f64>,
257 pub bias: f64,
259 pub extra: HashMap<String, Vec<f64>>,
261}
262
263impl Default for ModelParameters {
264 fn default() -> Self {
265 Self {
266 weights: Vec::new(),
267 bias: 0.0,
268 extra: HashMap::new(),
269 }
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct AutoMLStats {
276 pub total_trials: u64,
278 pub best_score: f64,
280 pub total_training_time_secs: f64,
282 pub ensemble_size: usize,
284 pub best_algorithm: Option<Algorithm>,
286 pub predictions_count: u64,
288 pub avg_prediction_time_ms: f64,
290}
291
292impl Default for AutoMLStats {
293 fn default() -> Self {
294 Self {
295 total_trials: 0,
296 best_score: 0.0,
297 total_training_time_secs: 0.0,
298 ensemble_size: 0,
299 best_algorithm: None,
300 predictions_count: 0,
301 avg_prediction_time_ms: 0.0,
302 }
303 }
304}
305
306pub struct AutoML {
308 config: AutoMLConfig,
309 best_model: Arc<RwLock<Option<TrainedModel>>>,
311 ensemble: Arc<RwLock<Vec<TrainedModel>>>,
313 trial_history: Arc<RwLock<Vec<ModelPerformance>>>,
315 stats: Arc<RwLock<AutoMLStats>>,
317 #[allow(clippy::arc_with_non_send_sync)]
319 rng: Arc<Mutex<Random>>,
320}
321
322impl AutoML {
323 #[allow(clippy::arc_with_non_send_sync)]
325 pub fn new(config: AutoMLConfig) -> Result<Self> {
326 Ok(Self {
327 config,
328 best_model: Arc::new(RwLock::new(None)),
329 ensemble: Arc::new(RwLock::new(Vec::new())),
330 trial_history: Arc::new(RwLock::new(Vec::new())),
331 stats: Arc::new(RwLock::new(AutoMLStats::default())),
332 rng: Arc::new(Mutex::new(Random::default())),
333 })
334 }
335
336 pub async fn fit(&mut self, features: &Array2<f64>, labels: &Array1<f64>) -> Result<()> {
338 info!(
339 "Starting AutoML training with task {:?}, {} samples, {} features",
340 self.config.task_type,
341 features.shape()[0],
342 features.shape()[1]
343 );
344
345 let start_time = std::time::Instant::now();
346 let candidate_algorithms = Algorithm::for_task(self.config.task_type);
347
348 let mut best_overall_score = f64::NEG_INFINITY;
349 let mut trials_without_improvement = 0;
350
351 for trial in 0..self.config.n_trials {
352 if start_time.elapsed().as_secs() >= self.config.max_training_time_secs {
354 info!("Time budget exhausted, stopping AutoML");
355 break;
356 }
357
358 let algorithm = {
360 let mut rng = self.rng.lock().await;
361 let idx = rng.random_range(0..candidate_algorithms.len());
362 candidate_algorithms[idx]
363 };
364
365 let hyperparams = self.generate_hyperparameters(algorithm).await?;
367
368 let performance = self
370 .train_and_evaluate(algorithm, &hyperparams, features, labels)
371 .await?;
372
373 self.trial_history.write().await.push(performance.clone());
375
376 let overall_score = performance.overall_score();
377
378 info!(
379 "Trial {}: {:?} - CV score: {:.4}, Overall score: {:.4}",
380 trial, algorithm, performance.cv_score, overall_score
381 );
382
383 if overall_score > best_overall_score {
385 best_overall_score = overall_score;
386 trials_without_improvement = 0;
387
388 let model = TrainedModel {
389 algorithm,
390 hyperparameters: hyperparams.clone(),
391 parameters: self
392 .train_final_model(algorithm, &hyperparams, features, labels)
393 .await?,
394 performance: performance.clone(),
395 };
396
397 *self.best_model.write().await = Some(model.clone());
398
399 if self.config.enable_ensemble {
401 self.update_ensemble(model).await?;
402 }
403
404 let mut stats = self.stats.write().await;
406 stats.best_score = best_overall_score;
407 stats.best_algorithm = Some(algorithm);
408 } else {
409 trials_without_improvement += 1;
410 }
411
412 if trials_without_improvement >= self.config.early_stopping_patience {
414 info!(
415 "Early stopping triggered after {} trials without improvement",
416 trials_without_improvement
417 );
418 break;
419 }
420
421 let mut stats = self.stats.write().await;
423 stats.total_trials = trial as u64 + 1;
424 }
425
426 let mut stats = self.stats.write().await;
428 stats.total_training_time_secs = start_time.elapsed().as_secs_f64();
429 stats.ensemble_size = self.ensemble.read().await.len();
430
431 info!(
432 "AutoML training complete: {} trials, best score: {:.4}, algorithm: {:?}",
433 stats.total_trials, stats.best_score, stats.best_algorithm
434 );
435
436 Ok(())
437 }
438
439 async fn generate_hyperparameters(&self, algorithm: Algorithm) -> Result<HyperParameters> {
441 let mut rng = self.rng.lock().await;
442
443 let _base = if self.config.enable_meta_learning {
445 self.get_meta_learning_initialization(algorithm).await
446 } else {
447 HyperParameters::default()
448 };
449
450 Ok(HyperParameters {
452 learning_rate: rng.random_range(0.0001..0.1),
453 n_estimators: rng.random_range(10..500),
454 max_depth: Some(rng.random_range(3..20)),
455 regularization: rng.random_range(0.0..1.0),
456 n_neighbors: rng.random_range(3..20),
457 batch_size: rng.random_range(16..256),
458 random_seed: rng.random::<u64>(),
459 })
460 }
461
462 async fn get_meta_learning_initialization(&self, _algorithm: Algorithm) -> HyperParameters {
464 HyperParameters::default()
466 }
467
468 async fn train_and_evaluate(
470 &self,
471 algorithm: Algorithm,
472 hyperparams: &HyperParameters,
473 features: &Array2<f64>,
474 labels: &Array1<f64>,
475 ) -> Result<ModelPerformance> {
476 let start_time = std::time::Instant::now();
477
478 let cv_scores = self
480 .cross_validate(algorithm, hyperparams, features, labels)
481 .await?;
482 let cv_score = cv_scores.iter().sum::<f64>() / cv_scores.len() as f64;
483
484 let (accuracy, precision, recall, f1, mse, r_squared) = self
486 .compute_metrics(algorithm, hyperparams, features, labels)
487 .await?;
488
489 let training_time = start_time.elapsed().as_secs_f64();
490
491 let complexity_score = match algorithm {
493 Algorithm::LinearRegression | Algorithm::LogisticRegression => 10.0,
494 Algorithm::DecisionTree => 30.0,
495 Algorithm::RandomForest | Algorithm::GradientBoosting => 60.0,
496 Algorithm::NeuralNetwork => 80.0,
497 _ => 40.0,
498 };
499
500 Ok(ModelPerformance {
501 algorithm,
502 hyperparameters: hyperparams.clone(),
503 accuracy,
504 precision,
505 recall,
506 f1_score: f1,
507 mse,
508 r_squared,
509 training_time_secs: training_time,
510 inference_time_ms: 1.0, complexity_score,
512 cv_score,
513 })
514 }
515
516 async fn cross_validate(
518 &self,
519 algorithm: Algorithm,
520 hyperparams: &HyperParameters,
521 features: &Array2<f64>,
522 labels: &Array1<f64>,
523 ) -> Result<Vec<f64>> {
524 let n_samples = features.shape()[0];
525 let fold_size = n_samples / self.config.cv_folds;
526
527 let mut scores = Vec::new();
528
529 for fold in 0..self.config.cv_folds {
530 let val_start = fold * fold_size;
531 let val_end = ((fold + 1) * fold_size).min(n_samples);
532
533 let score = self
535 .evaluate_fold(algorithm, hyperparams, features, labels, val_start, val_end)
536 .await?;
537 scores.push(score);
538 }
539
540 Ok(scores)
541 }
542
543 async fn evaluate_fold(
545 &self,
546 _algorithm: Algorithm,
547 _hyperparams: &HyperParameters,
548 _features: &Array2<f64>,
549 _labels: &Array1<f64>,
550 _val_start: usize,
551 _val_end: usize,
552 ) -> Result<f64> {
553 let mut rng = self.rng.lock().await;
559 Ok(0.7 + rng.random::<f64>() * 0.3) }
561
562 async fn compute_metrics(
564 &self,
565 _algorithm: Algorithm,
566 _hyperparams: &HyperParameters,
567 _features: &Array2<f64>,
568 _labels: &Array1<f64>,
569 ) -> Result<(
570 Option<f64>,
571 Option<f64>,
572 Option<f64>,
573 Option<f64>,
574 Option<f64>,
575 Option<f64>,
576 )> {
577 let mut rng = self.rng.lock().await;
579
580 match self.config.task_type {
581 TaskType::Classification => {
582 let accuracy = Some(0.7 + rng.random::<f64>() * 0.3);
583 let precision = Some(0.7 + rng.random::<f64>() * 0.3);
584 let recall = Some(0.7 + rng.random::<f64>() * 0.3);
585 let f1 = Some(0.7 + rng.random::<f64>() * 0.3);
586 Ok((accuracy, precision, recall, f1, None, None))
587 }
588 TaskType::Regression | TaskType::TimeSeries => {
589 let mse = Some(0.1 + rng.random::<f64>() * 0.9);
590 let r_squared = Some(0.5 + rng.random::<f64>() * 0.5);
591 Ok((None, None, None, None, mse, r_squared))
592 }
593 _ => Ok((None, None, None, None, None, None)),
594 }
595 }
596
597 async fn train_final_model(
599 &self,
600 _algorithm: Algorithm,
601 _hyperparams: &HyperParameters,
602 features: &Array2<f64>,
603 _labels: &Array1<f64>,
604 ) -> Result<ModelParameters> {
605 let n_features = features.shape()[1];
607
608 let mut rng = self.rng.lock().await;
609 let weights: Vec<f64> = (0..n_features).map(|_| rng.random::<f64>() - 0.5).collect();
610 let bias = rng.random::<f64>() - 0.5;
611
612 Ok(ModelParameters {
613 weights,
614 bias,
615 extra: HashMap::new(),
616 })
617 }
618
619 async fn update_ensemble(&self, model: TrainedModel) -> Result<()> {
621 let mut ensemble = self.ensemble.write().await;
622
623 ensemble.push(model);
625
626 if ensemble.len() > self.config.max_ensemble_size {
628 ensemble.sort_by(|a, b| {
629 b.performance
630 .overall_score()
631 .partial_cmp(&a.performance.overall_score())
632 .unwrap_or(std::cmp::Ordering::Equal)
633 });
634 ensemble.truncate(self.config.max_ensemble_size);
635 }
636
637 Ok(())
638 }
639
640 pub async fn predict(&self, features: &Array1<f64>) -> Result<f64> {
642 let start_time = std::time::Instant::now();
643
644 let prediction = if self.config.enable_ensemble {
645 self.ensemble_predict(features).await?
646 } else {
647 self.single_model_predict(features).await?
648 };
649
650 let mut stats = self.stats.write().await;
652 stats.predictions_count += 1;
653 let elapsed_ms = start_time.elapsed().as_secs_f64() * 1000.0;
654 stats.avg_prediction_time_ms =
655 (stats.avg_prediction_time_ms * (stats.predictions_count - 1) as f64 + elapsed_ms)
656 / stats.predictions_count as f64;
657
658 Ok(prediction)
659 }
660
661 async fn single_model_predict(&self, features: &Array1<f64>) -> Result<f64> {
663 let model = self.best_model.read().await;
664
665 match &*model {
666 Some(m) => {
667 let mut pred = m.parameters.bias;
669 for (i, &weight) in m.parameters.weights.iter().enumerate() {
670 if i < features.len() {
671 pred += weight * features[i];
672 }
673 }
674
675 if matches!(self.config.task_type, TaskType::Classification) {
677 pred = 1.0 / (1.0 + (-pred).exp()); }
679
680 Ok(pred)
681 }
682 None => Err(anyhow!("No trained model available")),
683 }
684 }
685
686 async fn ensemble_predict(&self, features: &Array1<f64>) -> Result<f64> {
688 let ensemble = self.ensemble.read().await;
689
690 if ensemble.is_empty() {
691 return self.single_model_predict(features).await;
692 }
693
694 let mut predictions = Vec::new();
695 let mut weights = Vec::new();
696
697 for model in ensemble.iter() {
698 let mut pred = model.parameters.bias;
699 for (i, &weight) in model.parameters.weights.iter().enumerate() {
700 if i < features.len() {
701 pred += weight * features[i];
702 }
703 }
704
705 if matches!(self.config.task_type, TaskType::Classification) {
706 pred = 1.0 / (1.0 + (-pred).exp());
707 }
708
709 predictions.push(pred);
710 weights.push(model.performance.overall_score());
711 }
712
713 let total_weight: f64 = weights.iter().sum();
715 let weighted_pred = predictions
716 .iter()
717 .zip(&weights)
718 .map(|(p, w)| p * w)
719 .sum::<f64>()
720 / total_weight;
721
722 Ok(weighted_pred)
723 }
724
725 pub async fn get_stats(&self) -> AutoMLStats {
727 self.stats.read().await.clone()
728 }
729
730 pub async fn get_trial_history(&self) -> Vec<ModelPerformance> {
732 self.trial_history.read().await.clone()
733 }
734
735 pub async fn get_best_model_info(
737 &self,
738 ) -> Option<(Algorithm, HyperParameters, ModelPerformance)> {
739 let model = self.best_model.read().await;
740 model.as_ref().map(|m| {
741 (
742 m.algorithm,
743 m.hyperparameters.clone(),
744 m.performance.clone(),
745 )
746 })
747 }
748
749 pub async fn export_model(&self) -> Result<String> {
751 let model = self.best_model.read().await;
752
753 match &*model {
754 Some(m) => {
755 let export = serde_json::json!({
756 "algorithm": format!("{:?}", m.algorithm),
757 "hyperparameters": m.hyperparameters,
758 "parameters": {
759 "weights": m.parameters.weights,
760 "bias": m.parameters.bias,
761 },
762 "performance": m.performance,
763 });
764 Ok(serde_json::to_string_pretty(&export)?)
765 }
766 None => Err(anyhow!("No model to export")),
767 }
768 }
769}
770
771#[cfg(test)]
772mod tests {
773 use super::*;
774
775 #[test]
776 fn test_algorithm_for_task() {
777 let classifiers = Algorithm::for_task(TaskType::Classification);
778 assert!(!classifiers.is_empty());
779 assert!(classifiers.contains(&Algorithm::LogisticRegression));
780
781 let regressors = Algorithm::for_task(TaskType::Regression);
782 assert!(regressors.contains(&Algorithm::LinearRegression));
783
784 let ts_algorithms = Algorithm::for_task(TaskType::TimeSeries);
785 assert!(ts_algorithms.contains(&Algorithm::ARIMA));
786 }
787
788 #[test]
789 fn test_hyperparameters_default() {
790 let params = HyperParameters::default();
791 assert_eq!(params.learning_rate, 0.01);
792 assert_eq!(params.n_estimators, 100);
793 assert_eq!(params.max_depth, Some(5));
794 }
795
796 #[test]
797 fn test_model_performance_overall_score() {
798 let perf = ModelPerformance {
799 algorithm: Algorithm::LinearRegression,
800 hyperparameters: HyperParameters::default(),
801 accuracy: None,
802 precision: None,
803 recall: None,
804 f1_score: None,
805 mse: Some(0.5),
806 r_squared: Some(0.9),
807 training_time_secs: 10.0,
808 inference_time_ms: 1.0,
809 complexity_score: 20.0,
810 cv_score: 0.85,
811 };
812
813 let score = perf.overall_score();
814 assert!(score > 0.0);
815 assert!(score <= 1.0);
816 }
817
818 #[tokio::test]
819 async fn test_automl_creation() {
820 let config = AutoMLConfig::default();
821 let automl = AutoML::new(config);
822 assert!(automl.is_ok());
823 }
824
825 #[tokio::test]
826 async fn test_automl_generate_hyperparameters() {
827 let config = AutoMLConfig::default();
828 let automl = AutoML::new(config).unwrap();
829
830 let params = automl
831 .generate_hyperparameters(Algorithm::LinearRegression)
832 .await;
833 assert!(params.is_ok());
834
835 let p = params.unwrap();
836 assert!(p.learning_rate > 0.0);
837 assert!(p.n_estimators > 0);
838 }
839
840 #[tokio::test]
841 async fn test_automl_fit_small_dataset() {
842 let config = AutoMLConfig {
843 task_type: TaskType::Regression,
844 max_training_time_secs: 5,
845 n_trials: 3,
846 cv_folds: 2,
847 enable_ensemble: false,
848 ..Default::default()
849 };
850
851 let mut automl = AutoML::new(config).unwrap();
852
853 let features = Array2::from_shape_vec(
855 (10, 2),
856 vec![
857 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
858 9.0, 10.0, 10.0, 11.0,
859 ],
860 )
861 .unwrap();
862
863 let labels = Array1::from_vec(vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0]);
864
865 let result = automl.fit(&features, &labels).await;
866 assert!(result.is_ok());
867
868 let stats = automl.get_stats().await;
869 assert!(stats.total_trials > 0);
870 assert!(stats.total_trials <= 3);
871 }
872
873 #[tokio::test]
874 async fn test_automl_prediction() {
875 let config = AutoMLConfig {
876 task_type: TaskType::Regression,
877 max_training_time_secs: 5,
878 n_trials: 2,
879 ..Default::default()
880 };
881
882 let mut automl = AutoML::new(config).unwrap();
883
884 let features = Array2::from_shape_vec(
885 (10, 2),
886 vec![
887 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
888 9.0, 10.0, 10.0, 11.0,
889 ],
890 )
891 .unwrap();
892
893 let labels = Array1::from_vec(vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0]);
894
895 automl.fit(&features, &labels).await.unwrap();
896
897 let test_features = Array1::from_vec(vec![5.5, 6.5]);
898 let prediction = automl.predict(&test_features).await;
899 assert!(prediction.is_ok());
900 }
901
902 #[tokio::test]
903 async fn test_ensemble_prediction() {
904 let config = AutoMLConfig {
905 task_type: TaskType::Classification,
906 enable_ensemble: true,
907 max_ensemble_size: 3,
908 n_trials: 5,
909 max_training_time_secs: 10,
910 ..Default::default()
911 };
912
913 let mut automl = AutoML::new(config).unwrap();
914
915 let features =
916 Array2::from_shape_vec((20, 2), (0..40).map(|x| x as f64).collect()).unwrap();
917 let labels = Array1::from_vec((0..20).map(|x| (x % 2) as f64).collect());
918
919 automl.fit(&features, &labels).await.unwrap();
920
921 let test_features = Array1::from_vec(vec![5.0, 10.0]);
922 let prediction = automl.predict(&test_features).await;
923 assert!(prediction.is_ok());
924
925 let pred = prediction.unwrap();
926 assert!((0.0..=1.0).contains(&pred)); }
928
929 #[tokio::test]
930 async fn test_get_best_model_info() {
931 let config = AutoMLConfig {
932 n_trials: 2,
933 max_training_time_secs: 5,
934 ..Default::default()
935 };
936
937 let mut automl = AutoML::new(config).unwrap();
938
939 let features =
940 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
941 let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
942
943 automl.fit(&features, &labels).await.unwrap();
944
945 let best_info = automl.get_best_model_info().await;
946 assert!(best_info.is_some());
947
948 let (_algorithm, _hyperparams, performance) = best_info.unwrap();
949 assert!(performance.cv_score >= 0.0);
950 }
951
952 #[tokio::test]
953 async fn test_export_model() {
954 let config = AutoMLConfig {
955 n_trials: 1,
956 max_training_time_secs: 5,
957 ..Default::default()
958 };
959
960 let mut automl = AutoML::new(config).unwrap();
961
962 let features =
963 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
964 let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
965
966 automl.fit(&features, &labels).await.unwrap();
967
968 let export = automl.export_model().await;
969 assert!(export.is_ok());
970
971 let json_str = export.unwrap();
972 assert!(json_str.contains("algorithm"));
973 assert!(json_str.contains("hyperparameters"));
974 }
975
976 #[tokio::test]
977 async fn test_trial_history() {
978 let config = AutoMLConfig {
979 n_trials: 3,
980 max_training_time_secs: 5,
981 ..Default::default()
982 };
983
984 let mut automl = AutoML::new(config).unwrap();
985
986 let features =
987 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
988 let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
989
990 automl.fit(&features, &labels).await.unwrap();
991
992 let history = automl.get_trial_history().await;
993 assert!(!history.is_empty());
994 assert!(history.len() <= 3);
995 }
996
997 #[tokio::test]
998 async fn test_early_stopping() {
999 let config = AutoMLConfig {
1000 n_trials: 100, max_training_time_secs: 60,
1002 early_stopping_patience: 3,
1003 ..Default::default()
1004 };
1005
1006 let mut automl = AutoML::new(config).unwrap();
1007
1008 let features =
1009 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
1010 let labels = Array1::from_vec((0..10).map(|x| x as f64).collect());
1011
1012 automl.fit(&features, &labels).await.unwrap();
1013
1014 let stats = automl.get_stats().await;
1015 assert!(stats.total_trials < 100);
1017 }
1018}