use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
use linfa_ensemble::{EnsembleLearnerParams, RandomForestParams};
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;
fn ensemble_learner(ensemble_size: usize, bootstrap_proportion: f64) {
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);
let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
.ensemble_size(ensemble_size)
.bootstrap_proportion(bootstrap_proportion)
.fit(&train)
.unwrap();
let final_predictions_ensemble = model.predict(&test);
println!("Final Predictions: \n{final_predictions_ensemble:?}");
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
println!("{cm:?}");
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}.\n",
100.0 * cm.accuracy());
}
fn random_forest(ensemble_size: usize, bootstrap_proportion: f64, feature_proportion: f64) {
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);
let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
.ensemble_size(ensemble_size)
.bootstrap_proportion(bootstrap_proportion)
.feature_proportion(feature_proportion)
.fit(&train)
.unwrap();
let final_predictions_ensemble = model.predict(&test);
println!("Final Predictions: \n{final_predictions_ensemble:?}");
let cm = final_predictions_ensemble.confusion_matrix(&test).unwrap();
println!("{cm:?}");
println!("Test accuracy: {} \n with default Decision Tree params, \n Ensemble Size: {ensemble_size},\n Bootstrap Proportion: {bootstrap_proportion}\n Feature selection proportion: {feature_proportion}.\n",
100.0 * cm.accuracy());
}
fn main() {
println!("An example using Bagging with Decision Tree on Iris Dataset");
ensemble_learner(100, 0.7);
println!("An example using a Random Forest on Iris Dataset");
random_forest(100, 0.7, 0.2);
}