use linfa::{
error::{Error, Result},
ParamGuard,
};
use rand::rngs::ThreadRng;
use rand::Rng;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct AdaBoostValidParams<P, R> {
pub n_estimators: usize,
pub learning_rate: f64,
pub model_params: P,
pub rng: R,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct AdaBoostParams<P, R>(AdaBoostValidParams<P, R>);
impl<P> AdaBoostParams<P, ThreadRng> {
pub fn new(model_params: P) -> AdaBoostParams<P, ThreadRng> {
Self::new_fixed_rng(model_params, rand::thread_rng())
}
}
impl<P, R: Rng + Clone> AdaBoostParams<P, R> {
pub fn new_fixed_rng(model_params: P, rng: R) -> AdaBoostParams<P, R> {
Self(AdaBoostValidParams {
n_estimators: 50,
learning_rate: 1.0,
model_params,
rng,
})
}
pub fn n_estimators(mut self, n_estimators: usize) -> Self {
self.0.n_estimators = n_estimators;
self
}
pub fn learning_rate(mut self, learning_rate: f64) -> Self {
self.0.learning_rate = learning_rate;
self
}
}
impl<P, R> ParamGuard for AdaBoostParams<P, R> {
type Checked = AdaBoostValidParams<P, R>;
type Error = Error;
fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.n_estimators < 1 {
Err(Error::Parameters(format!(
"n_estimators must be at least 1, but was {}",
self.0.n_estimators
)))
} else if self.0.learning_rate <= 0.0 {
Err(Error::Parameters(format!(
"learning_rate must be positive, but was {}",
self.0.learning_rate
)))
} else if !self.0.learning_rate.is_finite() {
Err(Error::Parameters(
"learning_rate must be finite (not NaN or infinity)".to_string(),
))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;
#[test]
fn test_default_params() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng);
assert_eq!(params.0.n_estimators, 50);
assert_eq!(params.0.learning_rate, 1.0);
}
#[test]
fn test_custom_params() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.n_estimators(100)
.learning_rate(0.5);
assert_eq!(params.0.n_estimators, 100);
assert_eq!(params.0.learning_rate, 0.5);
}
#[test]
fn test_invalid_n_estimators() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.n_estimators(0);
assert!(params.check_ref().is_err());
}
#[test]
fn test_invalid_learning_rate_negative() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.learning_rate(-0.5);
assert!(params.check_ref().is_err());
}
#[test]
fn test_invalid_learning_rate_zero() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.learning_rate(0.0);
assert!(params.check_ref().is_err());
}
#[test]
fn test_invalid_learning_rate_nan() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.learning_rate(f64::NAN);
assert!(params.check_ref().is_err());
}
#[test]
fn test_valid_params() {
let rng = SmallRng::seed_from_u64(42);
let params = AdaBoostParams::new_fixed_rng(DecisionTree::<f64, usize>::params(), rng)
.n_estimators(100)
.learning_rate(0.5);
assert!(params.check_ref().is_ok());
}
}