use crate::dataset::feature_extractor::FeatureExtractor;
use crate::defaults::{
ensemble as ensemble_defaults, gbdt as gbdt_defaults, seeds as seeds_defaults,
tree as tree_defaults, universal as universal_defaults,
};
use crate::learner::{LinearConfig, LinearPreset, TreeConfig, TreePreset};
use crate::model::universal::mode::BoostingMode;
use rkyv::{Archive, Deserialize, Serialize};
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub enum StackingStrategy {
Ridge {
alpha: f32,
rank_transform: bool,
fit_intercept: bool,
min_weight: f32,
},
Average,
}
impl Default for StackingStrategy {
fn default() -> Self {
Self::Ridge {
alpha: ensemble_defaults::DEFAULT_STACKING_ALPHA,
rank_transform: ensemble_defaults::DEFAULT_RANK_TRANSFORM,
fit_intercept: ensemble_defaults::DEFAULT_FIT_INTERCEPT,
min_weight: ensemble_defaults::DEFAULT_MIN_WEIGHT,
}
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct UniversalConfig {
pub mode: BoostingMode,
pub num_rounds: usize,
pub tree_config: TreeConfig,
pub linear_config: LinearConfig,
pub learning_rate: f32,
pub subsample: f32,
pub validation_ratio: f32,
pub early_stopping_rounds: usize,
pub calibration_ratio: f32,
pub conformal_quantile: f32,
pub seed: u64,
pub linear_rounds: usize,
pub max_linear_memory_mb: usize,
#[rkyv(with = rkyv::with::Skip)]
#[serde(skip)]
pub feature_extractor: Option<FeatureExtractor>,
pub ensemble_seeds: Option<Vec<u64>>,
pub stacking_strategy: StackingStrategy,
}
impl Default for UniversalConfig {
fn default() -> Self {
Self {
mode: BoostingMode::PureTree,
num_rounds: gbdt_defaults::DEFAULT_NUM_ROUNDS,
tree_config: TreeConfig::default(),
linear_config: LinearConfig::default(),
learning_rate: tree_defaults::DEFAULT_LEARNING_RATE,
subsample: gbdt_defaults::DEFAULT_SUBSAMPLE,
validation_ratio: gbdt_defaults::DEFAULT_VALIDATION_RATIO,
early_stopping_rounds: gbdt_defaults::DEFAULT_EARLY_STOPPING_ROUNDS,
calibration_ratio: gbdt_defaults::DEFAULT_CALIBRATION_RATIO,
conformal_quantile: gbdt_defaults::DEFAULT_CONFORMAL_QUANTILE,
seed: seeds_defaults::DEFAULT_SEED,
linear_rounds: universal_defaults::DEFAULT_LINEAR_ROUNDS, max_linear_memory_mb: universal_defaults::DEFAULT_MAX_LINEAR_MEMORY_MB, feature_extractor: None,
ensemble_seeds: None, stacking_strategy: StackingStrategy::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UniversalPreset {
PureTree,
LinearThenTree,
RandomForest,
TimeSeries,
NoisyTabular,
UncertaintyAware,
}
impl UniversalConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_preset(mut self, preset: UniversalPreset) -> Self {
match preset {
UniversalPreset::PureTree => {
self.mode = BoostingMode::PureTree;
}
UniversalPreset::LinearThenTree => {
self.mode = BoostingMode::LinearThenTree;
}
UniversalPreset::RandomForest => {
self.mode = BoostingMode::RandomForest;
}
UniversalPreset::TimeSeries => {
self.mode = BoostingMode::LinearThenTree;
self.linear_config = self.linear_config.with_preset(LinearPreset::Aggressive);
}
UniversalPreset::NoisyTabular => {
self.mode = BoostingMode::RandomForest;
self.tree_config = self.tree_config.with_preset(TreePreset::Regularized);
}
UniversalPreset::UncertaintyAware => {
self.mode = BoostingMode::PureTree;
self.calibration_ratio = gbdt_defaults::CONFORMAL_CALIBRATION_RATIO;
self.conformal_quantile = gbdt_defaults::DEFAULT_CONFORMAL_QUANTILE;
}
}
self
}
pub fn with_mode(mut self, mode: BoostingMode) -> Self {
self.mode = mode;
self
}
pub fn with_num_rounds(mut self, rounds: usize) -> Self {
self.num_rounds = rounds;
self
}
pub fn with_tree_config(mut self, config: TreeConfig) -> Self {
self.tree_config = config;
self
}
pub fn with_linear_config(mut self, config: LinearConfig) -> Self {
self.linear_config = config;
self
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr.clamp(0.0, 1.0);
self
}
pub fn with_subsample(mut self, ratio: f32) -> Self {
self.subsample = ratio.clamp(0.0, 1.0);
self
}
pub fn with_validation_ratio(mut self, ratio: f32) -> Self {
self.validation_ratio = ratio.clamp(0.0, 0.5);
self
}
pub fn with_early_stopping_rounds(mut self, rounds: usize) -> Self {
self.early_stopping_rounds = rounds;
self
}
pub fn with_conformal_calibration(mut self, ratio: f32, quantile: f32) -> Self {
self.calibration_ratio = ratio.clamp(0.0, 0.5);
self.conformal_quantile = quantile.clamp(0.5, 0.99);
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_linear_rounds(mut self, rounds: usize) -> Self {
self.linear_rounds = rounds;
self
}
pub fn with_max_linear_memory_mb(mut self, mb: usize) -> Self {
self.max_linear_memory_mb = mb;
self
}
pub fn with_feature_extractor(mut self, extractor: Option<FeatureExtractor>) -> Self {
self.feature_extractor = extractor;
self
}
pub fn estimate_linear_memory(&self, num_rows: usize, num_features: usize) -> usize {
num_rows * num_features * 4
}
pub fn with_ensemble_seeds(mut self, seeds: Vec<u64>) -> Self {
self.ensemble_seeds = Some(seeds);
self
}
pub fn with_stacking_strategy(mut self, strategy: StackingStrategy) -> Self {
self.stacking_strategy = strategy;
self
}
}