linfa_ensemble/lib.rs
1//! # Ensemble Learning Algorithms
2//!
3//! Ensemble methods combine the predictions of several base estimators built with a given
4//! learning algorithm in order to improve generalizability / robustness over a single estimator.
5//!
6//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
7//! * [Boostrap Aggregation](EnsembleLearner)
8//! * [Random Forest](RandomForest)
9//! * [AdaBoost]
10//!
11//! ## Bootstrap Aggregation (aka Bagging)
12//!
13//! A typical example of ensemble method is Bootstrap Aggregation, which combines the predictions of
14//! several decision trees (see [`linfa-trees`](linfa_trees)) trained on different samples subset of the training dataset.
15//!
16//! ## Random Forest
17//!
18//! A special case of Bootstrap Aggregation using decision trees (see [`linfa-trees`](linfa_trees)) with random feature
19//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
20//! the number of available features.
21//!
22//! ## AdaBoost
23//!
24//! AdaBoost (Adaptive Boosting) is a boosting ensemble method that trains weak learners sequentially.
25//! Each subsequent learner focuses on the examples that previous learners misclassified by increasing
26//! their sample weights. The final prediction is a weighted vote of all learners, where better-performing
27//! learners receive higher weights. Unlike bagging methods, boosting creates a strong classifier from
28//! weak learners (typically shallow decision trees or "stumps").
29//!
30//! ## Reference
31//!
32//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
33//! * [An Introduction to Statistical Learning](https://www.statlearning.com/)
34//!
35//! ## Example
36//!
37//! This example shows how to train a bagging model using 100 decision trees,
38//! each trained on 70% of the training data (bootstrap sampling).
39//!
40//! ```no_run
41//! use linfa::prelude::{Fit, Predict};
42//! use linfa_ensemble::EnsembleLearnerParams;
43//! use linfa_trees::DecisionTree;
44//! use ndarray_rand::rand::SeedableRng;
45//! use rand::rngs::SmallRng;
46//!
47//! // Load Iris dataset
48//! let mut rng = SmallRng::seed_from_u64(42);
49//! let (train, test) = linfa_datasets::iris()
50//! .shuffle(&mut rng)
51//! .split_with_ratio(0.8);
52//!
53//! // Train the model on the iris dataset
54//! let bagging_model = EnsembleLearnerParams::new(DecisionTree::params())
55//! .ensemble_size(100) // Number of Decision Tree to fit
56//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
57//! .fit(&train)
58//! .unwrap();
59//!
60//! // Make predictions on the test set
61//! let predictions = bagging_model.predict(&test);
62//! ```
63//!
64//! This example shows how to train a [Random Forest](RandomForest) model using 100 decision trees,
65//! each trained on 70% of the training data (bootstrap sampling) and using only
66//! 30% of the available features.
67//!
68//! ```no_run
69//! use linfa::prelude::{Fit, Predict};
70//! use linfa_ensemble::RandomForestParams;
71//! use linfa_trees::DecisionTree;
72//! use ndarray_rand::rand::SeedableRng;
73//! use rand::rngs::SmallRng;
74//!
75//! // Load Iris dataset
76//! let mut rng = SmallRng::seed_from_u64(42);
77//! let (train, test) = linfa_datasets::iris()
78//! .shuffle(&mut rng)
79//! .split_with_ratio(0.8);
80//!
81//! // Train the model on the iris dataset
82//! let random_forest = RandomForestParams::new(DecisionTree::params())
83//! .ensemble_size(100) // Number of Decision Tree to fit
84//! .bootstrap_proportion(0.7) // Select only 70% of the data via bootstrap
85//! .feature_proportion(0.3) // Select only 30% of the feature
86//! .fit(&train)
87//! .unwrap();
88//!
89//! // Make predictions on the test set
90//! let predictions = random_forest.predict(&test);
91//! ```
92
93mod adaboost;
94mod adaboost_hyperparams;
95mod algorithm;
96mod hyperparams;
97
98pub use adaboost::*;
99pub use adaboost_hyperparams::*;
100pub use algorithm::*;
101pub use hyperparams::*;
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
107 use linfa_trees::DecisionTree;
108 use ndarray_rand::rand::SeedableRng;
109 use rand::rngs::SmallRng;
110
111 #[test]
112 fn test_random_forest_accuracy_on_iris_dataset() {
113 let mut rng = SmallRng::seed_from_u64(42);
114 let (train, test) = linfa_datasets::iris()
115 .shuffle(&mut rng)
116 .split_with_ratio(0.8);
117
118 let model = RandomForestParams::new_fixed_rng(DecisionTree::params(), rng)
119 .ensemble_size(100)
120 .bootstrap_proportion(0.7)
121 .feature_proportion(0.3)
122 .fit(&train)
123 .unwrap();
124
125 let predictions = model.predict(&test);
126
127 let cm = predictions.confusion_matrix(&test).unwrap();
128 let acc = cm.accuracy();
129 assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
130 }
131
132 #[test]
133 fn test_ensemble_learner_accuracy_on_iris_dataset() {
134 let mut rng = SmallRng::seed_from_u64(42);
135 let (train, test) = linfa_datasets::iris()
136 .shuffle(&mut rng)
137 .split_with_ratio(0.8);
138
139 let model = EnsembleLearnerParams::new_fixed_rng(DecisionTree::params(), rng)
140 .ensemble_size(100)
141 .bootstrap_proportion(0.7)
142 .fit(&train)
143 .unwrap();
144
145 let predictions = model.predict(&test);
146
147 let cm = predictions.confusion_matrix(&test).unwrap();
148 let acc = cm.accuracy();
149 assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
150 }
151
152 #[test]
153 fn test_adaboost_accuracy_on_iris_dataset() {
154 let mut rng = SmallRng::seed_from_u64(42);
155 let (train, test) = linfa_datasets::iris()
156 .shuffle(&mut rng)
157 .split_with_ratio(0.8);
158
159 // Train AdaBoost with decision tree stumps (shallow trees)
160 let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
161 .n_estimators(50)
162 .learning_rate(1.0)
163 .fit(&train)
164 .unwrap();
165
166 let predictions = model.predict(&test);
167
168 let cm = predictions.confusion_matrix(&test).unwrap();
169 let acc = cm.accuracy();
170 assert!(
171 acc >= 0.85,
172 "Expected accuracy to be above 85%, got {}",
173 acc
174 );
175 }
176
177 #[test]
178 fn test_adaboost_with_low_learning_rate() {
179 let mut rng = SmallRng::seed_from_u64(42);
180 let (train, test) = linfa_datasets::iris()
181 .shuffle(&mut rng)
182 .split_with_ratio(0.8);
183
184 // Train AdaBoost with lower learning rate and more estimators
185 let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(2)), rng)
186 .n_estimators(100)
187 .learning_rate(0.5)
188 .fit(&train)
189 .unwrap();
190
191 let predictions = model.predict(&test);
192
193 let cm = predictions.confusion_matrix(&test).unwrap();
194 let acc = cm.accuracy();
195 assert!(
196 acc >= 0.85,
197 "Expected accuracy to be above 85%, got {}",
198 acc
199 );
200 }
201
202 #[test]
203 fn test_adaboost_model_weights() {
204 let mut rng = SmallRng::seed_from_u64(42);
205 let (train, _) = linfa_datasets::iris()
206 .shuffle(&mut rng)
207 .split_with_ratio(0.8);
208
209 let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
210 .n_estimators(10)
211 .fit(&train)
212 .unwrap();
213
214 // Verify that model weights are positive
215 for weight in model.weights() {
216 assert!(
217 *weight > 0.0,
218 "Model weight should be positive, got {}",
219 weight
220 );
221 }
222
223 // Verify we have the expected number of models
224 assert_eq!(model.n_estimators(), 10);
225 }
226
227 #[test]
228 fn test_adaboost_different_learning_rates() {
229 // Test that different learning rates produce different model weights
230 let rng1 = SmallRng::seed_from_u64(42);
231 let rng2 = SmallRng::seed_from_u64(42);
232 let (train, _) = linfa_datasets::iris()
233 .shuffle(&mut SmallRng::seed_from_u64(42))
234 .split_with_ratio(0.8);
235
236 // Train with learning_rate = 1.0
237 let model1 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng1)
238 .n_estimators(10)
239 .learning_rate(1.0)
240 .fit(&train)
241 .unwrap();
242
243 // Train with learning_rate = 0.5
244 let model2 = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng2)
245 .n_estimators(10)
246 .learning_rate(0.5)
247 .fit(&train)
248 .unwrap();
249
250 // Model weights should be different
251 let weights1 = model1.weights();
252 let weights2 = model2.weights();
253
254 // At least one weight should be significantly different
255 let mut has_difference = false;
256 for (w1, w2) in weights1.iter().zip(weights2.iter()) {
257 if (w1 - w2).abs() > 0.01 {
258 has_difference = true;
259 break;
260 }
261 }
262
263 assert!(
264 has_difference,
265 "Different learning rates should produce different model weights"
266 );
267 }
268
269 #[test]
270 fn test_adaboost_early_stopping_on_perfect_fit() {
271 use linfa::DatasetBase;
272 use ndarray::Array2;
273
274 // Create a simple linearly separable dataset
275 let records = Array2::from_shape_vec(
276 (6, 2),
277 vec![
278 0.0, 0.0, // class 0
279 0.1, 0.1, // class 0
280 0.2, 0.2, // class 0
281 1.0, 1.0, // class 1
282 1.1, 1.1, // class 1
283 1.2, 1.2, // class 1
284 ],
285 )
286 .unwrap();
287 let targets = ndarray::array![0, 0, 0, 1, 1, 1];
288 let dataset = DatasetBase::new(records, targets);
289
290 let rng = SmallRng::seed_from_u64(42);
291 let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(3)), rng)
292 .n_estimators(50)
293 .fit(&dataset)
294 .unwrap();
295
296 // Should stop early due to perfect classification
297 assert!(
298 model.n_estimators() < 50,
299 "Expected early stopping, but got {} estimators",
300 model.n_estimators()
301 );
302 }
303
304 #[test]
305 fn test_adaboost_single_class_error() {
306 use linfa::DatasetBase;
307 use ndarray::Array2;
308
309 // Create dataset with only one class
310 let records =
311 Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3]).unwrap();
312 let targets = ndarray::array![0, 0, 0, 0]; // All same class
313 let dataset = DatasetBase::new(records, targets);
314
315 let rng = SmallRng::seed_from_u64(42);
316 let result = AdaBoostParams::new_fixed_rng(DecisionTree::params(), rng)
317 .n_estimators(10)
318 .fit(&dataset);
319
320 assert!(result.is_err(), "Should fail with single class dataset");
321 }
322
323 #[test]
324 fn test_adaboost_classes_method() {
325 let mut rng = SmallRng::seed_from_u64(42);
326 let (train, _) = linfa_datasets::iris()
327 .shuffle(&mut rng)
328 .split_with_ratio(0.8);
329
330 let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
331 .n_estimators(10)
332 .fit(&train)
333 .unwrap();
334
335 // Verify classes are properly stored
336 let classes = &model.classes;
337 assert_eq!(classes.len(), 3, "Iris has 3 classes");
338 assert_eq!(classes, &vec![0, 1, 2], "Classes should be [0, 1, 2]");
339 }
340}