Skip to main content

linfa_ensemble/
adaboost_hyperparams.rs

1use linfa::{
2    error::{Error, Result},
3    ParamGuard,
4};
5use rand::rngs::ThreadRng;
6use rand::Rng;
7
8/// The set of valid hyperparameters for the [AdaBoost](crate::AdaBoost) algorithm.
9///
10/// ## Parameters
11///
12/// * `n_estimators`: The maximum number of weak learners to train sequentially.
13///   More estimators generally improve performance but increase training time and risk overfitting.
14///   Typical values range from 50 to 500. Default: 50.
15///
16/// * `learning_rate`: Shrinks the contribution of each classifier. There is a trade-off between
17///   `learning_rate` and `n_estimators`. Lower values require more estimators to achieve the same
18///   performance but may generalize better. Must be positive. Default: 1.0.
19///
20/// * `model_params`: The parameters for the base learner (weak classifier). Typically, shallow
21///   decision trees (stumps with max_depth=1 or max_depth=2) are used as weak learners.
22///
23/// * `rng`: Random number generator used for bootstrap sampling and reproducibility.
24#[derive(Clone, Copy, Debug, PartialEq)]
25pub struct AdaBoostValidParams<P, R> {
26    /// The maximum number of estimators to train
27    pub n_estimators: usize,
28    /// The learning rate (shrinkage parameter)
29    pub learning_rate: f64,
30    /// The base learner parameters
31    pub model_params: P,
32    /// Random number generator
33    pub rng: R,
34}
35
36/// A helper struct for building [AdaBoost](crate::AdaBoost) hyperparameters.
37///
38/// This struct follows the builder pattern, allowing you to chain method calls to configure
39/// the AdaBoost algorithm before fitting.
40///
41/// ## Example
42///
43/// ```no_run
44/// use linfa_ensemble::AdaBoostParams;
45/// use linfa_trees::DecisionTree;
46///
47/// let params = AdaBoostParams::new(DecisionTree::<f64, usize>::params().max_depth(Some(1)))
48///     .n_estimators(100)
49///     .learning_rate(0.5);
50/// ```
51#[derive(Clone, Copy, Debug, PartialEq)]
52pub struct AdaBoostParams<P, R>(AdaBoostValidParams<P, R>);
53
54impl<P> AdaBoostParams<P, ThreadRng> {
55    /// Create a new AdaBoost parameter set with default values and a thread-local RNG.
56    ///
57    /// # Arguments
58    ///
59    /// * `model_params` - The parameters for the base learner (e.g., DecisionTreeParams)
60    ///
61    /// # Default Values
62    ///
63    /// * `n_estimators`: 50
64    /// * `learning_rate`: 1.0
65    pub fn new(model_params: P) -> AdaBoostParams<P, ThreadRng> {
66        Self::new_fixed_rng(model_params, rand::thread_rng())
67    }
68}
69
70impl<P, R: Rng + Clone> AdaBoostParams<P, R> {
71    /// Create a new AdaBoost parameter set with a fixed RNG for reproducibility.
72    ///
73    /// # Arguments
74    ///
75    /// * `model_params` - The parameters for the base learner
76    /// * `rng` - A seeded random number generator for reproducible results
77    ///
78    /// # Example
79    ///
80    /// ```no_run
81    /// use linfa_ensemble::AdaBoostParams;
82    /// use linfa_trees::DecisionTree;
83    /// use ndarray_rand::rand::SeedableRng;
84    /// use rand::rngs::SmallRng;
85    ///
86    /// let rng = SmallRng::seed_from_u64(42);
87    /// let params = AdaBoostParams::new_fixed_rng(
88    ///     DecisionTree::<f64, usize>::params().max_depth(Some(1)),
89    ///     rng
90    /// );
91    /// ```
92    pub fn new_fixed_rng(model_params: P, rng: R) -> AdaBoostParams<P, R> {
93        Self(AdaBoostValidParams {
94            n_estimators: 50,
95            learning_rate: 1.0,
96            model_params,
97            rng,
98        })
99    }
100
101    /// Set the maximum number of weak learners to train.
102    ///
103    /// # Arguments
104    ///
105    /// * `n_estimators` - Must be at least 1. Typical values: 50-500
106    ///
107    /// # Notes
108    ///
109    /// Higher values generally lead to better training performance but:
110    /// * Increase training time linearly
111    /// * May lead to overfitting
112    /// * Should be balanced with `learning_rate`
113    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
114        self.0.n_estimators = n_estimators;
115        self
116    }
117
118    /// Set the learning rate (shrinkage parameter).
119    ///
120    /// # Arguments
121    ///
122    /// * `learning_rate` - Must be positive. Typical values: 0.01 to 2.0
123    ///
124    /// # Notes
125    ///
126    /// * Values < 1.0 provide regularization and often improve generalization
127    /// * Lower values require more estimators to achieve similar performance
128    /// * A common strategy is to use learning_rate=0.1 with n_estimators=500
129    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
130        self.0.learning_rate = learning_rate;
131        self
132    }
133}
134
135impl<P, R> ParamGuard for AdaBoostParams<P, R> {
136    type Checked = AdaBoostValidParams<P, R>;
137    type Error = Error;
138
139    fn check_ref(&self) -> Result<&Self::Checked> {
140        if self.0.n_estimators < 1 {
141            Err(Error::Parameters(format!(
142                "n_estimators must be at least 1, but was {}",
143                self.0.n_estimators
144            )))
145        } else if self.0.learning_rate <= 0.0 {
146            Err(Error::Parameters(format!(
147                "learning_rate must be positive, but was {}",
148                self.0.learning_rate
149            )))
150        } else if !self.0.learning_rate.is_finite() {
151            Err(Error::Parameters(
152                "learning_rate must be finite (not NaN or infinity)".to_string(),
153            ))
154        } else {
155            Ok(&self.0)
156        }
157    }
158
159    fn check(self) -> Result<Self::Checked> {
160        self.check_ref()?;
161        Ok(self.0)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use linfa_trees::DecisionTree;
169    use ndarray_rand::rand::SeedableRng;
170    use rand::rngs::SmallRng;
171
172    #[test]
173    fn test_default_params() {
174        let rng = SmallRng::seed_from_u64(42);
175        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng);
176        assert_eq!(params.0.n_estimators, 50);
177        assert_eq!(params.0.learning_rate, 1.0);
178    }
179
180    #[test]
181    fn test_custom_params() {
182        let rng = SmallRng::seed_from_u64(42);
183        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
184            .n_estimators(100)
185            .learning_rate(0.5);
186        assert_eq!(params.0.n_estimators, 100);
187        assert_eq!(params.0.learning_rate, 0.5);
188    }
189
190    #[test]
191    fn test_invalid_n_estimators() {
192        let rng = SmallRng::seed_from_u64(42);
193        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
194            .n_estimators(0);
195        assert!(params.check_ref().is_err());
196    }
197
198    #[test]
199    fn test_invalid_learning_rate_negative() {
200        let rng = SmallRng::seed_from_u64(42);
201        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
202            .learning_rate(-0.5);
203        assert!(params.check_ref().is_err());
204    }
205
206    #[test]
207    fn test_invalid_learning_rate_zero() {
208        let rng = SmallRng::seed_from_u64(42);
209        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
210            .learning_rate(0.0);
211        assert!(params.check_ref().is_err());
212    }
213
214    #[test]
215    fn test_invalid_learning_rate_nan() {
216        let rng = SmallRng::seed_from_u64(42);
217        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
218            .learning_rate(f64::NAN);
219        assert!(params.check_ref().is_err());
220    }
221
222    #[test]
223    fn test_valid_params() {
224        let rng = SmallRng::seed_from_u64(42);
225        let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
226            .n_estimators(100)
227            .learning_rate(0.5);
228        assert!(params.check_ref().is_ok());
229    }
230}