linfa_ensemble/adaboost_hyperparams.rs
1use linfa::{
2 error::{Error, Result},
3 ParamGuard,
4};
5use rand::rngs::ThreadRng;
6use rand::Rng;
7
8/// The set of valid hyperparameters for the [AdaBoost](crate::AdaBoost) algorithm.
9///
10/// ## Parameters
11///
12/// * `n_estimators`: The maximum number of weak learners to train sequentially.
13/// More estimators generally improve performance but increase training time and risk overfitting.
14/// Typical values range from 50 to 500. Default: 50.
15///
16/// * `learning_rate`: Shrinks the contribution of each classifier. There is a trade-off between
17/// `learning_rate` and `n_estimators`. Lower values require more estimators to achieve the same
18/// performance but may generalize better. Must be positive. Default: 1.0.
19///
20/// * `model_params`: The parameters for the base learner (weak classifier). Typically, shallow
21/// decision trees (stumps with max_depth=1 or max_depth=2) are used as weak learners.
22///
23/// * `rng`: Random number generator used for bootstrap sampling and reproducibility.
24#[derive(Clone, Copy, Debug, PartialEq)]
25pub struct AdaBoostValidParams<P, R> {
26 /// The maximum number of estimators to train
27 pub n_estimators: usize,
28 /// The learning rate (shrinkage parameter)
29 pub learning_rate: f64,
30 /// The base learner parameters
31 pub model_params: P,
32 /// Random number generator
33 pub rng: R,
34}
35
36/// A helper struct for building [AdaBoost](crate::AdaBoost) hyperparameters.
37///
38/// This struct follows the builder pattern, allowing you to chain method calls to configure
39/// the AdaBoost algorithm before fitting.
40///
41/// ## Example
42///
43/// ```no_run
44/// use linfa_ensemble::AdaBoostParams;
45/// use linfa_trees::DecisionTree;
46///
47/// let params = AdaBoostParams::new(DecisionTree::<f64, usize>::params().max_depth(Some(1)))
48/// .n_estimators(100)
49/// .learning_rate(0.5);
50/// ```
51#[derive(Clone, Copy, Debug, PartialEq)]
52pub struct AdaBoostParams<P, R>(AdaBoostValidParams<P, R>);
53
54impl<P> AdaBoostParams<P, ThreadRng> {
55 /// Create a new AdaBoost parameter set with default values and a thread-local RNG.
56 ///
57 /// # Arguments
58 ///
59 /// * `model_params` - The parameters for the base learner (e.g., DecisionTreeParams)
60 ///
61 /// # Default Values
62 ///
63 /// * `n_estimators`: 50
64 /// * `learning_rate`: 1.0
65 pub fn new(model_params: P) -> AdaBoostParams<P, ThreadRng> {
66 Self::new_fixed_rng(model_params, rand::thread_rng())
67 }
68}
69
70impl<P, R: Rng + Clone> AdaBoostParams<P, R> {
71 /// Create a new AdaBoost parameter set with a fixed RNG for reproducibility.
72 ///
73 /// # Arguments
74 ///
75 /// * `model_params` - The parameters for the base learner
76 /// * `rng` - A seeded random number generator for reproducible results
77 ///
78 /// # Example
79 ///
80 /// ```no_run
81 /// use linfa_ensemble::AdaBoostParams;
82 /// use linfa_trees::DecisionTree;
83 /// use ndarray_rand::rand::SeedableRng;
84 /// use rand::rngs::SmallRng;
85 ///
86 /// let rng = SmallRng::seed_from_u64(42);
87 /// let params = AdaBoostParams::new_fixed_rng(
88 /// DecisionTree::<f64, usize>::params().max_depth(Some(1)),
89 /// rng
90 /// );
91 /// ```
92 pub fn new_fixed_rng(model_params: P, rng: R) -> AdaBoostParams<P, R> {
93 Self(AdaBoostValidParams {
94 n_estimators: 50,
95 learning_rate: 1.0,
96 model_params,
97 rng,
98 })
99 }
100
101 /// Set the maximum number of weak learners to train.
102 ///
103 /// # Arguments
104 ///
105 /// * `n_estimators` - Must be at least 1. Typical values: 50-500
106 ///
107 /// # Notes
108 ///
109 /// Higher values generally lead to better training performance but:
110 /// * Increase training time linearly
111 /// * May lead to overfitting
112 /// * Should be balanced with `learning_rate`
113 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
114 self.0.n_estimators = n_estimators;
115 self
116 }
117
118 /// Set the learning rate (shrinkage parameter).
119 ///
120 /// # Arguments
121 ///
122 /// * `learning_rate` - Must be positive. Typical values: 0.01 to 2.0
123 ///
124 /// # Notes
125 ///
126 /// * Values < 1.0 provide regularization and often improve generalization
127 /// * Lower values require more estimators to achieve similar performance
128 /// * A common strategy is to use learning_rate=0.1 with n_estimators=500
129 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
130 self.0.learning_rate = learning_rate;
131 self
132 }
133}
134
135impl<P, R> ParamGuard for AdaBoostParams<P, R> {
136 type Checked = AdaBoostValidParams<P, R>;
137 type Error = Error;
138
139 fn check_ref(&self) -> Result<&Self::Checked> {
140 if self.0.n_estimators < 1 {
141 Err(Error::Parameters(format!(
142 "n_estimators must be at least 1, but was {}",
143 self.0.n_estimators
144 )))
145 } else if self.0.learning_rate <= 0.0 {
146 Err(Error::Parameters(format!(
147 "learning_rate must be positive, but was {}",
148 self.0.learning_rate
149 )))
150 } else if !self.0.learning_rate.is_finite() {
151 Err(Error::Parameters(
152 "learning_rate must be finite (not NaN or infinity)".to_string(),
153 ))
154 } else {
155 Ok(&self.0)
156 }
157 }
158
159 fn check(self) -> Result<Self::Checked> {
160 self.check_ref()?;
161 Ok(self.0)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use linfa_trees::DecisionTree;
169 use ndarray_rand::rand::SeedableRng;
170 use rand::rngs::SmallRng;
171
172 #[test]
173 fn test_default_params() {
174 let rng = SmallRng::seed_from_u64(42);
175 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng);
176 assert_eq!(params.0.n_estimators, 50);
177 assert_eq!(params.0.learning_rate, 1.0);
178 }
179
180 #[test]
181 fn test_custom_params() {
182 let rng = SmallRng::seed_from_u64(42);
183 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
184 .n_estimators(100)
185 .learning_rate(0.5);
186 assert_eq!(params.0.n_estimators, 100);
187 assert_eq!(params.0.learning_rate, 0.5);
188 }
189
190 #[test]
191 fn test_invalid_n_estimators() {
192 let rng = SmallRng::seed_from_u64(42);
193 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
194 .n_estimators(0);
195 assert!(params.check_ref().is_err());
196 }
197
198 #[test]
199 fn test_invalid_learning_rate_negative() {
200 let rng = SmallRng::seed_from_u64(42);
201 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
202 .learning_rate(-0.5);
203 assert!(params.check_ref().is_err());
204 }
205
206 #[test]
207 fn test_invalid_learning_rate_zero() {
208 let rng = SmallRng::seed_from_u64(42);
209 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
210 .learning_rate(0.0);
211 assert!(params.check_ref().is_err());
212 }
213
214 #[test]
215 fn test_invalid_learning_rate_nan() {
216 let rng = SmallRng::seed_from_u64(42);
217 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
218 .learning_rate(f64::NAN);
219 assert!(params.check_ref().is_err());
220 }
221
222 #[test]
223 fn test_valid_params() {
224 let rng = SmallRng::seed_from_u64(42);
225 let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
226 .n_estimators(100)
227 .learning_rate(0.5);
228 assert!(params.check_ref().is_ok());
229 }
230}