use alloc::boxed::Box;
use alloc::string::String;
use alloc::vec::Vec;
use crate::drift::adwin::Adwin;
use crate::drift::ddm::Ddm;
use crate::drift::pht::PageHinkleyTest;
use crate::drift::DriftDetector;
use crate::ensemble::variants::SGBTVariant;
use crate::error::Result;
use crate::tree::leaf_model::LeafModelType;
mod display;
mod tree_config_helper;
mod validation;
pub(crate) use tree_config_helper::build_tree_config;
pub use crate::feature::FeatureType;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
#[cfg_attr(
feature = "_serde_support",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub enum ScaleMode {
#[default]
Empirical,
TreeChain,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "_serde_support",
derive(serde::Serialize, serde::Deserialize)
)]
#[non_exhaustive]
pub enum DriftDetectorType {
PageHinkley {
delta: f64,
lambda: f64,
},
Adwin {
delta: f64,
},
Ddm {
warning_level: f64,
drift_level: f64,
min_instances: u64,
},
}
impl Default for DriftDetectorType {
fn default() -> Self {
DriftDetectorType::PageHinkley {
delta: 0.005,
lambda: 50.0,
}
}
}
impl DriftDetectorType {
pub fn create(&self) -> Box<dyn DriftDetector> {
match self {
Self::PageHinkley { delta, lambda } => {
Box::new(PageHinkleyTest::with_params(*delta, *lambda))
}
Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
Self::Ddm {
warning_level,
drift_level,
min_instances,
} => Box::new(Ddm::with_params(
*warning_level,
*drift_level,
*min_instances,
)),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "_serde_support",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct SGBTConfig {
pub n_steps: usize,
pub learning_rate: f64,
pub feature_subsample_rate: f64,
pub max_depth: usize,
pub n_bins: usize,
pub lambda: f64,
pub gamma: f64,
pub grace_period: usize,
pub delta: f64,
pub drift_detector: DriftDetectorType,
pub variant: SGBTVariant,
pub seed: u64,
pub initial_target_count: usize,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub leaf_half_life: Option<usize>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub max_tree_samples: Option<u64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub adaptive_mts: Option<(u64, f64)>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub adaptive_mts_floor: f64,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub proactive_prune_interval: Option<u64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub split_reeval_interval: Option<usize>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub feature_names: Option<Vec<String>>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub feature_types: Option<Vec<FeatureType>>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub gradient_clip_sigma: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub monotone_constraints: Option<Vec<i8>>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub quality_prune_alpha: Option<f64>,
#[cfg_attr(
feature = "_serde_support",
serde(default = "default_quality_prune_threshold")
)]
pub quality_prune_threshold: f64,
#[cfg_attr(
feature = "_serde_support",
serde(default = "default_quality_prune_patience")
)]
pub quality_prune_patience: u64,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub error_weight_alpha: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub uncertainty_modulated_lr: bool,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub scale_mode: ScaleMode,
#[cfg_attr(
feature = "_serde_support",
serde(default = "default_empirical_sigma_alpha")
)]
pub empirical_sigma_alpha: f64,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub max_leaf_output: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub adaptive_leaf_bound: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub adaptive_depth: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub min_hessian_sum: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub huber_k: Option<f64>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub shadow_warmup: Option<usize>,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub leaf_model_type: LeafModelType,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub packed_refresh_interval: u64,
#[cfg_attr(feature = "_serde_support", serde(default))]
pub hoeffding_r: Option<f64>,
}
#[cfg(feature = "_serde_support")]
fn default_empirical_sigma_alpha() -> f64 {
0.01
}
#[cfg(feature = "_serde_support")]
fn default_quality_prune_threshold() -> f64 {
1e-6
}
#[cfg(feature = "_serde_support")]
fn default_quality_prune_patience() -> u64 {
500
}
impl Default for SGBTConfig {
fn default() -> Self {
Self {
n_steps: 100,
learning_rate: 0.0125,
feature_subsample_rate: 0.75,
max_depth: 6,
n_bins: 64,
lambda: 1.0,
gamma: 0.0,
grace_period: 200,
delta: 1e-7,
drift_detector: DriftDetectorType::default(),
variant: SGBTVariant::default(),
seed: 0xDEAD_BEEF_CAFE_4242,
initial_target_count: 50,
leaf_half_life: None,
max_tree_samples: None,
adaptive_mts: None,
adaptive_mts_floor: 0.0,
proactive_prune_interval: None,
split_reeval_interval: None,
feature_names: None,
feature_types: None,
gradient_clip_sigma: None,
monotone_constraints: None,
quality_prune_alpha: None,
quality_prune_threshold: 1e-6,
quality_prune_patience: 500,
error_weight_alpha: None,
uncertainty_modulated_lr: false,
scale_mode: ScaleMode::default(),
empirical_sigma_alpha: 0.01,
max_leaf_output: None,
adaptive_leaf_bound: None,
adaptive_depth: None,
min_hessian_sum: None,
huber_k: None,
shadow_warmup: None,
leaf_model_type: LeafModelType::default(),
packed_refresh_interval: 0,
hoeffding_r: None,
}
}
}
impl SGBTConfig {
pub fn builder() -> SGBTConfigBuilder {
SGBTConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct SGBTConfigBuilder {
config: SGBTConfig,
}
impl SGBTConfigBuilder {
pub fn n_steps(mut self, n: usize) -> Self {
self.config.n_steps = n;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.config.learning_rate = lr;
self
}
pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
self.config.feature_subsample_rate = rate;
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.config.max_depth = depth;
self
}
pub fn n_bins(mut self, bins: usize) -> Self {
self.config.n_bins = bins;
self
}
pub fn lambda(mut self, l: f64) -> Self {
self.config.lambda = l;
self
}
pub fn gamma(mut self, g: f64) -> Self {
self.config.gamma = g;
self
}
pub fn grace_period(mut self, gp: usize) -> Self {
self.config.grace_period = gp;
self
}
pub fn delta(mut self, d: f64) -> Self {
self.config.delta = d;
self
}
pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
self.config.drift_detector = dt;
self
}
pub fn variant(mut self, v: SGBTVariant) -> Self {
self.config.variant = v;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
pub fn initial_target_count(mut self, count: usize) -> Self {
self.config.initial_target_count = count;
self
}
pub fn leaf_half_life(mut self, n: usize) -> Self {
self.config.leaf_half_life = Some(n);
self
}
pub fn max_tree_samples(mut self, n: u64) -> Self {
self.config.max_tree_samples = Some(n);
self
}
pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
self.config.adaptive_mts = Some((base_mts, k));
self
}
pub fn adaptive_mts_floor(mut self, fraction: f64) -> Self {
self.config.adaptive_mts_floor = fraction;
self
}
pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
self.config.proactive_prune_interval = Some(interval);
self
}
pub fn split_reeval_interval(mut self, n: usize) -> Self {
self.config.split_reeval_interval = Some(n);
self
}
pub fn feature_names(mut self, names: Vec<String>) -> Self {
self.config.feature_names = Some(names);
self
}
pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
self.config.feature_types = Some(types);
self
}
pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
self.config.gradient_clip_sigma = Some(sigma);
self
}
pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
self.config.monotone_constraints = Some(constraints);
self
}
pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
self.config.quality_prune_alpha = Some(alpha);
self
}
pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
self.config.quality_prune_threshold = threshold;
self
}
pub fn quality_prune_patience(mut self, patience: u64) -> Self {
self.config.quality_prune_patience = patience;
self
}
pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
self.config.error_weight_alpha = Some(alpha);
self
}
pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
self.config.uncertainty_modulated_lr = enabled;
self
}
pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
self.config.scale_mode = mode;
self
}
pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
self.config.empirical_sigma_alpha = alpha;
self
}
pub fn max_leaf_output(mut self, max: f64) -> Self {
self.config.max_leaf_output = Some(max);
self
}
pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
self.config.adaptive_leaf_bound = Some(k);
self
}
pub fn adaptive_depth(mut self, factor: f64) -> Self {
self.config.adaptive_depth = Some(factor);
self
}
pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
self.config.min_hessian_sum = Some(min_h);
self
}
pub fn huber_k(mut self, k: f64) -> Self {
self.config.huber_k = Some(k);
self
}
pub fn shadow_warmup(mut self, warmup: usize) -> Self {
self.config.shadow_warmup = Some(warmup);
self
}
pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
self.config.leaf_model_type = lmt;
self
}
pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
self.config.packed_refresh_interval = interval;
self
}
pub fn hoeffding_r(mut self, r: f64) -> Self {
self.config.hoeffding_r = Some(r);
self
}
pub fn build(self) -> Result<SGBTConfig> {
validation::validate_and_build(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn default_config_valid() {
let cfg = SGBTConfig::default();
assert_eq!(cfg.n_steps, 100);
assert_eq!(cfg.learning_rate, 0.0125);
}
#[test]
fn builder_basic() {
let cfg = SGBTConfig::builder()
.n_steps(50)
.learning_rate(0.05)
.build()
.unwrap();
assert_eq!(cfg.n_steps, 50);
assert_eq!(cfg.learning_rate, 0.05);
}
#[test]
fn validation_rejects_zero_n_steps() {
let result = SGBTConfig::builder().n_steps(0).build();
assert!(result.is_err());
}
#[test]
fn validation_accepts_valid_learning_rate() {
let result = SGBTConfig::builder().learning_rate(0.1).build();
assert!(result.is_ok());
}
#[test]
fn validation_rejects_zero_learning_rate() {
let result = SGBTConfig::builder().learning_rate(0.0).build();
assert!(result.is_err());
}
#[test]
fn validation_rejects_learning_rate_above_one() {
let result = SGBTConfig::builder().learning_rate(1.5).build();
assert!(result.is_err());
}
#[test]
fn validation_accepts_learning_rate_one() {
let result = SGBTConfig::builder().learning_rate(1.0).build();
assert!(result.is_ok());
}
#[test]
fn drift_detector_type_create() {
let dt = DriftDetectorType::PageHinkley {
delta: 0.005,
lambda: 50.0,
};
let mut detector = dt.create();
for _ in 0..500 {
detector.update(1.0);
}
let mut drifted = false;
for _ in 0..500 {
if detector.update(10.0) == crate::drift::DriftSignal::Drift {
drifted = true;
break;
}
}
assert!(drifted);
}
#[test]
fn boundary_n_bins_two_accepted() {
let result = SGBTConfig::builder().n_bins(2).build();
assert!(result.is_ok());
}
#[test]
fn boundary_grace_period_one_accepted() {
let result = SGBTConfig::builder().grace_period(1).build();
assert!(result.is_ok());
}
#[test]
fn feature_names_accepted() {
let cfg = SGBTConfig::builder()
.feature_names(vec!["price".into(), "volume".into(), "spread".into()])
.build()
.unwrap();
assert_eq!(
cfg.feature_names.as_ref().unwrap(),
&["price", "volume", "spread"]
);
}
#[test]
fn feature_names_rejects_duplicates() {
let result = SGBTConfig::builder()
.feature_names(vec!["price".into(), "volume".into(), "price".into()])
.build();
assert!(result.is_err());
}
#[test]
fn feature_names_empty_vec_accepted() {
let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
assert!(cfg.feature_names.unwrap().is_empty());
}
#[test]
fn builder_adaptive_leaf_bound() {
let cfg = SGBTConfig::builder()
.adaptive_leaf_bound(3.0)
.build()
.unwrap();
assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
}
#[test]
fn validation_rejects_zero_adaptive_leaf_bound() {
let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
assert!(result.is_err());
}
}