1#![deny(clippy::correctness)]
2#![warn(
3 clippy::all,
4 clippy::suspicious,
5 clippy::complexity,
6 clippy::perf,
7 clippy::style,
8 clippy::pedantic,
9 clippy::nursery,
10 clippy::missing_docs_in_private_items
11)]
12#![allow(clippy::module_name_repetitions, clippy::too_many_lines)]
13#![warn(missing_docs, rustdoc::missing_doc_code_examples)]
14#![doc = include_str!("../README.md")]
15
16pub mod settings;
17pub use settings::Settings;
18use settings::{Algorithm, Distance, FinalModel, Kernel, Metric, PreProcessing};
19
20pub mod cookbook;
21
22mod algorithms;
23use algorithms::{
24 CategoricalNaiveBayesClassifierWrapper, DecisionTreeClassifierWrapper,
25 DecisionTreeRegressorWrapper, ElasticNetRegressorWrapper, GaussianNaiveBayesClassifierWrapper,
26 KNNClassifierWrapper, KNNRegressorWrapper, LassoRegressorWrapper, LinearRegressorWrapper,
27 LogisticRegressionWrapper, ModelWrapper, RandomForestClassifierWrapper,
28 RandomForestRegressorWrapper, RidgeRegressorWrapper, SupportVectorClassifierWrapper,
29 SupportVectorRegressorWrapper,
30};
31
32mod utils;
33use utils::elementwise_multiply;
34
35use itertools::Itertools;
36use smartcore::{
37 dataset::Dataset,
38 decomposition::{
39 pca::{PCAParameters, PCA},
40 svd::{SVDParameters, SVD},
41 },
42 linalg::{naive::dense_matrix::DenseMatrix, BaseMatrix},
43 model_selection::{train_test_split, CrossValidationResult},
44};
45use std::{
46 cmp::Ordering::Equal,
47 fmt::{Display, Formatter},
48 io::{Read, Write},
49 time::Duration,
50};
51
52#[cfg(any(feature = "nd"))]
53use ndarray::{Array1, Array2};
54
55#[cfg(any(feature = "csv"))]
56use {
57 polars::prelude::{DataFrame, Float32Type},
58 utils::validate_and_read,
59};
60
61use {
62 comfy_table::{
63 modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL, Attribute, Cell, Table,
64 },
65 humantime::format_duration,
66};
67
68pub trait IntoSupervisedData {
70 fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>);
72}
73
74pub trait IntoFeatures {
77 fn to_dense_matrix(self) -> DenseMatrix<f32>;
79}
80
81pub trait IntoLabels {
84 fn into_vec(self) -> Vec<f32>;
86}
87
88impl IntoSupervisedData for Dataset<f32, f32> {
89 fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
90 (
91 DenseMatrix::from_array(self.num_samples, self.num_features, &self.data),
92 self.target,
93 )
94 }
95}
96
97#[cfg(any(feature = "csv"))]
98impl IntoSupervisedData for (&str, usize) {
99 fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
100 let (filepath, target_index) = self;
101 let df = validate_and_read(filepath);
102
103 let target_column_name = df.get_column_names()[target_index];
105 let series = df.column(target_column_name).unwrap().clone();
106 let target_df = DataFrame::new(vec![series]).unwrap();
107 let ndarray = target_df.to_ndarray::<Float32Type>().unwrap();
108 let y = ndarray.into_raw_vec();
109
110 let features = df.drop(target_column_name).unwrap();
112 let (height, width) = features.shape();
113 let ndarray = features.to_ndarray::<Float32Type>().unwrap();
114 let x = DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap());
115 (x, y)
116 }
117}
118
119#[cfg(any(feature = "csv"))]
120impl IntoFeatures for &str {
121 fn to_dense_matrix(self) -> DenseMatrix<f32> {
122 let df = validate_and_read(self);
123
124 let (height, width) = df.shape();
126 let ndarray = df.to_ndarray::<Float32Type>().unwrap();
127 DenseMatrix::from_array(height, width, ndarray.as_slice().unwrap())
128 }
129}
130
131impl<X, Y> IntoSupervisedData for (X, Y)
132where
133 X: IntoFeatures,
134 Y: IntoLabels,
135{
136 fn to_supervised_data(self) -> (DenseMatrix<f32>, Vec<f32>) {
137 (self.0.to_dense_matrix(), self.1.into_vec())
138 }
139}
140
141impl IntoFeatures for Vec<Vec<f32>> {
142 fn to_dense_matrix(self) -> DenseMatrix<f32> {
143 DenseMatrix::from_2d_vec(&self)
144 }
145}
146
147impl IntoLabels for Vec<f32> {
148 fn into_vec(self) -> Vec<f32> {
149 self
150 }
151}
152
153#[cfg(any(feature = "nd"))]
154impl IntoFeatures for Array2<f32> {
155 fn to_dense_matrix(self) -> DenseMatrix<f32> {
156 DenseMatrix::from_array(self.shape()[0], self.shape()[1], self.as_slice().unwrap())
157 }
158}
159
160#[cfg(any(feature = "nd"))]
161impl IntoLabels for Array1<f32> {
162 fn into_vec(self) -> Vec<f32> {
163 self.to_vec()
164 }
165}
166
167#[derive(serde::Serialize, serde::Deserialize)]
169pub struct SupervisedModel {
170 settings: Settings,
172 x_train: DenseMatrix<f32>,
174 y_train: Vec<f32>,
176 x_val: DenseMatrix<f32>,
178 y_val: Vec<f32>,
180 number_of_classes: usize,
182 comparison: Vec<Model>,
184 metamodel: Model,
186 preprocessing_pca: Option<PCA<f32, DenseMatrix<f32>>>,
188 preprocessing_svd: Option<SVD<f32, DenseMatrix<f32>>>,
190}
191
192impl SupervisedModel {
193 pub fn new<D>(data: D, settings: Settings) -> Self
243 where
244 D: IntoSupervisedData,
245 {
246 let (x, y) = data.to_supervised_data();
247 Self::build(x, y, settings)
248 }
249
250 #[must_use]
262 pub fn new_from_file(file_name: &str) -> Self {
263 let mut buf: Vec<u8> = Vec::new();
264 std::fs::File::open(file_name)
265 .and_then(|mut f| f.read_to_end(&mut buf))
266 .expect("Cannot load model from file.");
267 bincode::deserialize(&buf).expect("Can not deserialize the model")
268 }
269
270 pub fn predict<X: IntoFeatures>(&self, x: X) -> Vec<f32> {
319 let x = &self.preprocess(x.to_dense_matrix());
320 match self.settings.final_model_approach {
321 FinalModel::None => panic!(""),
322 FinalModel::Best => self.predict_by_model(x, &self.comparison[0]),
323 FinalModel::Blending { algorithm, .. } => self.predict_blended_model(x, algorithm),
324 }
325 }
326
327 pub fn train(&mut self) {
338 if let PreProcessing::ReplaceWithPCA {
340 number_of_components,
341 } = self.settings.preprocessing
342 {
343 self.train_pca(&self.x_train.clone(), number_of_components);
344 }
345 if let PreProcessing::ReplaceWithSVD {
346 number_of_components,
347 } = self.settings.preprocessing
348 {
349 self.train_svd(&self.x_train.clone(), number_of_components);
350 }
351
352 self.x_train = self.preprocess(self.x_train.clone());
354
355 if let FinalModel::Blending {
357 meta_training_fraction,
358 meta_testing_fraction: _,
359 algorithm: _,
360 } = &self.settings.final_model_approach
361 {
362 let (x_train, x_val, y_train, y_val) = train_test_split(
363 &self.x_train,
364 &self.y_train,
365 *meta_training_fraction,
366 self.settings.shuffle,
367 );
368 self.x_train = x_train;
369 self.y_train = y_train;
370 self.y_val = y_val;
371 self.x_val = x_val;
372 }
373
374 if !self
376 .settings
377 .skiplist
378 .contains(&Algorithm::LogisticRegression)
379 {
380 self.record_model(LogisticRegressionWrapper::cv_model(
381 &self.x_train,
382 &self.y_train,
383 &self.settings,
384 ));
385 }
386
387 if !self
389 .settings
390 .skiplist
391 .contains(&Algorithm::RandomForestClassifier)
392 {
393 self.record_model(RandomForestClassifierWrapper::cv_model(
394 &self.x_train,
395 &self.y_train,
396 &self.settings,
397 ));
398 }
399
400 if !self.settings.skiplist.contains(&Algorithm::KNNClassifier) {
402 self.record_model(KNNClassifierWrapper::cv_model(
403 &self.x_train,
404 &self.y_train,
405 &self.settings,
406 ));
407 }
408
409 if !self
410 .settings
411 .skiplist
412 .contains(&Algorithm::DecisionTreeClassifier)
413 {
414 self.record_model(DecisionTreeClassifierWrapper::cv_model(
415 &self.x_train,
416 &self.y_train,
417 &self.settings,
418 ));
419 }
420
421 if !self
422 .settings
423 .skiplist
424 .contains(&Algorithm::GaussianNaiveBayes)
425 {
426 self.record_model(GaussianNaiveBayesClassifierWrapper::cv_model(
427 &self.x_train,
428 &self.y_train,
429 &self.settings,
430 ));
431 }
432
433 if !self
434 .settings
435 .skiplist
436 .contains(&Algorithm::CategoricalNaiveBayes)
437 && std::mem::discriminant(&self.settings.preprocessing)
438 != std::mem::discriminant(&PreProcessing::ReplaceWithPCA {
439 number_of_components: 1,
440 })
441 && std::mem::discriminant(&self.settings.preprocessing)
442 != std::mem::discriminant(&PreProcessing::ReplaceWithSVD {
443 number_of_components: 1,
444 })
445 {
446 self.record_model(CategoricalNaiveBayesClassifierWrapper::cv_model(
447 &self.x_train,
448 &self.y_train,
449 &self.settings,
450 ));
451 }
452
453 if self.number_of_classes == 2 && !self.settings.skiplist.contains(&Algorithm::SVC) {
454 self.record_model(SupportVectorClassifierWrapper::cv_model(
455 &self.x_train,
456 &self.y_train,
457 &self.settings,
458 ));
459 }
460
461 if !self.settings.skiplist.contains(&Algorithm::Linear) {
462 self.record_model(LinearRegressorWrapper::cv_model(
463 &self.x_train,
464 &self.y_train,
465 &self.settings,
466 ));
467 }
468
469 if !self.settings.skiplist.contains(&Algorithm::SVR) {
470 self.record_model(SupportVectorRegressorWrapper::cv_model(
471 &self.x_train,
472 &self.y_train,
473 &self.settings,
474 ));
475 }
476
477 if !self.settings.skiplist.contains(&Algorithm::Lasso) {
478 self.record_model(RidgeRegressorWrapper::cv_model(
479 &self.x_train,
480 &self.y_train,
481 &self.settings,
482 ));
483 }
484
485 if !self.settings.skiplist.contains(&Algorithm::Ridge) {
486 self.record_model(LassoRegressorWrapper::cv_model(
487 &self.x_train,
488 &self.y_train,
489 &self.settings,
490 ));
491 }
492
493 if !self.settings.skiplist.contains(&Algorithm::ElasticNet) {
494 self.record_model(ElasticNetRegressorWrapper::cv_model(
495 &self.x_train,
496 &self.y_train,
497 &self.settings,
498 ));
499 }
500
501 if !self
502 .settings
503 .skiplist
504 .contains(&Algorithm::DecisionTreeRegressor)
505 {
506 self.record_model(DecisionTreeRegressorWrapper::cv_model(
507 &self.x_train,
508 &self.y_train,
509 &self.settings,
510 ));
511 }
512
513 if !self
514 .settings
515 .skiplist
516 .contains(&Algorithm::RandomForestRegressor)
517 {
518 self.record_model(RandomForestRegressorWrapper::cv_model(
519 &self.x_train,
520 &self.y_train,
521 &self.settings,
522 ));
523 }
524
525 if !self.settings.skiplist.contains(&Algorithm::KNNRegressor) {
526 self.record_model(KNNRegressorWrapper::cv_model(
527 &self.x_train,
528 &self.y_train,
529 &self.settings,
530 ));
531 }
532
533 if let FinalModel::Blending {
534 algorithm,
535 meta_training_fraction,
536 meta_testing_fraction,
537 } = self.settings.final_model_approach
538 {
539 self.train_blended_model(algorithm, meta_training_fraction, meta_testing_fraction);
540 }
541 }
542
543 pub fn save(&self, file_name: &str) {
554 let serial = bincode::serialize(&self).expect("Cannot serialize model.");
555 std::fs::File::create(file_name)
556 .and_then(|mut f| f.write_all(&serial))
557 .expect("Cannot write model to file.");
558 }
559
560 pub fn save_best(&self, file_name: &str) {
575 if matches!(self.settings.final_model_approach, FinalModel::Best) {
576 std::fs::File::create(file_name)
577 .and_then(|mut f| f.write_all(&self.comparison[0].model))
578 .expect("Cannot write model to file.");
579 }
580 }
581}
582
583impl SupervisedModel {
585 fn build(x: DenseMatrix<f32>, y: Vec<f32>, settings: Settings) -> Self {
593 Self {
594 settings,
595 x_train: x,
596 number_of_classes: Self::count_classes(&y),
597 y_train: y,
598 x_val: DenseMatrix::new(0, 0, vec![]),
599 y_val: vec![],
600 comparison: vec![],
601 metamodel: Model::default(),
602 preprocessing_pca: None,
603 preprocessing_svd: None,
604 }
605 }
606
607 fn train_blended_model(
615 &mut self,
616 algo: Algorithm,
617 training_fraction: f32,
618 testing_fraction: f32,
619 ) {
620 let mut meta_x: Vec<Vec<f32>> = Vec::new();
622 for model in &self.comparison {
623 meta_x.push(self.predict_by_model(&self.x_val, model));
624 }
625 let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose();
626
627 let (x_train, x_test, y_train, y_test) = train_test_split(
629 &xdm,
630 &self.y_val,
631 training_fraction / (training_fraction + testing_fraction),
632 self.settings.shuffle,
633 );
634
635 let model = algo.get_trainer()(&x_train, &y_train, &self.settings);
638
639 let train_score = self.settings.get_metric()(
641 &y_train,
642 &algo.get_predictor()(&x_train, &model, &self.settings),
643 );
645 let test_score = self.settings.get_metric()(
646 &y_test,
647 &algo.get_predictor()(&x_test, &model, &self.settings),
648 );
650
651 self.metamodel = Model {
652 score: CrossValidationResult {
653 test_score: vec![test_score; 1],
654 train_score: vec![train_score; 1],
655 },
656 name: algo,
657 duration: Duration::default(),
658 model,
659 };
660 }
661
662 fn predict_blended_model(&self, x: &DenseMatrix<f32>, algo: Algorithm) -> Vec<f32> {
673 let mut meta_x: Vec<Vec<f32>> = Vec::new();
675 for i in 0..self.comparison.len() {
676 let model = &self.comparison[i];
677 meta_x.push(self.predict_by_model(x, model));
678 }
679
680 let xdm = DenseMatrix::from_2d_vec(&meta_x).transpose();
682 let metamodel = &self.metamodel.model;
683
684 algo.get_predictor()(&xdm, metamodel, &self.settings)
686 }
687
688 fn predict_by_model(&self, x: &DenseMatrix<f32>, model: &Model) -> Vec<f32> {
699 model.name.get_predictor()(x, &model.model, &self.settings)
700 }
701
702 fn interaction_features(mut x: DenseMatrix<f32>) -> DenseMatrix<f32> {
706 let (_, width) = x.shape();
707 for i in 0..width {
708 for j in (i + 1)..width {
709 let feature = elementwise_multiply(&x.get_col_as_vec(i), &x.get_col_as_vec(j));
710 let new_column = DenseMatrix::from_row_vector(feature).transpose();
711 x = x.h_stack(&new_column);
712 }
713 }
714 x
715 }
716
717 fn polynomial_features(mut x: DenseMatrix<f32>, order: usize) -> DenseMatrix<f32> {
728 let (height, width) = x.shape();
729 for n in 2..=order {
730 let combinations = (0..width).combinations_with_replacement(n);
731 for combo in combinations {
732 let mut feature = vec![1.0; height];
733 for column in combo {
734 feature = elementwise_multiply(&x.get_col_as_vec(column), &feature);
735 }
736 let new_column = DenseMatrix::from_row_vector(feature).transpose();
737 x = x.h_stack(&new_column);
738 }
739 }
740 x
741 }
742
743 fn train_pca(&mut self, x: &DenseMatrix<f32>, n: usize) {
750 let pca = PCA::fit(
751 x,
752 PCAParameters::default()
753 .with_n_components(n)
754 .with_use_correlation_matrix(true),
755 )
756 .unwrap();
757 self.preprocessing_pca = Some(pca);
758 }
759
760 fn pca_features(&self, x: &DenseMatrix<f32>, _: usize) -> DenseMatrix<f32> {
766 self.preprocessing_pca
767 .as_ref()
768 .unwrap()
769 .transform(x)
770 .unwrap()
771 }
772
773 fn train_svd(&mut self, x: &DenseMatrix<f32>, n: usize) {
780 let svd = SVD::fit(x, SVDParameters::default().with_n_components(n)).unwrap();
781 self.preprocessing_svd = Some(svd);
782 }
783
784 fn svd_features(&self, x: &DenseMatrix<f32>, _: usize) -> DenseMatrix<f32> {
786 self.preprocessing_svd
787 .as_ref()
788 .unwrap()
789 .transform(x)
790 .unwrap()
791 }
792
793 fn preprocess(&self, x: DenseMatrix<f32>) -> DenseMatrix<f32> {
803 match self.settings.preprocessing {
804 PreProcessing::None => x,
805 PreProcessing::AddInteractions => Self::interaction_features(x),
806 PreProcessing::AddPolynomial { order } => Self::polynomial_features(x, order),
807 PreProcessing::ReplaceWithPCA {
808 number_of_components,
809 } => self.pca_features(&x, number_of_components),
810 PreProcessing::ReplaceWithSVD {
811 number_of_components,
812 } => self.svd_features(&x, number_of_components),
813 }
814 }
815
816 fn count_classes(y: &[f32]) -> usize {
826 let mut sorted_targets = y.to_vec();
827 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Equal));
828 sorted_targets.dedup();
829 sorted_targets.len()
830 }
831
832 fn record_model(&mut self, model: (CrossValidationResult<f32>, Algorithm, Duration, Vec<u8>)) {
834 self.comparison.push(Model {
835 score: model.0,
836 name: model.1,
837 duration: model.2,
838 model: model.3,
839 });
840 self.sort();
841 }
842
843 fn sort(&mut self) {
845 self.comparison.sort_by(|a, b| {
846 a.score
847 .mean_test_score()
848 .partial_cmp(&b.score.mean_test_score())
849 .unwrap_or(Equal)
850 });
851 if self.settings.sort_by == Metric::RSquared || self.settings.sort_by == Metric::Accuracy {
852 self.comparison.reverse();
853 }
854 }
855}
856
857impl Display for SupervisedModel {
858 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
859 let mut table = Table::new();
860 table.load_preset(UTF8_FULL);
861 table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
862 table.set_header(vec![
863 Cell::new("Model").add_attribute(Attribute::Bold),
864 Cell::new("Time").add_attribute(Attribute::Bold),
865 Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
866 Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
867 ]);
868 for model in &self.comparison {
869 let mut row_vec = vec![];
870 row_vec.push(format!("{}", &model.name));
871 row_vec.push(format!("{}", format_duration(model.duration)));
872 let decider =
873 ((model.score.mean_train_score() + model.score.mean_test_score()) / 2.0).abs();
874 if decider > 0.01 && decider < 1000.0 {
875 row_vec.push(format!("{:.2}", &model.score.mean_train_score()));
876 row_vec.push(format!("{:.2}", &model.score.mean_test_score()));
877 } else {
878 row_vec.push(format!("{:.3e}", &model.score.mean_train_score()));
879 row_vec.push(format!("{:.3e}", &model.score.mean_test_score()));
880 }
881
882 table.add_row(row_vec);
883 }
884
885 let mut meta_table = Table::new();
886 meta_table.load_preset(UTF8_FULL);
887 meta_table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
888 meta_table.set_header(vec![
889 Cell::new("Meta Model").add_attribute(Attribute::Bold),
890 Cell::new(format!("Training {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
891 Cell::new(format!("Testing {}", self.settings.sort_by)).add_attribute(Attribute::Bold),
892 ]);
893
894 let mut row_vec = vec![];
896 row_vec.push(format!("{}", self.metamodel.name));
897 let decider = ((self.metamodel.score.mean_train_score()
898 + self.metamodel.score.mean_test_score())
899 / 2.0)
900 .abs();
901 if decider > 0.01 && decider < 1000.0 {
902 row_vec.push(format!("{:.2}", self.metamodel.score.mean_train_score()));
903 row_vec.push(format!("{:.2}", self.metamodel.score.mean_test_score()));
904 } else {
905 row_vec.push(format!("{:.3e}", self.metamodel.score.mean_train_score()));
906 row_vec.push(format!("{:.3e}", self.metamodel.score.mean_test_score()));
907 }
908
909 meta_table.add_row(row_vec);
911
912 write!(f, "{table}\n{meta_table}")
914 }
915}
916
917#[derive(serde::Serialize, serde::Deserialize)]
919struct Model {
920 #[serde(with = "CrossValidationResultDef")]
922 score: CrossValidationResult<f32>,
923 name: Algorithm,
925 duration: Duration,
927 model: Vec<u8>,
929}
930
931impl Default for Model {
932 fn default() -> Self {
933 Self {
934 score: CrossValidationResult {
935 test_score: vec![],
936 train_score: vec![],
937 },
938 name: Algorithm::Linear,
939 duration: Duration::default(),
940 model: vec![],
941 }
942 }
943}
944
945#[derive(serde::Serialize, serde::Deserialize)]
947#[serde(remote = "CrossValidationResult::<f32>")]
948struct CrossValidationResultDef {
949 pub test_score: Vec<f32>,
951 pub train_score: Vec<f32>,
953}