automl/
lib.rs

1#![deny(clippy::correctness)]
2#![warn(
3    clippy::all,
4    clippy::suspicious,
5    clippy::complexity,
6    clippy::perf,
7    clippy::style,
8    clippy::pedantic,
9    clippy::nursery,
10    clippy::missing_docs_in_private_items
11)]
12#![allow(clippy::module_name_repetitions, clippy::too_many_lines)]
13#![warn(missing_docs, rustdoc::missing_doc_code_examples)]
14#![doc = include_str!("../README.md")]
15
16pub mod settings;
17pub use settings::Settings;
18use settings::{Algorithm, Distance, FinalModel, Kernel, Metric, PreProcessing};
19
20pub mod cookbook;
21
22mod algorithms;
23use algorithms::{
24    CategoricalNaiveBayesClassifierWrapper, DecisionTreeClassifierWrapper,
25    DecisionTreeRegressorWrapper, ElasticNetRegressorWrapper, GaussianNaiveBayesClassifierWrapper,
26    KNNClassifierWrapper, KNNRegressorWrapper, LassoRegressorWrapper, LinearRegressorWrapper,
27    LogisticRegressionWrapper, ModelWrapper, RandomForestClassifierWrapper,
28    RandomForestRegressorWrapper, RidgeRegressorWrapper, SupportVectorClassifierWrapper,
29    SupportVectorRegressorWrapper,
30};
31
32mod utils;
33use utils::elementwise_multiply;
34
35use itertools::Itertools;
36use smartcore::{
37    dataset::Dataset,
38    decomposition::{
39        pca::{PCAParameters, PCA},
40        svd::{SVDParameters, SVD},
41    },
42    linalg::{naive::dense_matrix::DenseMatrix, BaseMatrix},
43    model_selection::{train_test_split, CrossValidationResult},
44};
45use std::{
46    cmp::Ordering::Equal,
47    fmt::{Display, Formatter},
48    io::{Read, Write},
49    time::Duration,
50};
51
52#[cfg(any(feature = "nd"))]
53use ndarray::{Array1, Array2};
54
55#[cfg(any(feature = "csv"))]
56use {
57    polars::prelude::{DataFrame, Float32Type},
58    utils::validate_and_read,
59};
60
61use {
62    comfy_table::{
63        modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, Attribute, Cell, Table,
64    },
65    humantime::format_duration,
66};
67
68/// This trait must be implemented for any types passed to the `SupervisedModel::new` as data.
69pub trait IntoSupervisedData {
70    /// Converts the struct into paired features and labels
71    fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>);
72}
73
74/// Types that implement this trait can be paired in a tuple with a type implementing `IntoLabels` to
75/// automatically satisfy `IntoSupervisedData`. This trait is also required for data that is passed to `predict`.
76pub trait IntoFeatures {
77    /// Converts the struct into a dense matrix of features
78    fn to_dense_matrix(self) -> DenseMatrix<f32>;
79}
80
81/// Types that implement this trait can be paired in a tuple with a type implementing `IntoFeatures`
82/// to automatically satisfy `IntoSupervisedData`.
83pub trait IntoLabels {
84    /// Converts the struct into a vector of labels
85    fn into_vec(self) -> Vec<f32>;
86}
87
88impl IntoSupervisedData for Dataset<f32, f32> {
89    fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
90        (
91            DenseMatrix::from_array(self.num_samples, self.num_features, &self.data),
92            self.target,
93        )
94    }
95}
96
97#[cfg(any(feature = "csv"))]
98impl IntoSupervisedData for (&str, usize) {
99    fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
100        let (filepath, target_index) = self;
101        let df = validate_and_read(filepath);
102
103        // Get target variables
104        let target_column_name = df.get_column_names()[target_index];
105        let series = df.column(target_column_name).unwrap().clone();
106        let target_df = DataFrame::new(vec![series]).unwrap();
107        let ndarray = target_df.to_ndarray::<Float32Type>().unwrap();
108        let y = ndarray.into_raw_vec();
109
110        // Get the rest of the data
111        let features = df.drop(target_column_name).unwrap();
112        let (height, width) = features.shape();
113        let ndarray = features.to_ndarray::<Float32Type>().unwrap();
114        let x = DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap());
115        (x, y)
116    }
117}
118
119#[cfg(any(feature = "csv"))]
120impl IntoFeatures for &str {
121    fn to_dense_matrix(self) -> DenseMatrix<f32> {
122        let df = validate_and_read(self);
123
124        // Get the rest of the data
125        let (height, width) = df.shape();
126        let ndarray = df.to_ndarray::<Float32Type>().unwrap();
127        DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap())
128    }
129}
130
131impl<X, Y> IntoSupervisedData for (X, Y)
132where
133    X: IntoFeatures,
134    Y: IntoLabels,
135{
136    fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
137        (self.0.to_dense_matrix(), self.1.into_vec())
138    }
139}
140
141impl IntoFeatures for Vec<Vec<f32>> {
142    fn to_dense_matrix(self) -> DenseMatrix<f32> {
143        DenseMatrix::from_2d_vec(&self)
144    }
145}
146
147impl IntoLabels for Vec<f32> {
148    fn into_vec(self) -> Vec<f32> {
149        self
150    }
151}
152
153#[cfg(any(feature = "nd"))]
154impl IntoFeatures for Array2<f32> {
155    fn to_dense_matrix(self) -> DenseMatrix<f32> {
156        DenseMatrix::from_array(self.shape()[0], self.shape()[1], self.as_slice().unwrap())
157    }
158}
159
160#[cfg(any(feature = "nd"))]
161impl IntoLabels for Array1<f32> {
162    fn into_vec(self) -> Vec<f32> {
163        self.to_vec()
164    }
165}
166
167/// Trains and compares supervised models
168#[derive(serde::Serialize, serde::Deserialize)]
169pub struct SupervisedModel {
170    /// Settings for the model.
171    settings: Settings,
172    /// The training data.
173    x_train: DenseMatrix<f32>,
174    /// The training labels.
175    y_train: Vec<f32>,
176    /// The validation data.
177    x_val: DenseMatrix<f32>,
178    /// The validation labels.
179    y_val: Vec<f32>,
180    /// The number of classes in the data.
181    number_of_classes: usize,
182    /// The results of the model comparison.
183    comparison: Vec<Model>,
184    /// The final model.
185    metamodel: Model,
186    /// PCA model for preprocessing.
187    preprocessing_pca: Option<PCA<f32, DenseMatrix<f32>>>,
188    /// SVD model for preprocessing.
189    preprocessing_svd: Option<SVD<f32, DenseMatrix<f32>>>,
190}
191
192impl SupervisedModel {
193    /// Create a new supervised model. This function accepts various types of syntax. For instance, it will work for vectors:
194    /// ```
195    /// # use automl::{SupervisedModel, Settings};
196    /// let model = automl::SupervisedModel::new(
197    ///     (vec![vec![1.0; 5]; 5],
198    ///     vec![1.0; 5]),
199    ///     automl::Settings::default_regression(),
200    /// );    
201    /// ```
202    /// It also works for some ndarray datatypes:
203    /// ```
204    /// # use automl::{SupervisedModel, Settings};
205    /// #[cfg(any(feature = "nd"))]
206    /// let model = SupervisedModel::new(
207    ///     (
208    ///         ndarray::arr2(&[[1.0, 2.0], [3.0, 4.0]]),
209    ///         ndarray::arr1(&[1.0, 2.0])
210    ///     ),
211    ///     automl::Settings::default_regression(),
212    /// );
213    /// ```
214    /// But you can also create a new supervised model from a [smartcore toy dataset](https://docs.rs/smartcore/0.2.0/smartcore/dataset/index.html)
215    /// ```
216    /// # use automl::{SupervisedModel, Settings};
217    /// let model = SupervisedModel::new(
218    ///     smartcore::dataset::diabetes::load_dataset(),
219    ///     Settings::default_regression()
220    /// );
221    /// ```
222    /// You can even create a new supervised model directly from a CSV!
223    /// ```
224    /// # use automl::{SupervisedModel, Settings};
225    /// #[cfg(any(feature = "csv"))]
226    /// let model = SupervisedModel::new(
227    ///     ("data/diabetes.csv", 10),
228    ///     Settings::default_regression()
229    /// );
230    /// ```
231    /// And that CSV can even come from a URL
232    /// ```
233    /// # use automl::{SupervisedModel, Settings};
234    /// #[cfg(any(feature = "csv"))]
235    /// let mut model = automl::SupervisedModel::new(
236    ///         (
237    ///         "https://raw.githubusercontent.com/plotly/datasets/master/diabetes.csv",
238    ///         8,
239    ///     ),
240    ///     Settings::default_regression(),
241    /// );
242    pub fn new<D>(data: D, settings: Settings) -> Self
243    where
244        D: IntoSupervisedData,
245    {
246        let (x, y) = data.to_supervised_data();
247        Self::build(x, y, settings)
248    }
249
250    /// Load the supervised model from a file saved previously
251    /// ```
252    /// # use automl::{SupervisedModel, Settings};
253    /// # let mut model = SupervisedModel::new(
254    /// #    smartcore::dataset::diabetes::load_dataset(),
255    /// #    Settings::default_regression()
256    /// # );
257    /// # model.save("tests/load_that_model.aml");
258    /// let model = SupervisedModel::new_from_file("tests/load_that_model.aml");
259    /// # std::fs::remove_file("tests/load_that_model.aml");
260    /// ```
261    #[must_use]
262    pub fn new_from_file(file_name: &str) -> Self {
263        let mut buf: Vec<u8> = Vec::new();
264        std::fs::File::open(file_name)
265            .and_then(|mut f| f.read_to_end(&mut buf))
266            .expect("Cannot load model from file.");
267        bincode::deserialize(&buf).expect("Can not deserialize the model")
268    }
269
270    /// Predict values using the final model based on a vec.
271    /// ```
272    /// # use automl::{SupervisedModel, Settings};
273    /// # let mut model = SupervisedModel::new(
274    /// #     smartcore::dataset::diabetes::load_dataset(),
275    /// #    Settings::default_regression()
276    /// # .only(automl::settings::Algorithm::Linear)
277    /// # );
278    /// # model.train();
279    /// model.predict(vec![vec![5.0; 10]; 5]);
280    /// ```
281    /// Or predict values using the final model based on ndarray.
282    /// ```
283    /// # use automl::{SupervisedModel, Settings};
284    /// # #[cfg(any(feature = "nd"))]
285    /// # let mut model = SupervisedModel::new(
286    /// #     smartcore::dataset::diabetes::load_dataset(),
287    /// #     Settings::default_regression()
288    /// # .only(automl::settings::Algorithm::Linear)
289    /// # );
290    /// # #[cfg(any(feature = "nd"))]
291    /// # model.train();
292    /// #[cfg(any(feature = "nd"))]
293    /// model.predict(
294    ///     ndarray::arr2(&[
295    ///         [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
296    ///         [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
297    ///     ])
298    /// );
299    /// ```
300    /// You can also predict from a CSV file
301    /// ```
302    /// # use automl::{SupervisedModel, Settings};
303    /// # #[cfg(any(feature = "csv"))]
304    /// # let mut model = SupervisedModel::new(
305    /// #     ("data/diabetes.csv", 10),
306    /// #     Settings::default_regression()
307    /// # .only(automl::settings::Algorithm::Linear)
308    /// # );
309    /// # #[cfg(any(feature = "csv"))]
310    /// # model.train();
311    /// #[cfg(any(feature = "csv"))]
312    /// model.predict("data/diabetes_without_target.csv");
313    /// ```
314    ///
315    /// # Panics
316    ///
317    /// If the model has not been trained, this function will panic.
318    pub fn predict<X: IntoFeatures>(&self, x: X) -> Vec<f32> {
319        let x = &self.preprocess(x.to_dense_matrix());
320        match self.settings.final_model_approach {
321            FinalModel::None => panic!(""),
322            FinalModel::Best => self.predict_by_model(x, &self.comparison[0]),
323            FinalModel::Blending { algorithm, .. } => self.predict_blended_model(x, algorithm),
324        }
325    }
326
327    /// Runs a model comparison and trains a final model.
328    /// ```
329    /// # use automl::{SupervisedModel, Settings};
330    /// let mut model = SupervisedModel::new(
331    ///     smartcore::dataset::diabetes::load_dataset(),
332    ///     Settings::default_regression()
333    /// # .only(automl::settings::Algorithm::Linear)
334    /// );
335    /// model.train();
336    /// ```
337    pub fn train(&mut self) {
338        // Train any necessary preprocessing
339        if let PreProcessing::ReplaceWithPCA {
340            number_of_components,
341        } = self.settings.preprocessing
342        {
343            self.train_pca(&self.x_train.clone(), number_of_components);
344        }
345        if let PreProcessing::ReplaceWithSVD {
346            number_of_components,
347        } = self.settings.preprocessing
348        {
349            self.train_svd(&self.x_train.clone(), number_of_components);
350        }
351
352        // Preprocess the data
353        self.x_train = self.preprocess(self.x_train.clone());
354
355        // Split validatino out if blending
356        if let FinalModel::Blending {
357            meta_training_fraction,
358            meta_testing_fraction: _,
359            algorithm: _,
360        } = &self.settings.final_model_approach
361        {
362            let (x_train, x_val, y_train, y_val) = train_test_split(
363                &self.x_train,
364                &self.y_train,
365                *meta_training_fraction,
366                self.settings.shuffle,
367            );
368            self.x_train = x_train;
369            self.y_train = y_train;
370            self.y_val = y_val;
371            self.x_val = x_val;
372        }
373
374        // Run logistic regression
375        if !self
376            .settings
377            .skiplist
378            .contains(&Algorithm::LogisticRegression)
379        {
380            self.record_model(LogisticRegressionWrapper::cv_model(
381                &self.x_train,
382                &self.y_train,
383                &self.settings,
384            ));
385        }
386
387        // Run random forest classification
388        if !self
389            .settings
390            .skiplist
391            .contains(&Algorithm::RandomForestClassifier)
392        {
393            self.record_model(RandomForestClassifierWrapper::cv_model(
394                &self.x_train,
395                &self.y_train,
396                &self.settings,
397            ));
398        }
399
400        // Run k-nearest neighbor classifier
401        if !self.settings.skiplist.contains(&Algorithm::KNNClassifier) {
402            self.record_model(KNNClassifierWrapper::cv_model(
403                &self.x_train,
404                &self.y_train,
405                &self.settings,
406            ));
407        }
408
409        if !self
410            .settings
411            .skiplist
412            .contains(&Algorithm::DecisionTreeClassifier)
413        {
414            self.record_model(DecisionTreeClassifierWrapper::cv_model(
415                &self.x_train,
416                &self.y_train,
417                &self.settings,
418            ));
419        }
420
421        if !self
422            .settings
423            .skiplist
424            .contains(&Algorithm::GaussianNaiveBayes)
425        {
426            self.record_model(GaussianNaiveBayesClassifierWrapper::cv_model(
427                &self.x_train,
428                &self.y_train,
429                &self.settings,
430            ));
431        }
432
433        if !self
434            .settings
435            .skiplist
436            .contains(&Algorithm::CategoricalNaiveBayes)
437            && std::mem::discriminant(&self.settings.preprocessing)
438                != std::mem::discriminant(&PreProcessing::ReplaceWithPCA {
439                    number_of_components: 1,
440                })
441            && std::mem::discriminant(&self.settings.preprocessing)
442                != std::mem::discriminant(&PreProcessing::ReplaceWithSVD {
443                    number_of_components: 1,
444                })
445        {
446            self.record_model(CategoricalNaiveBayesClassifierWrapper::cv_model(
447                &self.x_train,
448                &self.y_train,
449                &self.settings,
450            ));
451        }
452
453        if self.number_of_classes == 2 && !self.settings.skiplist.contains(&Algorithm::SVC) {
454            self.record_model(SupportVectorClassifierWrapper::cv_model(
455                &self.x_train,
456                &self.y_train,
457                &self.settings,
458            ));
459        }
460
461        if !self.settings.skiplist.contains(&Algorithm::Linear) {
462            self.record_model(LinearRegressorWrapper::cv_model(
463                &self.x_train,
464                &self.y_train,
465                &self.settings,
466            ));
467        }
468
469        if !self.settings.skiplist.contains(&Algorithm::SVR) {
470            self.record_model(SupportVectorRegressorWrapper::cv_model(
471                &self.x_train,
472                &self.y_train,
473                &self.settings,
474            ));
475        }
476
477        if !self.settings.skiplist.contains(&Algorithm::Lasso) {
478            self.record_model(RidgeRegressorWrapper::cv_model(
479                &self.x_train,
480                &self.y_train,
481                &self.settings,
482            ));
483        }
484
485        if !self.settings.skiplist.contains(&Algorithm::Ridge) {
486            self.record_model(LassoRegressorWrapper::cv_model(
487                &self.x_train,
488                &self.y_train,
489                &self.settings,
490            ));
491        }
492
493        if !self.settings.skiplist.contains(&Algorithm::ElasticNet) {
494            self.record_model(ElasticNetRegressorWrapper::cv_model(
495                &self.x_train,
496                &self.y_train,
497                &self.settings,
498            ));
499        }
500
501        if !self
502            .settings
503            .skiplist
504            .contains(&Algorithm::DecisionTreeRegressor)
505        {
506            self.record_model(DecisionTreeRegressorWrapper::cv_model(
507                &self.x_train,
508                &self.y_train,
509                &self.settings,
510            ));
511        }
512
513        if !self
514            .settings
515            .skiplist
516            .contains(&Algorithm::RandomForestRegressor)
517        {
518            self.record_model(RandomForestRegressorWrapper::cv_model(
519                &self.x_train,
520                &self.y_train,
521                &self.settings,
522            ));
523        }
524
525        if !self.settings.skiplist.contains(&Algorithm::KNNRegressor) {
526            self.record_model(KNNRegressorWrapper::cv_model(
527                &self.x_train,
528                &self.y_train,
529                &self.settings,
530            ));
531        }
532
533        if let FinalModel::Blending {
534            algorithm,
535            meta_training_fraction,
536            meta_testing_fraction,
537        } = self.settings.final_model_approach
538        {
539            self.train_blended_model(algorithm, meta_training_fraction, meta_testing_fraction);
540        }
541    }
542
543    /// Save the supervised model to a file for later use
544    /// ```
545    /// # use automl::{SupervisedModel, Settings};
546    /// let mut model = SupervisedModel::new(
547    ///     smartcore::dataset::diabetes::load_dataset(),
548    ///     Settings::default_regression()
549    /// );
550    /// model.save("tests/save_that_model.aml");
551    /// # std::fs::remove_file("tests/save_that_model.aml");
552    /// ```
553    pub fn save(&self, file_name: &str) {
554        let serial = bincode::serialize(&self).expect("Cannot serialize model.");
555        std::fs::File::create(file_name)
556            .and_then(|mut f| f.write_all(&serial))
557            .expect("Cannot write model to file.");
558    }
559
560    /// Save the best model for later use as a smartcore native object.
561    /// ```
562    /// # use automl::{SupervisedModel, Settings, settings::Algorithm};
563    /// use std::io::Read;
564    ///
565    /// let mut model = SupervisedModel::new(
566    ///     smartcore::dataset::diabetes::load_dataset(),
567    ///     Settings::default_regression()
568    /// # .only(Algorithm::Linear)
569    /// );
570    /// model.train();
571    /// model.save("tests/save_best.sc");
572    /// # std::fs::remove_file("tests/save_best.sc");
573    /// ```
574    pub fn save_best(&self, file_name: &str) {
575        if matches!(self.settings.final_model_approach, FinalModel::Best) {
576            std::fs::File::create(file_name)
577                .and_then(|mut f| f.write_all(&self.comparison[0].model))
578                .expect("Cannot write model to file.");
579        }
580    }
581}
582
583/// Private functions go here
584impl SupervisedModel {
585    /// Build a new supervised model
586    ///
587    /// # Arguments
588    ///
589    /// * `x` - The input data
590    /// * `y` - The output data
591    /// * `settings` - The settings for the model
592    fn build(x: DenseMatrix<f32>, y: Vec<f32>, settings: Settings) -> Self {
593        Self {
594            settings,
595            x_train: x,
596            number_of_classes: Self::count_classes(&y),
597            y_train: y,
598            x_val: DenseMatrix::new(0, 0, vec![]),
599            y_val: vec![],
600            comparison: vec![],
601            metamodel: Model::default(),
602            preprocessing_pca: None,
603            preprocessing_svd: None,
604        }
605    }
606
607    /// Train the supervised model.
608    ///
609    /// # Arguments
610    ///
611    /// * `algo` - The algorithm to use
612    /// * `training_fraction` - The fraction of the data to use for training
613    /// * `testing_fraction` - The fraction of the data to use for testing
614    fn train_blended_model(
615        &mut self,
616        algo: Algorithm,
617        training_fraction: f32,
618        testing_fraction: f32,
619    ) {
620        // Make the data
621        let mut meta_x: Vec<Vec<f32>> = Vec::new();
622        for model in &self.comparison {
623            meta_x.push(self.predict_by_model(&self.x_val, model));
624        }
625        let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose();
626
627        // Split into datasets
628        let (x_train, x_test, y_train, y_test) = train_test_split(
629            &xdm,
630            &self.y_val,
631            training_fraction / (training_fraction + testing_fraction),
632            self.settings.shuffle,
633        );
634
635        // Train the model
636        // let model = LassoRegressorWrapper::train(&x_train, &y_train, &self.settings);
637        let model = algo.get_trainer()(&x_train, &y_train, &self.settings);
638
639        // Score the model
640        let train_score = self.settings.get_metric()(
641            &y_train,
642            &algo.get_predictor()(&x_train, &model, &self.settings),
643            // &LassoRegressorWrapper::predict(&x_train, &model, &self.settings),
644        );
645        let test_score = self.settings.get_metric()(
646            &y_test,
647            &algo.get_predictor()(&x_test, &model, &self.settings),
648            // &LassoRegressorWrapper::predict(&x_test, &model, &self.settings),
649        );
650
651        self.metamodel = Model {
652            score: CrossValidationResult {
653                test_score: vec![test_score; 1],
654                train_score: vec![train_score; 1],
655            },
656            name: algo,
657            duration: Duration::default(),
658            model,
659        };
660    }
661
662    /// Predict using all of the trained models.
663    ///
664    /// # Arguments
665    ///
666    /// * `x` - The input data
667    /// * `algo` - The algorithm to use
668    ///
669    /// # Returns
670    ///
671    /// * The predicted values
672    fn predict_blended_model(&self, x: &DenseMatrix<f32>, algo: Algorithm) -> Vec<f32> {
673        // Make the data
674        let mut meta_x: Vec<Vec<f32>> = Vec::new();
675        for i in 0..self.comparison.len() {
676            let model = &self.comparison[i];
677            meta_x.push(self.predict_by_model(x, model));
678        }
679
680        //
681        let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose();
682        let metamodel = &self.metamodel.model;
683
684        // Train the model
685        algo.get_predictor()(&xdm, metamodel, &self.settings)
686    }
687
688    /// Predict using a single model.
689    ///
690    /// # Arguments
691    ///
692    /// * `x` - The input data
693    /// * `model` - The model to use
694    ///
695    /// # Returns
696    ///
697    /// * The predicted values
698    fn predict_by_model(&self, x: &DenseMatrix<f32>, model: &Model) -> Vec<f32> {
699        model.name.get_predictor()(x, &model.model, &self.settings)
700    }
701
702    /// Get interaction features for the data.
703    ///
704    /// # Arguments
705    fn interaction_features(mut x: DenseMatrix<f32>) -> DenseMatrix<f32> {
706        let (_, width) = x.shape();
707        for i in 0..width {
708            for j in (i + 1)..width {
709                let feature = elementwise_multiply(&x.get_col_as_vec(i), &x.get_col_as_vec(j));
710                let new_column = DenseMatrix::from_row_vector(feature).transpose();
711                x = x.h_stack(&new_column);
712            }
713        }
714        x
715    }
716
717    /// Get polynomial features for the data.
718    ///
719    /// # Arguments
720    ///
721    /// * `x` - The input data
722    /// * `order` - The order of the polynomial
723    ///
724    /// # Returns
725    ///
726    /// * The data with polynomial features
727    fn polynomial_features(mut x: DenseMatrix<f32>, order: usize) -> DenseMatrix<f32> {
728        let (height, width) = x.shape();
729        for n in 2..=order {
730            let combinations = (0..width).combinations_with_replacement(n);
731            for combo in combinations {
732                let mut feature = vec![1.0; height];
733                for column in combo {
734                    feature = elementwise_multiply(&x.get_col_as_vec(column), &feature);
735                }
736                let new_column = DenseMatrix::from_row_vector(feature).transpose();
737                x = x.h_stack(&new_column);
738            }
739        }
740        x
741    }
742
743    /// Train PCA on the data for preprocessing.
744    ///
745    /// # Arguments
746    ///
747    /// * `x` - The input data
748    /// * `n` - The number of components to use
749    fn train_pca(&mut self, x: &DenseMatrix<f32>, n: usize) {
750        let pca = PCA::fit(
751            x,
752            PCAParameters::default()
753                .with_n_components(n)
754                .with_use_correlation_matrix(true),
755        )
756        .unwrap();
757        self.preprocessing_pca = Some(pca);
758    }
759
760    /// Get PCA features for the data using the trained PCA preprocessor.
761    ///
762    /// # Arguments
763    ///
764    /// * `x` - The input data
765    fn pca_features(&self, x: &DenseMatrix<f32>, _: usize) -> DenseMatrix<f32> {
766        self.preprocessing_pca
767            .as_ref()
768            .unwrap()
769            .transform(x)
770            .unwrap()
771    }
772
773    /// Train SVD on the data for preprocessing.
774    ///
775    /// # Arguments
776    ///
777    /// * `x` - The input data
778    /// * `n` - The number of components to use
779    fn train_svd(&mut self, x: &DenseMatrix<f32>, n: usize) {
780        let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)).unwrap();
781        self.preprocessing_svd = Some(svd);
782    }
783
784    /// Get SVD features for the data.
785    fn svd_features(&self, x: &DenseMatrix<f32>, _: usize) -> DenseMatrix<f32> {
786        self.preprocessing_svd
787            .as_ref()
788            .unwrap()
789            .transform(x)
790            .unwrap()
791    }
792
793    /// Pre process the data.
794    ///
795    /// # Arguments
796    ///
797    /// * `x` - The input data
798    ///
799    /// # Returns
800    ///
801    /// * The preprocessed data
802    fn preprocess(&self, x: DenseMatrix<f32>) -> DenseMatrix<f32> {
803        match self.settings.preprocessing {
804            PreProcessing::None => x,
805            PreProcessing::AddInteractions => Self::interaction_features(x),
806            PreProcessing::AddPolynomial { order } => Self::polynomial_features(x, order),
807            PreProcessing::ReplaceWithPCA {
808                number_of_components,
809            } => self.pca_features(&x, number_of_components),
810            PreProcessing::ReplaceWithSVD {
811                number_of_components,
812            } => self.svd_features(&x, number_of_components),
813        }
814    }
815
816    /// Count the number of classes in the data.
817    ///
818    /// # Arguments
819    ///
820    /// * `y` - The data to count the classes in
821    ///
822    /// # Returns
823    ///
824    /// * The number of classes
825    fn count_classes(y: &[f32]) -> usize {
826        let mut sorted_targets = y.to_vec();
827        sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Equal));
828        sorted_targets.dedup();
829        sorted_targets.len()
830    }
831
832    /// Record a model in the comparison.
833    fn record_model(&mut self, model: (CrossValidationResult<f32>, Algorithm, Duration, Vec<u8>)) {
834        self.comparison.push(Model {
835            score: model.0,
836            name: model.1,
837            duration: model.2,
838            model: model.3,
839        });
840        self.sort();
841    }
842
843    /// Sort the models in the comparison by their mean test scores.
844    fn sort(&mut self) {
845        self.comparison.sort_by(|a, b| {
846            a.score
847                .mean_test_score()
848                .partial_cmp(&b.score.mean_test_score())
849                .unwrap_or(Equal)
850        });
851        if self.settings.sort_by == Metric::RSquared || self.settings.sort_by == Metric::Accuracy {
852            self.comparison.reverse();
853        }
854    }
855}
856
857impl Display for SupervisedModel {
858    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
859        let mut table = Table::new();
860        table.load_preset(UTF8_FULL);
861        table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
862        table.set_header(vec![
863            Cell::new("Model").add_attribute(Attribute::Bold),
864            Cell::new("Time").add_attribute(Attribute::Bold),
865            Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
866            Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
867        ]);
868        for model in &self.comparison {
869            let mut row_vec = vec![];
870            row_vec.push(format!("{}", &model.name));
871            row_vec.push(format!("{}", format_duration(model.duration)));
872            let decider =
873                ((model.score.mean_train_score() + model.score.mean_test_score()) / 2.0).abs();
874            if decider > 0.01 && decider < 1000.0 {
875                row_vec.push(format!("{:.2}", &model.score.mean_train_score()));
876                row_vec.push(format!("{:.2}", &model.score.mean_test_score()));
877            } else {
878                row_vec.push(format!("{:.3e}", &model.score.mean_train_score()));
879                row_vec.push(format!("{:.3e}", &model.score.mean_test_score()));
880            }
881
882            table.add_row(row_vec);
883        }
884
885        let mut meta_table = Table::new();
886        meta_table.load_preset(UTF8_FULL);
887        meta_table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
888        meta_table.set_header(vec![
889            Cell::new("Meta Model").add_attribute(Attribute::Bold),
890            Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
891            Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
892        ]);
893
894        // Populate row
895        let mut row_vec = vec![];
896        row_vec.push(format!("{}", self.metamodel.name));
897        let decider = ((self.metamodel.score.mean_train_score()
898            + self.metamodel.score.mean_test_score())
899            / 2.0)
900            .abs();
901        if decider > 0.01 && decider < 1000.0 {
902            row_vec.push(format!("{:.2}", self.metamodel.score.mean_train_score()));
903            row_vec.push(format!("{:.2}", self.metamodel.score.mean_test_score()));
904        } else {
905            row_vec.push(format!("{:.3e}", self.metamodel.score.mean_train_score()));
906            row_vec.push(format!("{:.3e}", self.metamodel.score.mean_test_score()));
907        }
908
909        // Add row to table
910        meta_table.add_row(row_vec);
911
912        // Write
913        write!(f, "{table}\n{meta_table}")
914    }
915}
916
917/// This contains the results of a single model
918#[derive(serde::Serialize, serde::Deserialize)]
919struct Model {
920    /// The cross validation score of the model
921    #[serde(with = "CrossValidationResultDef")]
922    score: CrossValidationResult<f32>,
923    /// The algorithm used
924    name: Algorithm,
925    /// The time it took to train the model
926    duration: Duration,
927    /// What is this? TODO
928    model: Vec<u8>,
929}
930
931impl Default for Model {
932    fn default() -> Self {
933        Self {
934            score: CrossValidationResult {
935                test_score: vec![],
936                train_score: vec![],
937            },
938            name: Algorithm::Linear,
939            duration: Duration::default(),
940            model: vec![],
941        }
942    }
943}
944
945/// This is a wrapper for the `CrossValidationResult`
946#[derive(serde::Serialize, serde::Deserialize)]
947#[serde(remote = "CrossValidationResult::<f32>")]
948struct CrossValidationResultDef {
949    /// Vector with test scores on each cv split
950    pub test_score: Vec<f32>,
951    /// Vector with training scores on each cv split
952    pub train_score: Vec<f32>,
953}