Skip to main content

linfa_ensemble/
lib.rs

1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
7//! * [Boostrap Aggregation](EnsembleLearner)
8//! * [Random Forest](RandomForest)
9//! * [AdaBoost]
10//!
11//! ## Bootstrap Aggregation (aka Bagging)
12//!
13//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
14//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
15//!
16//! ## Random Forest
17//!
18//! A special case of Bootstrap Aggregation using decision trees (see  [`linfa-trees`](linfa_trees)) with random feature
19//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
20//! the number of available features.
21//!
22//! ## AdaBoost
23//!
24//! AdaBoost (Adaptive Boosting) is a boosting ensemble method that trains weak learners sequentially.
25//! Each subsequent learner focuses on the examples that previous learners misclassified by increasing
26//! their sample weights. The final prediction is a weighted vote of all learners, where better-performing
27//! learners receive higher weights. Unlike bagging methods, boosting creates a strong classifier from
28//! weak learners (typically shallow decision trees or "stumps").
29//!
30//! ## Reference
31//!
32//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
33//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
34//!
35//! ## Example
36//!
37//! This example shows how to train a bagging model using 100 decision trees,
38//! each trained on 70% of the training data (bootstrap sampling).
39//!
40//! ```no_run
41//! use linfa::prelude::{Fit, Predict};
42//! use linfa_ensemble::EnsembleLearnerParams;
43//! use linfa_trees::DecisionTree;
44//! use ndarray_rand::rand::SeedableRng;
45//! use rand::rngs::SmallRng;
46//!
47//! // Load Iris dataset
48//! let mut rng = SmallRng::seed_from_u64(42);
49//! let (train, test) = linfa_datasets::iris()
50//!     .shuffle(&mut rng)
51//!     .split_with_ratio(0.8);
52//!
53//! // Train the model on the iris dataset
54//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
55//!     .ensemble_size(100)        // Number of Decision Tree to fit
56//!     .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
57//!     .fit(&train)
58//!     .unwrap();
59//!
60//! // Make predictions on the test set
61//! let predictions = bagging_model.predict(&test);
62//! ```
63//!
64//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
65//! each trained on 70% of the training data (bootstrap sampling) and using only
66//! 30% of the available features.
67//!
68//! ```no_run
69//! use linfa::prelude::{Fit, Predict};
70//! use linfa_ensemble::RandomForestParams;
71//! use linfa_trees::DecisionTree;
72//! use ndarray_rand::rand::SeedableRng;
73//! use rand::rngs::SmallRng;
74//!
75//! // Load Iris dataset
76//! let mut rng = SmallRng::seed_from_u64(42);
77//! let (train, test) = linfa_datasets::iris()
78//!     .shuffle(&mut rng)
79//!     .split_with_ratio(0.8);
80//!
81//! // Train the model on the iris dataset
82//! let random_forest = RandomForestParams::new(DecisionTree::params())
83//!     .ensemble_size(100)        // Number of Decision Tree to fit
84//!     .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
85//!     .feature_proportion(0.3)   // Select only 30% of the feature
86//!     .fit(&train)
87//!     .unwrap();
88//!
89//! // Make predictions on the test set
90//! let predictions = random_forest.predict(&test);
91//! ```
92
93mod adaboost;
94mod adaboost_hyperparams;
95mod algorithm;
96mod hyperparams;
97
98pub use adaboost::*;
99pub use adaboost_hyperparams::*;
100pub use algorithm::*;
101pub use hyperparams::*;
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
107    use linfa_trees::DecisionTree;
108    use ndarray_rand::rand::SeedableRng;
109    use rand::rngs::SmallRng;
110
111    #[test]
112    fn test_random_forest_accuracy_on_iris_dataset() {
113        let mut rng = SmallRng::seed_from_u64(42);
114        let (train, test) = linfa_datasets::iris()
115            .shuffle(&mut rng)
116            .split_with_ratio(0.8);
117
118        let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
119            .ensemble_size(100)
120            .bootstrap_proportion(0.7)
121            .feature_proportion(0.3)
122            .fit(&train)
123            .unwrap();
124
125        let predictions = model.predict(&test);
126
127        let cm = predictions.confusion_matrix(&test).unwrap();
128        let acc = cm.accuracy();
129        assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
130    }
131
132    #[test]
133    fn test_ensemble_learner_accuracy_on_iris_dataset() {
134        let mut rng = SmallRng::seed_from_u64(42);
135        let (train, test) = linfa_datasets::iris()
136            .shuffle(&mut rng)
137            .split_with_ratio(0.8);
138
139        let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
140            .ensemble_size(100)
141            .bootstrap_proportion(0.7)
142            .fit(&train)
143            .unwrap();
144
145        let predictions = model.predict(&test);
146
147        let cm = predictions.confusion_matrix(&test).unwrap();
148        let acc = cm.accuracy();
149        assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
150    }
151
152    #[test]
153    fn test_adaboost_accuracy_on_iris_dataset() {
154        let mut rng = SmallRng::seed_from_u64(42);
155        let (train, test) = linfa_datasets::iris()
156            .shuffle(&mut rng)
157            .split_with_ratio(0.8);
158
159        // Train AdaBoost with decision tree stumps (shallow trees)
160        let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
161            .n_estimators(50)
162            .learning_rate(1.0)
163            .fit(&train)
164            .unwrap();
165
166        let predictions = model.predict(&test);
167
168        let cm = predictions.confusion_matrix(&test).unwrap();
169        let acc = cm.accuracy();
170        assert!(
171            acc >= 0.85,
172            "Expected accuracy to be above 85%, got {}",
173            acc
174        );
175    }
176
177    #[test]
178    fn test_adaboost_with_low_learning_rate() {
179        let mut rng = SmallRng::seed_from_u64(42);
180        let (train, test) = linfa_datasets::iris()
181            .shuffle(&mut rng)
182            .split_with_ratio(0.8);
183
184        // Train AdaBoost with lower learning rate and more estimators
185        let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(2)), rng)
186            .n_estimators(100)
187            .learning_rate(0.5)
188            .fit(&train)
189            .unwrap();
190
191        let predictions = model.predict(&test);
192
193        let cm = predictions.confusion_matrix(&test).unwrap();
194        let acc = cm.accuracy();
195        assert!(
196            acc >= 0.85,
197            "Expected accuracy to be above 85%, got {}",
198            acc
199        );
200    }
201
202    #[test]
203    fn test_adaboost_model_weights() {
204        let mut rng = SmallRng::seed_from_u64(42);
205        let (train, _) = linfa_datasets::iris()
206            .shuffle(&mut rng)
207            .split_with_ratio(0.8);
208
209        let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
210            .n_estimators(10)
211            .fit(&train)
212            .unwrap();
213
214        // Verify that model weights are positive
215        for weight in model.weights() {
216            assert!(
217                *weight > 0.0,
218                "Model weight should be positive, got {}",
219                weight
220            );
221        }
222
223        // Verify we have the expected number of models
224        assert_eq!(model.n_estimators(), 10);
225    }
226
227    #[test]
228    fn test_adaboost_different_learning_rates() {
229        // Test that different learning rates produce different model weights
230        let rng1 = SmallRng::seed_from_u64(42);
231        let rng2 = SmallRng::seed_from_u64(42);
232        let (train, _) = linfa_datasets::iris()
233            .shuffle(&mut SmallRng::seed_from_u64(42))
234            .split_with_ratio(0.8);
235
236        // Train with learning_rate = 1.0
237        let model1 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng1)
238            .n_estimators(10)
239            .learning_rate(1.0)
240            .fit(&train)
241            .unwrap();
242
243        // Train with learning_rate = 0.5
244        let model2 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng2)
245            .n_estimators(10)
246            .learning_rate(0.5)
247            .fit(&train)
248            .unwrap();
249
250        // Model weights should be different
251        let weights1 = model1.weights();
252        let weights2 = model2.weights();
253
254        // At least one weight should be significantly different
255        let mut has_difference = false;
256        for (w1, w2) in weights1.iter().zip(weights2.iter()) {
257            if (w1 - w2).abs() > 0.01 {
258                has_difference = true;
259                break;
260            }
261        }
262
263        assert!(
264            has_difference,
265            "Different learning rates should produce different model weights"
266        );
267    }
268
269    #[test]
270    fn test_adaboost_early_stopping_on_perfect_fit() {
271        use linfa::DatasetBase;
272        use ndarray::Array2;
273
274        // Create a simple linearly separable dataset
275        let records = Array2::from_shape_vec(
276            (6, 2),
277            vec![
278                0.0, 0.0, // class 0
279                0.1, 0.1, // class 0
280                0.2, 0.2, // class 0
281                1.0, 1.0, // class 1
282                1.1, 1.1, // class 1
283                1.2, 1.2, // class 1
284            ],
285        )
286        .unwrap();
287        let targets = ndarray::array![0, 0, 0, 1, 1, 1];
288        let dataset = DatasetBase::new(records, targets);
289
290        let rng = SmallRng::seed_from_u64(42);
291        let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(3)), rng)
292            .n_estimators(50)
293            .fit(&dataset)
294            .unwrap();
295
296        // Should stop early due to perfect classification
297        assert!(
298            model.n_estimators() < 50,
299            "Expected early stopping, but got {} estimators",
300            model.n_estimators()
301        );
302    }
303
304    #[test]
305    fn test_adaboost_single_class_error() {
306        use linfa::DatasetBase;
307        use ndarray::Array2;
308
309        // Create dataset with only one class
310        let records =
311            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3]).unwrap();
312        let targets = ndarray::array![0, 0, 0, 0]; // All same class
313        let dataset = DatasetBase::new(records, targets);
314
315        let rng = SmallRng::seed_from_u64(42);
316        let result = AdaBoostParams::new_fixed_rng(DecisionTree::params(), rng)
317            .n_estimators(10)
318            .fit(&dataset);
319
320        assert!(result.is_err(), "Should fail with single class dataset");
321    }
322
323    #[test]
324    fn test_adaboost_classes_method() {
325        let mut rng = SmallRng::seed_from_u64(42);
326        let (train, _) = linfa_datasets::iris()
327            .shuffle(&mut rng)
328            .split_with_ratio(0.8);
329
330        let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
331            .n_estimators(10)
332            .fit(&train)
333            .unwrap();
334
335        // Verify classes are properly stored
336        let classes = &model.classes;
337        assert_eq!(classes.len(), 3, "Iris has 3 classes");
338        assert_eq!(classes, &vec![0, 1, 2], "Classes should be [0, 1, 2]");
339    }
340}