use crate::analysis::{compute_mse, compute_r2};
use crate::booster::{GBDTConfig, GBDTModel};
use crate::dataset::{BinnedDataset, FeatureInfo, QuantileBinner};
use crate::defaults::{
linear as linear_defaults, ltt as ltt_defaults, seeds as seeds_defaults, tree as tree_defaults,
};
use crate::learner::{LinearBooster, LinearConfig, WeakLearner};
use crate::{Result, TreeBoostError};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy)]
pub struct LinearHyperparams {
pub lambda: f32,
pub l1_ratio: f32,
pub shrinkage_factor: f32,
pub extrapolation_damping: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinearHyperparamsPreset {
Ridge,
Lasso,
ElasticNet,
}
impl Default for LinearHyperparams {
fn default() -> Self {
Self {
lambda: linear_defaults::DEFAULT_LAMBDA,
l1_ratio: linear_defaults::DEFAULT_L1_RATIO, shrinkage_factor: ltt_defaults::DEFAULT_LTT_SHRINKAGE,
extrapolation_damping: linear_defaults::DEFAULT_EXTRAPOLATION_DAMPING,
}
}
}
impl LinearHyperparams {
pub fn to_config(&self) -> LinearConfig {
LinearConfig::default()
.with_lambda(self.lambda)
.with_l1_ratio(self.l1_ratio)
.with_shrinkage_factor(self.shrinkage_factor)
.with_extrapolation_damping(self.extrapolation_damping)
}
pub fn with_preset(mut self, preset: LinearHyperparamsPreset) -> Self {
match preset {
LinearHyperparamsPreset::Ridge => {
self.l1_ratio = 0.0;
}
LinearHyperparamsPreset::Lasso => {
self.l1_ratio = 1.0;
}
LinearHyperparamsPreset::ElasticNet => {
self.l1_ratio = 0.5;
}
}
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct TreeHyperparams {
pub max_depth: u32,
pub learning_rate: f32,
pub num_rounds: u32,
pub min_child_weight: f32,
pub subsample: f32,
pub colsample_bytree: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeHyperparamsPreset {
Conservative,
Aggressive,
}
impl Default for TreeHyperparams {
fn default() -> Self {
Self {
max_depth: tree_defaults::DEFAULT_MAX_DEPTH as u32,
learning_rate: tree_defaults::DEFAULT_LEARNING_RATE,
num_rounds: 500,
min_child_weight: 1.0,
subsample: 1.0,
colsample_bytree: 1.0,
}
}
}
impl TreeHyperparams {
pub fn with_preset(mut self, preset: TreeHyperparamsPreset) -> Self {
match preset {
TreeHyperparamsPreset::Conservative => {
self.max_depth = 4;
self.learning_rate = 0.05;
self.num_rounds = 1000;
self.min_child_weight = 3.0;
self.subsample = 0.8;
self.colsample_bytree = 0.8;
}
TreeHyperparamsPreset::Aggressive => {
self.max_depth = 8;
self.learning_rate = 0.15;
self.num_rounds = 500;
self.min_child_weight = 1.0;
self.subsample = 1.0;
self.colsample_bytree = 1.0;
}
}
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct LttConfig {
pub linear: LinearHyperparams,
pub tree: TreeHyperparams,
}
impl Default for LttConfig {
fn default() -> Self {
Self {
linear: LinearHyperparams::default(),
tree: TreeHyperparams::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct LttTuningResult {
pub linear_params: LinearHyperparams,
pub tree_params: TreeHyperparams,
pub linear_r2: f32,
pub final_rmse: f32,
pub total_time: Duration,
pub phase_times: PhaseTimes,
pub history: LttTuningHistory,
}
#[derive(Debug, Clone, Default)]
pub struct PhaseTimes {
pub linear_phase: Duration,
pub tree_phase: Duration,
pub joint_phase: Duration,
}
#[derive(Debug, Clone, Default)]
pub struct LttTuningHistory {
pub linear_trials: Vec<LinearTrial>,
pub tree_trials: Vec<TreeTrial>,
pub joint_trials: Vec<JointTrial>,
}
#[derive(Debug, Clone)]
pub struct LinearTrial {
pub lambda: f32,
pub l1_ratio: f32,
pub r2: f32,
pub rmse: f32,
}
#[derive(Debug, Clone)]
pub struct TreeTrial {
pub max_depth: u32,
pub learning_rate: f32,
pub num_rounds: u32,
pub residual_rmse: f32,
}
#[derive(Debug, Clone)]
pub struct JointTrial {
pub extrapolation_damping: f32,
pub combined_rmse: f32,
}
struct DataSplit<'a> {
train_features: &'a [f32],
train_targets: &'a [f32],
val_features: &'a [f32],
val_targets: &'a [f32],
num_features: usize,
}
impl<'a> DataSplit<'a> {
fn new(
train_features: &'a [f32],
train_targets: &'a [f32],
val_features: &'a [f32],
val_targets: &'a [f32],
num_features: usize,
) -> Self {
Self {
train_features,
train_targets,
val_features,
val_targets,
num_features,
}
}
}
struct LinearEvalResult {
train_preds: Vec<f32>,
val_preds: Vec<f32>,
r2: f32,
rmse: f32,
}
#[derive(Debug, Clone)]
pub struct LttTunerConfig {
pub val_ratio: f32,
pub lambda_values: Vec<f32>,
pub l1_ratio_values: Vec<f32>,
pub max_depth_values: Vec<u32>,
pub learning_rate_values: Vec<f32>,
pub num_rounds_values: Vec<u32>,
pub shrinkage_factor_values: Vec<f32>,
pub extrapolation_damping_values: Vec<f32>,
pub enable_joint_refinement: bool,
pub seed: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LttTunerPreset {
Quick,
Standard,
Thorough,
ShrinkageOnly,
}
impl Default for LttTunerConfig {
fn default() -> Self {
Self {
val_ratio: ltt_defaults::DEFAULT_LTT_VAL_RATIO,
lambda_values: ltt_defaults::DEFAULT_LAMBDA_GRID.to_vec(),
l1_ratio_values: ltt_defaults::DEFAULT_L1_RATIO_GRID.to_vec(), max_depth_values: ltt_defaults::DEFAULT_MAX_DEPTH_GRID.to_vec(),
learning_rate_values: ltt_defaults::DEFAULT_LR_GRID.to_vec(),
num_rounds_values: ltt_defaults::DEFAULT_ROUNDS_GRID.to_vec(),
shrinkage_factor_values: ltt_defaults::DEFAULT_SHRINKAGE_GRID.to_vec(),
extrapolation_damping_values: ltt_defaults::DEFAULT_EXTRAPOLATION_DAMPING_GRID.to_vec(),
enable_joint_refinement: true,
seed: seeds_defaults::DEFAULT_SEED,
}
}
}
impl LttTunerConfig {
pub fn with_preset(self, preset: LttTunerPreset) -> Self {
match preset {
LttTunerPreset::Quick => Self {
val_ratio: ltt_defaults::DEFAULT_LTT_VAL_RATIO,
lambda_values: ltt_defaults::QUICK_LAMBDA_GRID.to_vec(),
l1_ratio_values: ltt_defaults::QUICK_L1_RATIO_GRID.to_vec(),
max_depth_values: ltt_defaults::QUICK_MAX_DEPTH_GRID.to_vec(),
learning_rate_values: ltt_defaults::QUICK_LR_GRID.to_vec(),
num_rounds_values: ltt_defaults::QUICK_ROUNDS_GRID.to_vec(),
shrinkage_factor_values: ltt_defaults::QUICK_SHRINKAGE_GRID.to_vec(),
extrapolation_damping_values: ltt_defaults::QUICK_EXTRAPOLATION_DAMPING_GRID
.to_vec(),
enable_joint_refinement: false, seed: seeds_defaults::DEFAULT_SEED,
},
LttTunerPreset::Standard => Self::default(),
LttTunerPreset::Thorough => Self {
val_ratio: ltt_defaults::DEFAULT_LTT_VAL_RATIO,
lambda_values: ltt_defaults::THOROUGH_LAMBDA_GRID.to_vec(),
l1_ratio_values: ltt_defaults::THOROUGH_L1_RATIO_GRID.to_vec(),
max_depth_values: ltt_defaults::THOROUGH_MAX_DEPTH_GRID.to_vec(),
learning_rate_values: ltt_defaults::THOROUGH_LR_GRID.to_vec(),
num_rounds_values: ltt_defaults::THOROUGH_ROUNDS_GRID.to_vec(),
shrinkage_factor_values: ltt_defaults::THOROUGH_SHRINKAGE_GRID.to_vec(),
extrapolation_damping_values: ltt_defaults::THOROUGH_EXTRAPOLATION_DAMPING_GRID
.to_vec(),
enable_joint_refinement: true,
seed: seeds_defaults::DEFAULT_SEED,
},
LttTunerPreset::ShrinkageOnly => {
let mut config = Self::default();
let linear_defaults = LinearHyperparams::default();
let tree_defaults = TreeHyperparams::default();
config.lambda_values = vec![linear_defaults.lambda];
config.l1_ratio_values = vec![linear_defaults.l1_ratio];
config.max_depth_values = vec![tree_defaults.max_depth];
config.learning_rate_values = vec![tree_defaults.learning_rate];
config.num_rounds_values = vec![tree_defaults.num_rounds];
config.enable_joint_refinement = false;
config
}
}
}
fn validate(&self) -> Result<()> {
if self.val_ratio <= 0.0 || self.val_ratio >= 1.0 {
return Err(TreeBoostError::Config(format!(
"val_ratio must be in (0.0, 1.0), got {}",
self.val_ratio
)));
}
if self.lambda_values.is_empty() {
return Err(TreeBoostError::Config(
"lambda_values cannot be empty".into(),
));
}
if self.l1_ratio_values.is_empty() {
return Err(TreeBoostError::Config(
"l1_ratio_values cannot be empty".into(),
));
}
if self.max_depth_values.is_empty() {
return Err(TreeBoostError::Config(
"max_depth_values cannot be empty".into(),
));
}
if self.learning_rate_values.is_empty() {
return Err(TreeBoostError::Config(
"learning_rate_values cannot be empty".into(),
));
}
if self.num_rounds_values.is_empty() {
return Err(TreeBoostError::Config(
"num_rounds_values cannot be empty".into(),
));
}
if self.shrinkage_factor_values.is_empty() {
return Err(TreeBoostError::Config(
"shrinkage_factor_values cannot be empty".into(),
));
}
if self.enable_joint_refinement && self.extrapolation_damping_values.is_empty() {
return Err(TreeBoostError::Config(
"extrapolation_damping_values cannot be empty when joint refinement is enabled"
.into(),
));
}
Ok(())
}
}
pub struct LttTuner {
config: LttTunerConfig,
}
impl LttTuner {
pub fn new(config: LttTunerConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(LttTunerConfig::default())
}
pub fn estimated_trials(&self) -> usize {
let linear_trials = self.config.lambda_values.len() * self.config.l1_ratio_values.len();
let tree_trials = self.config.max_depth_values.len()
* self.config.learning_rate_values.len()
* self.config.num_rounds_values.len();
let joint_trials = if self.config.enable_joint_refinement {
self.config.extrapolation_damping_values.len()
} else {
0
};
linear_trials + tree_trials + joint_trials
}
pub fn tune(
&self,
features: &[f32],
num_features: usize,
targets: &[f32],
) -> Result<LttTuningResult> {
self.validate_inputs(features, num_features, targets)?;
let start = Instant::now();
let mut history = LttTuningHistory::default();
let mut phase_times = PhaseTimes::default();
let num_rows = targets.len();
let val_size = ((num_rows as f32) * self.config.val_ratio).ceil() as usize;
let train_size = num_rows - val_size;
if train_size == 0 {
return Err(TreeBoostError::Data(format!(
"Train/val split produced empty training set (val_ratio={}, num_rows={})",
self.config.val_ratio, num_rows
)));
}
if val_size == 0 {
return Err(TreeBoostError::Data(format!(
"Train/val split produced empty validation set (val_ratio={}, num_rows={})",
self.config.val_ratio, num_rows
)));
}
let train_indices: Vec<usize> = (0..train_size).collect();
let val_indices: Vec<usize> = (train_size..num_rows).collect();
let (train_features, train_targets) =
Self::extract_split(features, targets, num_features, &train_indices);
let (val_features, val_targets) =
Self::extract_split(features, targets, num_features, &val_indices);
let split = DataSplit::new(
&train_features,
&train_targets,
&val_features,
&val_targets,
num_features,
);
let phase1_start = Instant::now();
let (mut best_linear, linear_r2, linear_train_preds, linear_val_preds) =
self.tune_linear_phase(&split, &mut history)?;
phase_times.linear_phase = phase1_start.elapsed();
let best_shrinkage =
self.select_shrinkage_factor(&linear_train_preds, &linear_val_preds, &split);
best_linear.shrinkage_factor = best_shrinkage;
let train_residuals: Vec<f32> = train_targets
.iter()
.zip(linear_train_preds.iter())
.map(|(&t, &p)| t - best_shrinkage * p)
.collect();
let phase2_start = Instant::now();
let best_tree = self.tune_tree_phase(&train_residuals, &mut history)?;
phase_times.tree_phase = phase2_start.elapsed();
let mut final_linear = best_linear;
if self.config.enable_joint_refinement {
let phase3_start = Instant::now();
final_linear = self.tune_joint_phase(&split, &best_linear, &mut history)?;
phase_times.joint_phase = phase3_start.elapsed();
}
let final_rmse = self.compute_final_rmse(&final_linear, &split)?;
Ok(LttTuningResult {
linear_params: final_linear,
tree_params: best_tree,
linear_r2,
final_rmse,
total_time: start.elapsed(),
phase_times,
history,
})
}
fn validate_inputs(
&self,
features: &[f32],
num_features: usize,
targets: &[f32],
) -> Result<()> {
self.config.validate()?;
if num_features == 0 {
return Err(TreeBoostError::Data(
"num_features must be greater than 0".into(),
));
}
if targets.is_empty() {
return Err(TreeBoostError::Data("targets cannot be empty".into()));
}
if features.is_empty() {
return Err(TreeBoostError::Data("features cannot be empty".into()));
}
let num_rows = targets.len();
let expected_features = num_rows * num_features;
if features.len() != expected_features {
return Err(TreeBoostError::Data(format!(
"Feature matrix size mismatch: expected {} ({}×{}), got {}",
expected_features,
num_rows,
num_features,
features.len()
)));
}
let min_rows = 10; if num_rows < min_rows {
return Err(TreeBoostError::Data(format!(
"Insufficient data for tuning: need at least {} rows, got {}",
min_rows, num_rows
)));
}
Ok(())
}
fn extract_split(
features: &[f32],
targets: &[f32],
num_features: usize,
indices: &[usize],
) -> (Vec<f32>, Vec<f32>) {
let mut split_features = Vec::with_capacity(indices.len() * num_features);
let mut split_targets = Vec::with_capacity(indices.len());
for &idx in indices {
for f in 0..num_features {
split_features.push(features[idx * num_features + f]);
}
split_targets.push(targets[idx]);
}
(split_features, split_targets)
}
fn evaluate_linear_config(config: LinearConfig, split: &DataSplit) -> Result<LinearEvalResult> {
let mut booster = LinearBooster::new(split.num_features, config);
let train_preds = booster.fit_direct(
split.train_features,
split.num_features,
split.train_targets,
)?;
let val_preds = booster.predict_batch(split.val_features, split.num_features);
let r2 = compute_r2(split.val_targets, &val_preds);
let mse = compute_mse(split.val_targets, &val_preds);
let rmse = mse.sqrt();
Ok(LinearEvalResult {
train_preds,
val_preds,
r2,
rmse,
})
}
fn tune_linear_phase(
&self,
split: &DataSplit,
history: &mut LttTuningHistory,
) -> Result<(LinearHyperparams, f32, Vec<f32>, Vec<f32>)> {
let mut best_params = LinearHyperparams::default();
let mut best_r2 = f32::NEG_INFINITY;
let mut best_train_preds: Option<Vec<f32>> = None;
let mut best_val_preds: Option<Vec<f32>> = None;
for &lambda in &self.config.lambda_values {
for &l1_ratio in &self.config.l1_ratio_values {
let config = LinearConfig::default()
.with_lambda(lambda)
.with_l1_ratio(l1_ratio);
let eval_result = Self::evaluate_linear_config(config, split)?;
history.linear_trials.push(LinearTrial {
lambda,
l1_ratio,
r2: eval_result.r2,
rmse: eval_result.rmse,
});
if eval_result.r2 > best_r2 {
best_r2 = eval_result.r2;
best_params = LinearHyperparams {
lambda,
l1_ratio,
shrinkage_factor: ltt_defaults::DEFAULT_LTT_SHRINKAGE,
extrapolation_damping: 0.0,
};
best_train_preds = Some(eval_result.train_preds);
best_val_preds = Some(eval_result.val_preds);
}
}
}
let train_preds = best_train_preds
.ok_or_else(|| TreeBoostError::Training("Linear phase produced no results".into()))?;
let val_preds = best_val_preds
.ok_or_else(|| TreeBoostError::Training("Linear phase produced no results".into()))?;
Ok((best_params, best_r2, train_preds, val_preds))
}
fn select_shrinkage_factor(
&self,
linear_train_preds: &[f32],
linear_val_preds: &[f32],
split: &DataSplit,
) -> f32 {
if self.config.shrinkage_factor_values.len() <= 1 {
return self
.config
.shrinkage_factor_values
.first()
.copied()
.unwrap_or(ltt_defaults::DEFAULT_LTT_SHRINKAGE);
}
let candidates = self.config.shrinkage_factor_values.clone();
let (train_binned, val_binned, feature_info) = Self::build_probe_binned_datasets(split);
struct ShrinkageScore {
shrinkage: f32,
score: f32,
}
let mut scores: Vec<ShrinkageScore> = Vec::with_capacity(candidates.len());
for &shrinkage in &candidates {
let train_residuals: Vec<f32> = split
.train_targets
.iter()
.zip(linear_train_preds.iter())
.map(|(&t, &p)| t - shrinkage * p)
.collect();
let val_residuals: Vec<f32> = split
.val_targets
.iter()
.zip(linear_val_preds.iter())
.map(|(&t, &p)| t - shrinkage * p)
.collect();
let train_dataset = BinnedDataset::new(
split.train_targets.len(),
train_binned.clone(),
train_residuals,
feature_info.clone(),
);
let val_dataset = BinnedDataset::new(
split.val_targets.len(),
val_binned.clone(),
val_residuals,
feature_info.clone(),
);
let probe_config = GBDTConfig::new()
.with_mse_loss()
.with_num_rounds(ltt_defaults::SHRINKAGE_PROBE_ROUNDS)
.with_learning_rate(ltt_defaults::SHRINKAGE_PROBE_LR)
.with_max_depth(ltt_defaults::SHRINKAGE_PROBE_DEPTH)
.with_min_samples_leaf(ltt_defaults::SHRINKAGE_PROBE_MIN_SAMPLES_LEAF)
.with_seed(self.config.seed);
let probe_model =
GBDTModel::train_binned(&train_dataset, probe_config).unwrap_or_else(|_| {
GBDTModel::train_binned(
&train_dataset,
GBDTConfig::new()
.with_mse_loss()
.with_num_rounds(1)
.with_max_depth(1)
.with_seed(self.config.seed),
)
.expect("Probe fallback should succeed")
});
let residual_preds = probe_model.predict(&val_dataset);
let combined_preds: Vec<f32> = linear_val_preds
.iter()
.zip(residual_preds.iter())
.map(|(&p, &r)| shrinkage * p + r)
.collect();
let abs_errors: Vec<f32> = combined_preds
.iter()
.zip(split.val_targets.iter())
.map(|(&p, &t)| (p - t).abs())
.collect();
let mae = abs_errors.iter().sum::<f32>() / abs_errors.len().max(1) as f32;
let rmse = compute_mse(split.val_targets, &combined_preds).sqrt();
let mean_error = mae;
let std = {
let var = abs_errors
.iter()
.map(|&e| {
let d = e - mean_error;
d * d
})
.sum::<f32>()
/ abs_errors.len().max(1) as f32;
var.sqrt()
};
let score = rmse + 0.5 * mae + 0.2 * std;
scores.push(ShrinkageScore { shrinkage, score });
}
let best_by = |f: fn(&ShrinkageScore) -> f32| {
scores
.iter()
.min_by(|a, b| f(a).partial_cmp(&f(b)).unwrap())
.map(|s| (s.shrinkage, f(s)))
.unwrap()
};
let (best_shrinkage, _best_metric) = best_by(|s| s.score);
best_shrinkage
}
fn build_probe_binned_datasets(split: &DataSplit) -> (Vec<u8>, Vec<u8>, Vec<FeatureInfo>) {
let num_train_rows = split.train_targets.len();
let num_val_rows = split.val_targets.len();
let num_features = split.num_features;
let binner = QuantileBinner::new(ltt_defaults::SHRINKAGE_PROBE_BINS);
let mut feature_info = Vec::with_capacity(num_features);
let mut boundaries: Vec<Vec<f64>> = Vec::with_capacity(num_features);
for feat in 0..num_features {
let mut combined = Vec::with_capacity(num_train_rows + num_val_rows);
for row in 0..num_train_rows {
combined.push(split.train_features[row * num_features + feat] as f64);
}
for row in 0..num_val_rows {
combined.push(split.val_features[row * num_features + feat] as f64);
}
let bins = binner.compute_boundaries(&combined);
boundaries.push(bins.clone());
feature_info.push(binner.create_feature_info(format!("f{}", feat), bins));
}
let train_binned = Self::bin_features(
split.train_features,
num_train_rows,
num_features,
&boundaries,
&binner,
);
let val_binned = Self::bin_features(
split.val_features,
num_val_rows,
num_features,
&boundaries,
&binner,
);
(train_binned, val_binned, feature_info)
}
fn bin_features(
features: &[f32],
num_rows: usize,
num_features: usize,
boundaries: &[Vec<f64>],
binner: &QuantileBinner,
) -> Vec<u8> {
let mut binned = Vec::with_capacity(num_rows * num_features);
for feat in 0..num_features {
let mut values = Vec::with_capacity(num_rows);
for row in 0..num_rows {
values.push(features[row * num_features + feat] as f64);
}
let column = binner.bin_column(&values, &boundaries[feat]);
binned.extend_from_slice(&column);
}
binned
}
fn tune_tree_phase(
&self,
residuals: &[f32],
history: &mut LttTuningHistory,
) -> Result<TreeHyperparams> {
let mut best_params = TreeHyperparams::default();
let mut best_score = f32::NEG_INFINITY;
let residual_std = crate::analysis::compute_std(residuals);
let residual_range = crate::analysis::compute_range(residuals);
let is_high_variance = residual_std > ltt_defaults::HIGH_VARIANCE_THRESHOLD;
for &max_depth in &self.config.max_depth_values {
for &learning_rate in &self.config.learning_rate_values {
for &num_rounds in &self.config.num_rounds_values {
let complexity_score = if is_high_variance {
(max_depth as f32 * ltt_defaults::DEPTH_WEIGHT_HIGH_VAR)
+ (num_rounds as f32 * ltt_defaults::ROUNDS_WEIGHT_HIGH_VAR)
- (learning_rate * ltt_defaults::LR_PENALTY_HIGH_VAR)
} else {
(max_depth as f32 * ltt_defaults::DEPTH_WEIGHT_LOW_VAR)
+ (learning_rate * ltt_defaults::LR_WEIGHT_LOW_VAR)
- (num_rounds as f32 * ltt_defaults::ROUNDS_PENALTY_LOW_VAR)
};
let depth_penalty = if max_depth > ltt_defaults::MAX_DEPTH_THRESHOLD {
ltt_defaults::EXTREME_CONFIG_PENALTY
} else {
0.0
};
let lr_penalty = if learning_rate < ltt_defaults::MIN_LR_THRESHOLD {
ltt_defaults::EXTREME_CONFIG_PENALTY
} else {
0.0
};
let score = complexity_score - depth_penalty - lr_penalty;
let simulated_rmse = residual_range / (1.0 + score.abs());
history.tree_trials.push(TreeTrial {
max_depth,
learning_rate,
num_rounds,
residual_rmse: simulated_rmse,
});
if score > best_score {
best_score = score;
best_params = TreeHyperparams {
max_depth,
learning_rate,
num_rounds,
min_child_weight: 1.0,
subsample: 1.0,
colsample_bytree: 1.0,
};
}
}
}
}
Ok(best_params)
}
fn tune_joint_phase(
&self,
split: &DataSplit,
linear_params: &LinearHyperparams,
history: &mut LttTuningHistory,
) -> Result<LinearHyperparams> {
let mut best_params = *linear_params;
let mut best_rmse = f32::INFINITY;
for &damping in &self.config.extrapolation_damping_values {
let config = LinearConfig::default()
.with_lambda(linear_params.lambda)
.with_l1_ratio(linear_params.l1_ratio)
.with_shrinkage_factor(linear_params.shrinkage_factor)
.with_extrapolation_damping(damping);
let eval_result = Self::evaluate_linear_config(config, split)?;
history.joint_trials.push(JointTrial {
extrapolation_damping: damping,
combined_rmse: eval_result.rmse,
});
if eval_result.rmse < best_rmse {
best_rmse = eval_result.rmse;
best_params.extrapolation_damping = damping;
}
}
Ok(best_params)
}
fn compute_final_rmse(
&self,
linear_params: &LinearHyperparams,
split: &DataSplit,
) -> Result<f32> {
let config = linear_params.to_config();
let eval_result = Self::evaluate_linear_config(config, split)?;
Ok(eval_result.rmse)
}
}
impl Default for LttTuner {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_hyperparams_default() {
let params = LinearHyperparams::default();
assert_eq!(params.lambda, 1.0);
assert_eq!(params.l1_ratio, 0.0); assert_eq!(params.shrinkage_factor, ltt_defaults::DEFAULT_LTT_SHRINKAGE);
assert_eq!(params.extrapolation_damping, 0.0);
}
#[test]
fn test_tree_hyperparams_default() {
let params = TreeHyperparams::default();
assert_eq!(params.max_depth, 6);
assert_eq!(params.learning_rate, 0.1);
assert_eq!(params.num_rounds, 500);
}
#[test]
fn test_ltt_tuner_config_presets() {
let quick = LttTunerConfig::default().with_preset(LttTunerPreset::Quick);
let thorough = LttTunerConfig::default().with_preset(LttTunerPreset::Thorough);
assert!(quick.lambda_values.len() < thorough.lambda_values.len());
assert!(quick.max_depth_values.len() < thorough.max_depth_values.len());
}
#[test]
fn test_ltt_tuner_estimated_trials() {
let config = LttTunerConfig::default();
let tuner = LttTuner::new(config.clone());
let linear_trials = config.lambda_values.len() * config.l1_ratio_values.len();
let tree_trials = config.max_depth_values.len()
* config.learning_rate_values.len()
* config.num_rounds_values.len();
let joint_trials = config.extrapolation_damping_values.len();
assert_eq!(
tuner.estimated_trials(),
linear_trials + tree_trials + joint_trials
);
}
#[test]
fn test_config_validation_empty_grids() {
let mut config = LttTunerConfig::default();
config.lambda_values = vec![];
let result = config.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("lambda_values"));
}
#[test]
fn test_config_validation_bad_val_ratio() {
let mut config = LttTunerConfig::default();
config.val_ratio = 1.5;
let result = config.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("val_ratio"));
}
#[test]
fn test_input_validation_empty_targets() {
let tuner = LttTuner::with_defaults();
let features = vec![1.0, 2.0, 3.0];
let targets: Vec<f32> = vec![];
let result = tuner.tune(&features, 1, &targets);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("empty"));
}
#[test]
fn test_input_validation_zero_features() {
let tuner = LttTuner::with_defaults();
let features = vec![1.0, 2.0, 3.0];
let targets = vec![1.0, 2.0, 3.0];
let result = tuner.tune(&features, 0, &targets);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("num_features"));
}
#[test]
fn test_input_validation_dimension_mismatch() {
let tuner = LttTuner::with_defaults();
let features = vec![1.0, 2.0, 3.0]; let targets = vec![1.0, 2.0];
let result = tuner.tune(&features, 1, &targets);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("mismatch"));
}
#[test]
fn test_ltt_tuner_tune() {
let num_features = 1;
let num_rows = 100;
let mut features = Vec::with_capacity(num_rows * num_features);
let mut targets = Vec::with_capacity(num_rows);
for i in 0..num_rows {
let x = (i as f32) / 10.0;
features.push(x);
targets.push(2.0 * x + 1.0 + (i as f32 % 3.0) * 0.1); }
let config = LttTunerConfig::default().with_preset(LttTunerPreset::Quick);
let tuner = LttTuner::new(config);
let result = tuner
.tune(&features, num_features, &targets)
.expect("LTT tuner should successfully fit linear data");
assert!(result.linear_r2 > 0.5, "R² should be > 0.5 for linear data");
assert!(result.final_rmse < 5.0, "RMSE should be reasonable");
assert!(!result.history.linear_trials.is_empty());
assert!(!result.history.tree_trials.is_empty());
}
#[test]
fn test_linear_params_to_config() {
let params = LinearHyperparams {
lambda: 0.5,
l1_ratio: 0.3,
shrinkage_factor: 0.8,
extrapolation_damping: 0.1,
};
let config = params.to_config();
assert!((config.lambda - 0.5).abs() < 1e-6);
assert!((config.l1_ratio - 0.3).abs() < 1e-6);
assert!((config.shrinkage_factor - 0.8).abs() < 1e-6);
assert!((config.extrapolation_damping - 0.1).abs() < 1e-6);
}
#[test]
fn test_shrinkage_factor_selection_strong_linear() {
let mut config = LttTunerConfig::default();
let shrinkage_grid = vec![0.3, 0.7];
config.shrinkage_factor_values = shrinkage_grid.clone();
let tuner = LttTuner::new(config);
let train_features = vec![0.0, 1.0, 2.0, 3.0];
let val_features = vec![4.0, 5.0];
let train_targets = vec![0.0, 1.0, 2.0, 3.0];
let val_targets = vec![4.0, 5.0];
let split = DataSplit::new(
&train_features,
&train_targets,
&val_features,
&val_targets,
1,
);
let train_preds = vec![0.0, 1.0, 2.0, 3.0];
let val_preds = vec![4.0, 5.0];
let shrinkage = tuner.select_shrinkage_factor(&train_preds, &val_preds, &split);
assert!(
shrinkage_grid.contains(&shrinkage),
"Shrinkage should be selected from configured grid"
);
}
#[test]
fn test_shrinkage_factor_selection_weak_linear() {
let mut config = LttTunerConfig::default();
let shrinkage_grid = vec![0.3, 0.7];
config.shrinkage_factor_values = shrinkage_grid.clone();
let tuner = LttTuner::new(config);
let train_features = vec![0.0, 1.0, 2.0, 3.0];
let val_features = vec![4.0, 5.0];
let train_targets = vec![0.0, 1.0, 2.0, 3.0];
let val_targets = vec![4.0, 5.0];
let split = DataSplit::new(
&train_features,
&train_targets,
&val_features,
&val_targets,
1,
);
let train_preds = vec![3.0, 2.0, 1.0, 0.0];
let val_preds = vec![1.0, 0.0];
let shrinkage = tuner.select_shrinkage_factor(&train_preds, &val_preds, &split);
assert!(
shrinkage_grid.contains(&shrinkage),
"Shrinkage should be selected from configured grid"
);
}
#[test]
fn test_constants_are_reasonable() {
assert!(ltt_defaults::STRONG_LINEAR_R2 > ltt_defaults::WEAK_LINEAR_R2);
assert!(ltt_defaults::HIGH_SHRINKAGE_MIN > 0.0 && ltt_defaults::HIGH_SHRINKAGE_MIN < 1.0);
assert!(ltt_defaults::LOW_SHRINKAGE_MAX > 0.0 && ltt_defaults::LOW_SHRINKAGE_MAX < 1.0);
assert!(
ltt_defaults::DEFAULT_LTT_SHRINKAGE > 0.0 && ltt_defaults::DEFAULT_LTT_SHRINKAGE <= 1.0
);
assert!(ltt_defaults::HIGH_VARIANCE_THRESHOLD > 0.0);
assert!(ltt_defaults::MAX_DEPTH_THRESHOLD > 0);
assert!(ltt_defaults::MIN_LR_THRESHOLD > 0.0);
}
}