use crate::error::{ConfigError, Result};
use super::DriftDetectorType;
use super::SGBTConfig;
use crate::ensemble::variants::SGBTVariant;
pub(super) fn validate_and_build(config: SGBTConfig) -> Result<SGBTConfig> {
let c = &config;
if c.n_steps == 0 {
return Err(ConfigError::out_of_range("n_steps", "must be > 0", c.n_steps).into());
}
if c.learning_rate <= 0.0 || c.learning_rate > 1.0 {
return Err(ConfigError::out_of_range(
"learning_rate",
"must be in (0, 1]",
c.learning_rate,
)
.into());
}
if c.feature_subsample_rate <= 0.0 || c.feature_subsample_rate > 1.0 {
return Err(ConfigError::out_of_range(
"feature_subsample_rate",
"must be in (0, 1]",
c.feature_subsample_rate,
)
.into());
}
if c.max_depth == 0 {
return Err(ConfigError::out_of_range("max_depth", "must be > 0", c.max_depth).into());
}
if c.n_bins < 2 {
return Err(ConfigError::out_of_range("n_bins", "must be >= 2", c.n_bins).into());
}
if c.lambda < 0.0 {
return Err(ConfigError::out_of_range("lambda", "must be >= 0", c.lambda).into());
}
if c.gamma < 0.0 {
return Err(ConfigError::out_of_range("gamma", "must be >= 0", c.gamma).into());
}
if c.grace_period == 0 {
return Err(
ConfigError::out_of_range("grace_period", "must be > 0", c.grace_period).into(),
);
}
if c.delta <= 0.0 || c.delta >= 1.0 {
return Err(ConfigError::out_of_range("delta", "must be in (0, 1)", c.delta).into());
}
if c.initial_target_count == 0 {
return Err(ConfigError::out_of_range(
"initial_target_count",
"must be > 0",
c.initial_target_count,
)
.into());
}
if let Some(hl) = c.leaf_half_life {
if hl == 0 {
return Err(ConfigError::out_of_range("leaf_half_life", "must be >= 1", hl).into());
}
}
if let Some(max) = c.max_tree_samples {
if max < 100 {
return Err(
ConfigError::out_of_range("max_tree_samples", "must be >= 100", max).into(),
);
}
}
if let Some(interval) = c.split_reeval_interval {
if interval < c.grace_period {
return Err(ConfigError::invalid(
"split_reeval_interval",
format!(
"must be >= grace_period ({}), got {}",
c.grace_period, interval
),
)
.into());
}
}
if let Some(ref names) = c.feature_names {
let mut seen = std::collections::HashSet::new();
for name in names {
if !seen.insert(name.as_str()) {
return Err(ConfigError::invalid(
"feature_names",
format!("duplicate feature name: '{}'", name),
)
.into());
}
}
}
if let Some(ref types) = c.feature_types {
if let Some(ref names) = c.feature_names {
if !names.is_empty() && !types.is_empty() && names.len() != types.len() {
return Err(ConfigError::invalid(
"feature_types",
format!(
"length ({}) must match feature_names length ({})",
types.len(),
names.len()
),
)
.into());
}
}
}
if let Some(sigma) = c.gradient_clip_sigma {
if sigma <= 0.0 {
return Err(
ConfigError::out_of_range("gradient_clip_sigma", "must be > 0", sigma).into(),
);
}
}
if let Some(ref mc) = c.monotone_constraints {
for (i, &v) in mc.iter().enumerate() {
if v != -1 && v != 0 && v != 1 {
return Err(ConfigError::invalid(
"monotone_constraints",
format!("feature {}: must be -1, 0, or +1, got {}", i, v),
)
.into());
}
}
}
if let Some(max) = c.max_leaf_output {
if max <= 0.0 {
return Err(ConfigError::out_of_range("max_leaf_output", "must be > 0", max).into());
}
}
if let Some(k) = c.adaptive_leaf_bound {
if k <= 0.0 {
return Err(ConfigError::out_of_range("adaptive_leaf_bound", "must be > 0", k).into());
}
}
if let Some(factor) = c.adaptive_depth {
if factor <= 0.0 {
return Err(ConfigError::out_of_range("adaptive_depth", "must be > 0", factor).into());
}
}
if let Some(min_h) = c.min_hessian_sum {
if min_h <= 0.0 {
return Err(ConfigError::out_of_range("min_hessian_sum", "must be > 0", min_h).into());
}
}
if let Some(k) = c.huber_k {
if k <= 0.0 {
return Err(ConfigError::out_of_range("huber_k", "must be > 0", k).into());
}
}
if !(0.0..=1.0).contains(&c.empirical_sigma_alpha) {
return Err(ConfigError::out_of_range(
"empirical_sigma_alpha",
"must be in [0.0, 1.0]",
c.empirical_sigma_alpha,
)
.into());
}
if let Some(warmup) = c.shadow_warmup {
if warmup == 0 {
return Err(
ConfigError::out_of_range("shadow_warmup", "must be > 0", warmup as f64).into(),
);
}
}
if let Some(alpha) = c.quality_prune_alpha {
if alpha <= 0.0 || alpha >= 1.0 {
return Err(ConfigError::out_of_range(
"quality_prune_alpha",
"must be in (0, 1)",
alpha,
)
.into());
}
}
if c.quality_prune_threshold <= 0.0 {
return Err(ConfigError::out_of_range(
"quality_prune_threshold",
"must be > 0",
c.quality_prune_threshold,
)
.into());
}
if c.quality_prune_patience == 0 {
return Err(ConfigError::out_of_range(
"quality_prune_patience",
"must be > 0",
c.quality_prune_patience,
)
.into());
}
if let Some(alpha) = c.error_weight_alpha {
if alpha <= 0.0 || alpha >= 1.0 {
return Err(ConfigError::out_of_range(
"error_weight_alpha",
"must be in (0, 1)",
alpha,
)
.into());
}
}
if let Some((base_mts, k)) = c.adaptive_mts {
if base_mts < 100 {
return Err(ConfigError::out_of_range(
"adaptive_mts.base_mts",
"must be >= 100",
base_mts,
)
.into());
}
if k <= 0.0 {
return Err(ConfigError::out_of_range("adaptive_mts.k", "must be > 0", k).into());
}
}
if !(0.0..=1.0).contains(&c.adaptive_mts_floor) {
return Err(ConfigError::out_of_range(
"adaptive_mts_floor",
"must be in [0.0, 1.0]",
c.adaptive_mts_floor,
)
.into());
}
if let Some(interval) = c.proactive_prune_interval {
if interval < 10 {
return Err(ConfigError::out_of_range(
"proactive_prune_interval",
"must be >= 10",
interval,
)
.into());
}
}
if let Some(r) = c.hoeffding_r {
if !(r > 0.0 && r.is_finite()) {
return Err(
ConfigError::out_of_range("hoeffding_r", "must be finite and > 0", r).into(),
);
}
}
if c.packed_refresh_interval != 0 && c.packed_refresh_interval < 10 {
return Err(ConfigError::out_of_range(
"packed_refresh_interval",
"must be 0 (disabled) or >= 10",
c.packed_refresh_interval,
)
.into());
}
match &c.drift_detector {
DriftDetectorType::PageHinkley { delta, lambda } => {
if *delta <= 0.0 {
return Err(ConfigError::out_of_range(
"drift_detector.PageHinkley.delta",
"must be > 0",
delta,
)
.into());
}
if *lambda <= 0.0 {
return Err(ConfigError::out_of_range(
"drift_detector.PageHinkley.lambda",
"must be > 0",
lambda,
)
.into());
}
}
DriftDetectorType::Adwin { delta } => {
if *delta <= 0.0 || *delta >= 1.0 {
return Err(ConfigError::out_of_range(
"drift_detector.Adwin.delta",
"must be in (0, 1)",
delta,
)
.into());
}
}
DriftDetectorType::Ddm {
warning_level,
drift_level,
min_instances,
} => {
if *warning_level <= 0.0 {
return Err(ConfigError::out_of_range(
"drift_detector.Ddm.warning_level",
"must be > 0",
warning_level,
)
.into());
}
if *drift_level <= 0.0 {
return Err(ConfigError::out_of_range(
"drift_detector.Ddm.drift_level",
"must be > 0",
drift_level,
)
.into());
}
if *drift_level <= *warning_level {
return Err(ConfigError::invalid(
"drift_detector.Ddm.drift_level",
format!(
"must be > warning_level ({}), got {}",
warning_level, drift_level
),
)
.into());
}
if *min_instances == 0 {
return Err(ConfigError::out_of_range(
"drift_detector.Ddm.min_instances",
"must be > 0",
min_instances,
)
.into());
}
}
}
match &c.variant {
SGBTVariant::Standard => {} SGBTVariant::Skip { k } => {
if *k == 0 {
return Err(ConfigError::out_of_range("variant.Skip.k", "must be > 0", k).into());
}
}
SGBTVariant::MultipleIterations { multiplier } => {
if *multiplier <= 0.0 {
return Err(ConfigError::out_of_range(
"variant.MultipleIterations.multiplier",
"must be > 0",
multiplier,
)
.into());
}
}
}
Ok(config)
}