use crate::drift::adwin::Adwin;
use crate::drift::ddm::Ddm;
use crate::drift::pht::PageHinkleyTest;
use crate::drift::DriftDetector;
use crate::ensemble::variants::SGBTVariant;
use crate::error::Result;
use crate::tree::leaf_model::LeafModelType;
mod display;
mod tree_config_helper;
mod validation;
pub(crate) use tree_config_helper::build_tree_config;
pub use irithyll_core::feature::FeatureType;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum ScaleMode {
#[default]
Empirical,
TreeChain,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum DriftDetectorType {
PageHinkley {
delta: f64,
lambda: f64,
},
Adwin {
delta: f64,
},
Ddm {
warning_level: f64,
drift_level: f64,
min_instances: u64,
},
}
impl Default for DriftDetectorType {
fn default() -> Self {
DriftDetectorType::PageHinkley {
delta: 0.005,
lambda: 50.0,
}
}
}
impl DriftDetectorType {
pub fn create(&self) -> Box<dyn DriftDetector> {
match self {
Self::PageHinkley { delta, lambda } => {
Box::new(PageHinkleyTest::with_params(*delta, *lambda))
}
Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
Self::Ddm {
warning_level,
drift_level,
min_instances,
} => Box::new(Ddm::with_params(
*warning_level,
*drift_level,
*min_instances,
)),
}
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct SGBTConfig {
pub n_steps: usize,
pub learning_rate: f64,
pub feature_subsample_rate: f64,
pub max_depth: usize,
pub n_bins: usize,
pub lambda: f64,
pub gamma: f64,
pub grace_period: usize,
pub delta: f64,
pub drift_detector: DriftDetectorType,
pub variant: SGBTVariant,
pub seed: u64,
pub initial_target_count: usize,
#[serde(default)]
pub leaf_half_life: Option<usize>,
#[serde(default)]
pub max_tree_samples: Option<u64>,
#[serde(default)]
pub split_reeval_interval: Option<usize>,
#[serde(default)]
pub feature_names: Option<Vec<String>>,
#[serde(default)]
pub feature_types: Option<Vec<FeatureType>>,
#[serde(default)]
pub gradient_clip_sigma: Option<f64>,
#[serde(default)]
pub monotone_constraints: Option<Vec<i8>>,
#[serde(default)]
pub quality_prune_alpha: Option<f64>,
#[serde(default = "default_quality_prune_threshold")]
pub quality_prune_threshold: f64,
#[serde(default = "default_quality_prune_patience")]
pub quality_prune_patience: u64,
#[serde(default)]
pub error_weight_alpha: Option<f64>,
#[serde(default)]
pub uncertainty_modulated_lr: bool,
#[serde(default)]
pub scale_mode: ScaleMode,
#[serde(default = "default_empirical_sigma_alpha")]
pub empirical_sigma_alpha: f64,
#[serde(default)]
pub max_leaf_output: Option<f64>,
#[serde(default)]
pub adaptive_leaf_bound: Option<f64>,
#[serde(default)]
pub adaptive_depth: Option<f64>,
#[serde(default)]
pub min_hessian_sum: Option<f64>,
#[serde(default)]
pub huber_k: Option<f64>,
#[serde(default)]
pub shadow_warmup: Option<usize>,
#[serde(default)]
pub leaf_model_type: LeafModelType,
#[serde(default)]
pub packed_refresh_interval: u64,
#[serde(default)]
pub adaptive_mts: Option<(u64, f64)>,
#[serde(default)]
pub adaptive_mts_floor: f64,
#[serde(default)]
pub proactive_prune_interval: Option<u64>,
#[serde(default)]
pub accuracy_based_pruning: bool,
#[serde(default)]
pub prune_half_life: Option<usize>,
#[serde(default)]
pub hoeffding_r: Option<f64>,
}
fn default_empirical_sigma_alpha() -> f64 {
0.01
}
fn default_quality_prune_threshold() -> f64 {
1e-6
}
fn default_quality_prune_patience() -> u64 {
500
}
impl Default for SGBTConfig {
fn default() -> Self {
Self {
n_steps: 100,
learning_rate: 0.0125,
feature_subsample_rate: 0.75,
max_depth: 6,
n_bins: 64,
lambda: 1.0,
gamma: 0.0,
grace_period: 200,
delta: 1e-7,
drift_detector: DriftDetectorType::default(),
variant: SGBTVariant::default(),
seed: 0xDEAD_BEEF_CAFE_4242,
initial_target_count: 50,
leaf_half_life: None,
max_tree_samples: None,
split_reeval_interval: None,
feature_names: None,
feature_types: None,
gradient_clip_sigma: None,
monotone_constraints: None,
quality_prune_alpha: None,
quality_prune_threshold: 1e-6,
quality_prune_patience: 500,
error_weight_alpha: None,
uncertainty_modulated_lr: false,
scale_mode: ScaleMode::default(),
empirical_sigma_alpha: 0.01,
max_leaf_output: None,
adaptive_leaf_bound: None,
adaptive_depth: None,
min_hessian_sum: None,
huber_k: None,
shadow_warmup: None,
leaf_model_type: LeafModelType::default(),
packed_refresh_interval: 0,
adaptive_mts: None,
adaptive_mts_floor: 0.0,
proactive_prune_interval: None,
accuracy_based_pruning: false,
prune_half_life: None,
hoeffding_r: None,
}
}
}
impl SGBTConfig {
pub fn builder() -> SGBTConfigBuilder {
SGBTConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct SGBTConfigBuilder {
config: SGBTConfig,
}
impl SGBTConfigBuilder {
pub fn n_steps(mut self, n: usize) -> Self {
self.config.n_steps = n;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.config.learning_rate = lr;
self
}
pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
self.config.feature_subsample_rate = rate;
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.config.max_depth = depth;
self
}
pub fn n_bins(mut self, bins: usize) -> Self {
self.config.n_bins = bins;
self
}
pub fn lambda(mut self, l: f64) -> Self {
self.config.lambda = l;
self
}
pub fn gamma(mut self, g: f64) -> Self {
self.config.gamma = g;
self
}
pub fn grace_period(mut self, gp: usize) -> Self {
self.config.grace_period = gp;
self
}
pub fn delta(mut self, d: f64) -> Self {
self.config.delta = d;
self
}
pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
self.config.drift_detector = dt;
self
}
pub fn variant(mut self, v: SGBTVariant) -> Self {
self.config.variant = v;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
pub fn initial_target_count(mut self, count: usize) -> Self {
self.config.initial_target_count = count;
self
}
pub fn leaf_half_life(mut self, n: usize) -> Self {
self.config.leaf_half_life = Some(n);
self
}
pub fn max_tree_samples(mut self, n: u64) -> Self {
self.config.max_tree_samples = Some(n);
self
}
pub fn split_reeval_interval(mut self, n: usize) -> Self {
self.config.split_reeval_interval = Some(n);
self
}
pub fn feature_names(mut self, names: Vec<String>) -> Self {
self.config.feature_names = Some(names);
self
}
pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
self.config.feature_types = Some(types);
self
}
pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
self.config.gradient_clip_sigma = Some(sigma);
self
}
pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
self.config.monotone_constraints = Some(constraints);
self
}
pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
self.config.quality_prune_alpha = Some(alpha);
self
}
pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
self.config.quality_prune_threshold = threshold;
self
}
pub fn quality_prune_patience(mut self, patience: u64) -> Self {
self.config.quality_prune_patience = patience;
self
}
pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
self.config.error_weight_alpha = Some(alpha);
self
}
pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
self.config.uncertainty_modulated_lr = enabled;
self
}
pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
self.config.scale_mode = mode;
self
}
pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
self.config.empirical_sigma_alpha = alpha;
self
}
pub fn max_leaf_output(mut self, max: f64) -> Self {
self.config.max_leaf_output = Some(max);
self
}
pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
self.config.adaptive_leaf_bound = Some(k);
self
}
pub fn adaptive_depth(mut self, factor: f64) -> Self {
self.config.adaptive_depth = Some(factor);
self
}
pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
self.config.min_hessian_sum = Some(min_h);
self
}
pub fn huber_k(mut self, k: f64) -> Self {
self.config.huber_k = Some(k);
self
}
pub fn shadow_warmup(mut self, warmup: usize) -> Self {
self.config.shadow_warmup = Some(warmup);
self
}
pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
self.config.leaf_model_type = lmt;
self
}
pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
self.config.packed_refresh_interval = interval;
self
}
pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
self.config.adaptive_mts = Some((base_mts, k));
self
}
pub fn adaptive_mts_floor(mut self, fraction: f64) -> Self {
self.config.adaptive_mts_floor = fraction;
self
}
pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
self.config.proactive_prune_interval = Some(interval);
self
}
pub fn accuracy_based_pruning(mut self, enabled: bool) -> Self {
self.config.accuracy_based_pruning = enabled;
self
}
pub fn prune_half_life(mut self, n: usize) -> Self {
self.config.prune_half_life = Some(n);
self
}
pub fn hoeffding_r(mut self, r: f64) -> Self {
self.config.hoeffding_r = Some(r);
self
}
pub fn build(self) -> Result<SGBTConfig> {
validation::validate_and_build(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let cfg = SGBTConfig::default();
assert_eq!(cfg.n_steps, 100);
assert!((cfg.learning_rate - 0.0125).abs() < f64::EPSILON);
assert!((cfg.feature_subsample_rate - 0.75).abs() < f64::EPSILON);
assert_eq!(cfg.max_depth, 6);
assert_eq!(cfg.n_bins, 64);
assert!((cfg.lambda - 1.0).abs() < f64::EPSILON);
assert!((cfg.gamma - 0.0).abs() < f64::EPSILON);
assert_eq!(cfg.grace_period, 200);
assert!((cfg.delta - 1e-7).abs() < f64::EPSILON);
assert_eq!(cfg.variant, SGBTVariant::Standard);
}
#[test]
fn builder_chain() {
let cfg = SGBTConfig::builder()
.n_steps(50)
.learning_rate(0.1)
.feature_subsample_rate(0.5)
.max_depth(10)
.n_bins(128)
.lambda(0.5)
.gamma(0.1)
.grace_period(500)
.delta(1e-5)
.build()
.expect("valid config");
assert_eq!(cfg.n_steps, 50);
assert!((cfg.learning_rate - 0.1).abs() < f64::EPSILON);
assert!((cfg.feature_subsample_rate - 0.5).abs() < f64::EPSILON);
assert_eq!(cfg.max_depth, 10);
assert_eq!(cfg.n_bins, 128);
assert!((cfg.lambda - 0.5).abs() < f64::EPSILON);
assert!((cfg.gamma - 0.1).abs() < f64::EPSILON);
assert_eq!(cfg.grace_period, 500);
assert!((cfg.delta - 1e-5).abs() < f64::EPSILON);
}
#[test]
fn validation_n_steps_zero() {
let cfg = SGBTConfig::builder().n_steps(0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_learning_rate_zero() {
let cfg = SGBTConfig::builder().learning_rate(0.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_learning_rate_too_high() {
let cfg = SGBTConfig::builder().learning_rate(1.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_feature_subsample_rate_zero() {
let cfg = SGBTConfig::builder().feature_subsample_rate(0.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_feature_subsample_rate_too_high() {
let cfg = SGBTConfig::builder().feature_subsample_rate(1.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_max_depth_zero() {
let cfg = SGBTConfig::builder().max_depth(0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_n_bins_too_small() {
let cfg = SGBTConfig::builder().n_bins(1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_n_bins_two_ok() {
let cfg = SGBTConfig::builder().n_bins(2).build();
assert!(cfg.is_ok());
}
#[test]
fn validation_lambda_negative() {
let cfg = SGBTConfig::builder().lambda(-0.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_gamma_negative() {
let cfg = SGBTConfig::builder().gamma(-0.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_grace_period_zero() {
let cfg = SGBTConfig::builder().grace_period(0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_delta_zero() {
let cfg = SGBTConfig::builder().delta(0.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_delta_one() {
let cfg = SGBTConfig::builder().delta(1.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_initial_target_count_zero() {
let cfg = SGBTConfig::builder().initial_target_count(0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_duplicate_feature_names() {
let cfg = SGBTConfig::builder()
.feature_names(vec!["a".into(), "a".into()])
.build();
assert!(cfg.is_err());
}
#[test]
fn validation_feature_names_types_mismatch() {
let cfg = SGBTConfig::builder()
.feature_names(vec!["a".into(), "b".into()])
.feature_types(vec![FeatureType::Continuous])
.build();
assert!(cfg.is_err());
}
#[test]
fn validation_bad_monotone_constraint() {
let cfg = SGBTConfig::builder().monotone_constraints(vec![2]).build();
assert!(cfg.is_err());
}
#[test]
fn validation_quality_prune_alpha_zero() {
let cfg = SGBTConfig::builder().quality_prune_alpha(0.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_quality_prune_alpha_one() {
let cfg = SGBTConfig::builder().quality_prune_alpha(1.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_error_weight_alpha_zero() {
let cfg = SGBTConfig::builder().error_weight_alpha(0.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_error_weight_alpha_one() {
let cfg = SGBTConfig::builder().error_weight_alpha(1.0).build();
assert!(cfg.is_err());
}
#[test]
fn validation_empirical_sigma_alpha_too_high() {
let cfg = SGBTConfig::builder().empirical_sigma_alpha(1.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_adaptive_mts_floor_too_high() {
let cfg = SGBTConfig::builder().adaptive_mts_floor(1.1).build();
assert!(cfg.is_err());
}
#[test]
fn validation_drift_detector_pht_bad_delta() {
let cfg = SGBTConfig::builder()
.drift_detector(DriftDetectorType::PageHinkley {
delta: 0.0,
lambda: 1.0,
})
.build();
assert!(cfg.is_err());
}
#[test]
fn validation_drift_detector_adwin_bad_delta() {
let cfg = SGBTConfig::builder()
.drift_detector(DriftDetectorType::Adwin { delta: 1.1 })
.build();
assert!(cfg.is_err());
}
#[test]
fn validation_drift_detector_ddm_bad_levels() {
let cfg = SGBTConfig::builder()
.drift_detector(DriftDetectorType::Ddm {
warning_level: 3.0,
drift_level: 2.0,
min_instances: 30,
})
.build();
assert!(cfg.is_err());
}
}