linfa_ensemble/
adaboost.rs1use crate::AdaBoostValidParams;
2use linfa::{
3 dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned},
4 error::Error,
5 traits::*,
6 DatasetBase,
7};
8use ndarray::{Array1, Array2, Axis};
9use ndarray_rand::rand::distributions::WeightedIndex;
10use ndarray_rand::rand::prelude::*;
11use ndarray_rand::rand::Rng;
12use std::{cmp::Eq, collections::HashMap, hash::Hash};
13
14const PERFECT_MODEL_WEIGHT: f64 = 1e6;
16
17#[derive(Debug, Clone)]
78pub struct AdaBoost<M, L> {
79 pub models: Vec<M>,
81 pub model_weights: Vec<f64>,
83 pub classes: Vec<L>,
85}
86
87impl<M, L> AdaBoost<M, L> {
88 pub fn n_estimators(&self) -> usize {
90 self.models.len()
91 }
92
93 pub fn weights(&self) -> &[f64] {
95 &self.model_weights
96 }
97}
98
99impl<F: Clone, T, M, L> PredictInplace<Array2<F>, T> for AdaBoost<M, L>
100where
101 M: PredictInplace<Array2<F>, T>,
102 <T as AsTargets>::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
103 T: AsTargets + AsTargetsMut<Elem = <T as AsTargets>::Elem>,
104 usize: Into<<T as AsTargets>::Elem>,
105{
106 fn predict_inplace(&self, x: &Array2<F>, y: &mut T) {
107 let y_array = y.as_targets();
108 assert_eq!(
109 x.nrows(),
110 y_array.len_of(Axis(0)),
111 "The number of data points must match the number of outputs."
112 );
113
114 let mut all_predictions = Vec::with_capacity(self.models.len());
116 for model in &self.models {
117 let mut pred = model.default_target(x);
118 model.predict_inplace(x, &mut pred);
119 all_predictions.push(pred);
120 }
121
122 let mut prediction_maps = y_array.map(|_| HashMap::new());
124
125 for (model_idx, prediction) in all_predictions.iter().enumerate() {
127 let pred_array = prediction.as_targets();
128 let weight = self.model_weights[model_idx];
129
130 for (vote_map, &pred_val) in prediction_maps.iter_mut().zip(pred_array.iter()) {
132 let class_idx: usize = pred_val.into();
133 *vote_map.entry(class_idx).or_insert(0.0) += weight;
134 }
135 }
136
137 let final_predictions = prediction_maps.map(|votes| {
139 votes
140 .iter()
141 .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap())
142 .map(|(k, _)| (*k).into())
143 .expect("No predictions found")
144 });
145
146 let mut y_array_mut = y.as_targets_mut();
148 for (y, pred) in y_array_mut.iter_mut().zip(final_predictions.iter()) {
149 *y = *pred;
150 }
151 }
152
153 fn default_target(&self, x: &Array2<F>) -> T {
154 self.models[0].default_target(x)
155 }
156}
157
158impl<D, T, P, R> Fit<Array2<D>, T, Error> for AdaBoostValidParams<P, R>
159where
160 D: Clone + ndarray::ScalarOperand,
161 T: FromTargetArrayOwned<Owned = T> + AsTargets + Clone,
162 T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
163 P: Fit<Array2<D>, T, Error> + Clone,
164 P::Object: PredictInplace<Array2<D>, T>,
165 R: Rng + Clone,
166 usize: Into<T::Elem>,
167{
168 type Object = AdaBoost<P::Object, T::Elem>;
169
170 fn fit(
171 &self,
172 dataset: &DatasetBase<Array2<D>, T>,
173 ) -> core::result::Result<Self::Object, Error> {
174 let n_samples = dataset.records.nrows();
175
176 if n_samples == 0 {
177 return Err(Error::Parameters(
178 "Cannot fit AdaBoost on empty dataset".to_string(),
179 ));
180 }
181
182 let target_array = dataset.targets.as_targets();
184 let mut classes_set: Vec<T::Elem> = target_array
185 .iter()
186 .copied()
187 .collect::<std::collections::HashSet<_>>()
188 .into_iter()
189 .collect();
190 classes_set.sort_unstable_by_key(|x| (*x).into());
192
193 if classes_set.len() < 2 {
194 return Err(Error::Parameters(
195 "AdaBoost requires at least 2 classes".to_string(),
196 ));
197 }
198
199 let mut sample_weights = Array1::from_elem(n_samples, 1.0 / n_samples as f64);
201
202 let mut models = Vec::with_capacity(self.n_estimators);
203 let mut model_weights = Vec::with_capacity(self.n_estimators);
204
205 let mut rng = self.rng.clone();
206
207 for iteration in 0..self.n_estimators {
208 let weight_sum = sample_weights.sum();
210 if weight_sum <= 0.0 {
211 return Err(Error::NotConverged(format!(
212 "Sample weights sum to zero at iteration {}",
213 iteration
214 )));
215 }
216 sample_weights /= weight_sum;
217
218 let dist = WeightedIndex::new(sample_weights.iter().copied())
221 .map_err(|_| Error::Parameters("Invalid sample weights".to_string()))?;
222
223 let bootstrap_indices: Vec<usize> =
224 (0..n_samples).map(|_| dist.sample(&mut rng)).collect();
225
226 let bootstrap_records = dataset.records.select(Axis(0), &bootstrap_indices);
228 let bootstrap_targets_array = target_array.select(Axis(0), &bootstrap_indices);
229
230 let bootstrap_targets = T::new_targets(bootstrap_targets_array);
232 let bootstrap_dataset = DatasetBase::new(bootstrap_records, bootstrap_targets);
233
234 let model = self.model_params.fit(&bootstrap_dataset).map_err(|e| {
236 Error::NotConverged(format!(
237 "Base learner failed to fit at iteration {}: {}",
238 iteration, e
239 ))
240 })?;
241
242 let mut predictions = model.default_target(&dataset.records);
244 model.predict_inplace(&dataset.records, &mut predictions);
245 let pred_array = predictions.as_targets();
246
247 let mut weighted_error = 0.0;
249 for ((true_label, pred_label), weight) in target_array
250 .iter()
251 .zip(pred_array.iter())
252 .zip(sample_weights.iter())
253 {
254 let true_idx: usize = (*true_label).into();
255 let pred_idx: usize = (*pred_label).into();
256
257 if true_idx != pred_idx {
258 weighted_error += *weight;
259 }
260 }
261
262 if weighted_error <= 0.0 {
264 model_weights.push(PERFECT_MODEL_WEIGHT); models.push(model);
267 break;
268 }
269
270 let k = classes_set.len() as f64;
272 let error_threshold = (k - 1.0) / k;
273
274 if weighted_error >= error_threshold {
275 if models.is_empty() {
277 return Err(Error::NotConverged(format!(
278 "First base learner performs worse than random guessing (error: {:.4}, threshold: {:.4})",
279 weighted_error, error_threshold
280 )));
281 }
282 break;
283 }
284
285 let error_ratio = (1.0 - weighted_error) / weighted_error;
289 let alpha = self.learning_rate * (error_ratio.ln() + (k - 1.0).ln());
290
291 for ((true_label, pred_label), weight) in target_array
293 .iter()
294 .zip(pred_array.iter())
295 .zip(sample_weights.iter_mut())
296 {
297 let true_idx: usize = (*true_label).into();
298 let pred_idx: usize = (*pred_label).into();
299
300 if true_idx != pred_idx {
301 *weight *= alpha.exp();
303 }
304 }
305
306 model_weights.push(alpha);
307 models.push(model);
308 }
309
310 if models.is_empty() {
311 return Err(Error::NotConverged(
312 "No models were successfully trained".to_string(),
313 ));
314 }
315
316 Ok(AdaBoost {
317 models,
318 model_weights,
319 classes: classes_set,
320 })
321 }
322}