use alloc::vec::Vec;
use crate::feature::FeatureType;
use crate::tree::leaf_model::LeafModelType;
#[derive(Debug, Clone)]
pub struct TreeConfig {
pub max_depth: usize,
pub n_bins: usize,
pub lambda: f64,
pub gamma: f64,
pub grace_period: usize,
pub delta: f64,
pub feature_subsample_rate: f64,
pub seed: u64,
pub leaf_decay_alpha: Option<f64>,
pub split_reeval_interval: Option<usize>,
pub feature_types: Option<Vec<FeatureType>>,
pub gradient_clip_sigma: Option<f64>,
pub monotone_constraints: Option<Vec<i8>>,
pub max_leaf_output: Option<f64>,
pub adaptive_leaf_bound: Option<f64>,
pub adaptive_depth: Option<f64>,
pub min_hessian_sum: Option<f64>,
pub leaf_model_type: LeafModelType,
pub hoeffding_r: Option<f64>,
}
impl Default for TreeConfig {
fn default() -> Self {
Self {
max_depth: 6,
n_bins: 64,
lambda: 1.0,
gamma: 0.0,
grace_period: 200,
delta: 1e-7,
feature_subsample_rate: 1.0,
seed: 42,
leaf_decay_alpha: None,
split_reeval_interval: None,
feature_types: None,
gradient_clip_sigma: None,
monotone_constraints: None,
max_leaf_output: None,
adaptive_leaf_bound: None,
adaptive_depth: None,
min_hessian_sum: None,
leaf_model_type: LeafModelType::default(),
hoeffding_r: None,
}
}
}
impl TreeConfig {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = max_depth;
self
}
#[inline]
pub fn n_bins(mut self, n_bins: usize) -> Self {
self.n_bins = n_bins;
self
}
#[inline]
pub fn lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
#[inline]
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
#[inline]
pub fn grace_period(mut self, grace_period: usize) -> Self {
self.grace_period = grace_period;
self
}
#[inline]
pub fn delta(mut self, delta: f64) -> Self {
self.delta = delta;
self
}
#[inline]
pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
self.feature_subsample_rate = rate.clamp(0.0, 1.0);
self
}
#[inline]
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
#[inline]
pub fn leaf_decay_alpha(mut self, alpha: f64) -> Self {
self.leaf_decay_alpha = Some(alpha);
self
}
#[inline]
pub fn leaf_decay_alpha_opt(mut self, alpha: Option<f64>) -> Self {
self.leaf_decay_alpha = alpha;
self
}
#[inline]
pub fn split_reeval_interval(mut self, interval: usize) -> Self {
self.split_reeval_interval = Some(interval);
self
}
#[inline]
pub fn split_reeval_interval_opt(mut self, interval: Option<usize>) -> Self {
self.split_reeval_interval = interval;
self
}
#[inline]
pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
self.feature_types = Some(types);
self
}
#[inline]
pub fn feature_types_opt(mut self, types: Option<Vec<FeatureType>>) -> Self {
self.feature_types = types;
self
}
#[inline]
pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
self.gradient_clip_sigma = Some(sigma);
self
}
#[inline]
pub fn gradient_clip_sigma_opt(mut self, sigma: Option<f64>) -> Self {
self.gradient_clip_sigma = sigma;
self
}
#[inline]
pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
self.monotone_constraints = Some(constraints);
self
}
#[inline]
pub fn monotone_constraints_opt(mut self, constraints: Option<Vec<i8>>) -> Self {
self.monotone_constraints = constraints;
self
}
#[inline]
pub fn max_leaf_output(mut self, max: f64) -> Self {
self.max_leaf_output = Some(max);
self
}
#[inline]
pub fn max_leaf_output_opt(mut self, max: Option<f64>) -> Self {
self.max_leaf_output = max;
self
}
#[inline]
pub fn adaptive_leaf_bound_opt(mut self, k: Option<f64>) -> Self {
self.adaptive_leaf_bound = k;
self
}
#[inline]
pub fn adaptive_depth(mut self, factor: f64) -> Self {
self.adaptive_depth = Some(factor);
self
}
#[inline]
pub fn adaptive_depth_opt(mut self, factor: Option<f64>) -> Self {
self.adaptive_depth = factor;
self
}
#[inline]
pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
self.min_hessian_sum = Some(min_h);
self
}
#[inline]
pub fn min_hessian_sum_opt(mut self, min_h: Option<f64>) -> Self {
self.min_hessian_sum = min_h;
self
}
#[inline]
pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
self.leaf_model_type = lmt;
self
}
#[inline]
pub fn hoeffding_r(mut self, r: f64) -> Self {
self.hoeffding_r = Some(r);
self
}
#[inline]
pub fn hoeffding_r_opt(mut self, r: Option<f64>) -> Self {
self.hoeffding_r = r;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_values() {
let cfg = TreeConfig::default();
assert_eq!(cfg.max_depth, 6);
assert_eq!(cfg.n_bins, 64);
assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
assert_eq!(cfg.grace_period, 200);
assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
assert!((cfg.feature_subsample_rate - 1.0).abs() < f64::EPSILON);
}
#[test]
fn new_equals_default() {
let a = TreeConfig::new();
let b = TreeConfig::default();
assert_eq!(a.max_depth, b.max_depth);
assert_eq!(a.n_bins, b.n_bins);
assert!((a.lambda - b.lambda).abs() < f64::EPSILON);
assert!((a.gamma - b.gamma).abs() < f64::EPSILON);
assert_eq!(a.grace_period, b.grace_period);
assert!((a.delta - b.delta).abs() < f64::EPSILON);
assert!((a.feature_subsample_rate - b.feature_subsample_rate).abs() < f64::EPSILON);
}
#[test]
fn builder_chain() {
let cfg = TreeConfig::new()
.max_depth(10)
.n_bins(128)
.lambda(0.5)
.gamma(0.1)
.grace_period(500)
.delta(1e-3)
.feature_subsample_rate(0.8);
assert_eq!(cfg.max_depth, 10);
assert_eq!(cfg.n_bins, 128);
assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
assert_eq!(cfg.grace_period, 500);
assert!((cfg.delta - 1e-3).abs() < f64::EPSILON);
assert!((cfg.feature_subsample_rate - 0.8).abs() < f64::EPSILON);
}
#[test]
fn feature_subsample_rate_clamped() {
let cfg = TreeConfig::new().feature_subsample_rate(1.5);
assert!((cfg.feature_subsample_rate - 1.0).abs() < f64::EPSILON);
let cfg = TreeConfig::new().feature_subsample_rate(-0.3);
assert!((cfg.feature_subsample_rate - 0.0).abs() < f64::EPSILON);
}
#[test]
fn max_leaf_output_builder() {
let cfg = TreeConfig::new().max_leaf_output(1.5);
assert_eq!(cfg.max_leaf_output, Some(1.5));
}
#[test]
fn min_hessian_sum_builder() {
let cfg = TreeConfig::new().min_hessian_sum(10.0);
assert_eq!(cfg.min_hessian_sum, Some(10.0));
}
#[test]
fn max_leaf_output_default_none() {
let cfg = TreeConfig::default();
assert!(cfg.max_leaf_output.is_none());
assert!(cfg.min_hessian_sum.is_none());
}
}