use crate::bin::Bin;
use crate::binning::{bin_columnar_matrix, bin_matrix};
use crate::booster::config::*;
use crate::constants::{
FREE_MEM_ALLOC_FACTOR, GENERALIZATION_THRESHOLD_RELAXED, ITER_LIMIT, MIN_COL_AMOUNT, N_NODES_ALLOC_MAX,
N_NODES_ALLOC_MIN, STOPPING_ROUNDS,
};
use crate::constraints::ConstraintMap;
use crate::data::{ColumnarMatrix, Matrix};
use crate::errors::PerpetualError;
use crate::histogram::{HistogramArena, NodeHistogram, update_cuts};
use crate::objective::{Objective, ObjectiveFunction};
use crate::splitter::{MissingBranchSplitter, MissingImputerSplitter, SplitInfo, SplitInfoSlice, Splitter};
use crate::tree::core::{Tree, TreeStopper};
use log::{info, warn};
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand::seq::IndexedRandom;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::mem;
use std::time::Instant;
use sysinfo::{MemoryRefreshKind, RefreshKind, System};
type ImportanceFn = fn(&Tree, &mut HashMap<usize, (f32, usize)>);
#[derive(Clone, Serialize, Deserialize)]
pub struct PerpetualBooster {
pub cfg: BoosterConfig,
#[serde(deserialize_with = "crate::booster::config::parse_missing")]
pub base_score: f64,
#[serde(deserialize_with = "crate::booster::config::parse_f32")]
pub eta: f32,
pub trees: Vec<Tree>,
#[serde(default)]
pub cal_models: HashMap<String, [(PerpetualBooster, f64); 2]>,
#[serde(default)]
pub cal_params: HashMap<String, Vec<f64>>,
#[serde(default)]
pub isotonic_calibrator: Option<crate::calibration::isotonic::IsotonicCalibrator>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl Default for PerpetualBooster {
fn default() -> Self {
PerpetualBooster {
cfg: BoosterConfig::default(),
base_score: f64::NAN,
eta: f32::NAN,
trees: Vec::new(),
cal_models: HashMap::new(),
cal_params: HashMap::new(),
isotonic_calibrator: None,
metadata: HashMap::new(),
}
}
}
impl PerpetualBooster {
#[allow(clippy::too_many_arguments)]
pub fn new(
objective: Objective,
budget: f32,
base_score: f64,
max_bin: u16,
num_threads: Option<usize>,
monotone_constraints: Option<ConstraintMap>,
interaction_constraints: Option<Vec<Vec<usize>>>,
force_children_to_bound_parent: bool,
missing: f64,
allow_missing_splits: bool,
create_missing_branch: bool,
terminate_missing_features: HashSet<usize>,
missing_node_treatment: MissingNodeTreatment,
log_iterations: usize,
seed: u64,
reset: Option<bool>,
categorical_features: Option<HashSet<usize>>,
timeout: Option<f32>,
iteration_limit: Option<usize>,
memory_limit: Option<f32>,
stopping_rounds: Option<usize>,
save_node_stats: bool,
) -> Result<Self, PerpetualError> {
let cfg = BoosterConfig {
objective,
budget,
max_bin,
num_threads,
monotone_constraints,
interaction_constraints,
force_children_to_bound_parent,
missing,
allow_missing_splits,
create_missing_branch,
terminate_missing_features,
missing_node_treatment,
log_iterations,
seed,
reset,
categorical_features,
timeout,
iteration_limit,
memory_limit,
stopping_rounds,
save_node_stats,
calibration_method: CalibrationMethod::default(),
};
let booster = PerpetualBooster {
cfg,
base_score,
eta: f32::NAN,
trees: Vec::new(),
cal_models: HashMap::new(),
cal_params: HashMap::new(),
isotonic_calibrator: None,
metadata: HashMap::new(),
};
booster.validate_parameters()?;
Ok(booster)
}
pub fn validate_parameters(&self) -> Result<(), PerpetualError> {
Ok(())
}
pub fn fit(
&mut self,
data: &Matrix<f64>,
y: &[f64],
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> Result<(), PerpetualError> {
let constraints_map = self
.cfg
.monotone_constraints
.as_ref()
.unwrap_or(&ConstraintMap::new())
.to_owned();
self.set_eta(self.cfg.budget);
if self.cfg.create_missing_branch {
let splitter = MissingBranchSplitter::new(
self.eta,
self.cfg.allow_missing_splits,
constraints_map,
self.cfg.terminate_missing_features.clone(),
self.cfg.missing_node_treatment,
self.cfg.force_children_to_bound_parent,
);
self.fit_trees(data, y, &splitter, sample_weight, group)?;
} else {
let splitter = MissingImputerSplitter::new(
self.eta,
self.cfg.allow_missing_splits,
constraints_map,
self.cfg.interaction_constraints.clone(),
);
self.fit_trees(data, y, &splitter, sample_weight, group)?;
};
Ok(())
}
pub fn fit_columnar(
&mut self,
data: &ColumnarMatrix<f64>,
y: &[f64],
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> Result<(), PerpetualError> {
let constraints_map = self
.cfg
.monotone_constraints
.as_ref()
.unwrap_or(&ConstraintMap::new())
.to_owned();
self.set_eta(self.cfg.budget);
if self.cfg.create_missing_branch {
let splitter = MissingBranchSplitter::new(
self.eta,
self.cfg.allow_missing_splits,
constraints_map,
self.cfg.terminate_missing_features.clone(),
self.cfg.missing_node_treatment,
self.cfg.force_children_to_bound_parent,
);
self.fit_trees_columnar(data, y, &splitter, sample_weight, group)?;
} else {
let splitter = MissingImputerSplitter::new(
self.eta,
self.cfg.allow_missing_splits,
constraints_map,
self.cfg.interaction_constraints.clone(),
);
self.fit_trees_columnar(data, y, &splitter, sample_weight, group)?;
};
Ok(())
}
pub fn fit_trees<T: Splitter>(
&mut self,
data: &Matrix<f64>,
y: &[f64],
splitter: &T,
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> Result<(), PerpetualError> {
let start = Instant::now();
let objective_fn = &self.cfg.objective;
let n_threads_available = std::thread::available_parallelism().unwrap().get();
let num_threads = match self.cfg.num_threads {
Some(num_threads) => num_threads,
None => n_threads_available,
};
let builder = rayon::ThreadPoolBuilder::new().num_threads(num_threads);
let pool = builder.build().unwrap();
let mut yhat;
if self.cfg.reset.unwrap_or(true) || self.trees.is_empty() {
self.trees.clear();
if self.base_score.is_nan() {
self.base_score = objective_fn.initial_value(y, sample_weight, group);
}
yhat = vec![self.base_score; y.len()];
} else {
yhat = self.predict(data, true);
}
let (mut grad, mut hess, mut loss) = objective_fn.gradient_and_loss(y, &yhat, sample_weight, group);
let loss_avg = if self.cfg.reset.unwrap_or(true) || self.trees.is_empty() {
loss.iter().sum::<f32>() / loss.len() as f32
} else {
let loss_base = objective_fn.loss(y, &vec![self.base_score; y.len()], sample_weight, group);
loss_base.iter().sum::<f32>() / loss_base.len() as f32
};
let base = 10.0_f32;
let n = base / self.cfg.budget;
let reciprocals_of_powers = n / (n - 1.0);
let truncated_series_sum = reciprocals_of_powers - (1.0 + 1.0 / n);
let c = 1.0 / n - truncated_series_sum;
let target_loss_decrement = c * base.powf(-self.cfg.budget) * loss_avg;
let is_const_hess = hess.is_none();
let binned_data = bin_matrix(
data,
sample_weight,
self.cfg.max_bin,
self.cfg.missing,
self.cfg.categorical_features.as_ref(),
)?;
let bdata = Matrix::new(&binned_data.binned_data, data.rows, data.cols);
let col_index: Vec<usize> = (0..data.cols).collect();
let mut stopping = 0;
let mut n_low_loss_rounds = 0;
let mut best_loss_avg = loss.iter().sum::<f32>() / loss.len() as f32;
let mut no_improvement_rounds: usize = 0;
let mut rng = StdRng::seed_from_u64(self.cfg.seed);
let row_column_ratio_limit = 10.0_f32.powf(-self.cfg.budget) * 1000.0;
let colsample_bytree = (data.rows as f32 / data.cols as f32) / row_column_ratio_limit;
let col_amount = (((col_index.len() as f32) * colsample_bytree).floor() as usize)
.clamp(usize::min(MIN_COL_AMOUNT, col_index.len()), col_index.len());
let mem_bin = mem::size_of::<Bin>();
let effective_max_bin = if col_amount == col_index.len() {
self.cfg.max_bin
} else {
let max_nunique = *binned_data
.nunique
.iter()
.max()
.unwrap_or(&(self.cfg.max_bin as usize + 2));
(max_nunique.saturating_sub(2) as u16).max(self.cfg.max_bin)
};
let mem_hist: f32 = if col_amount == col_index.len() {
(mem_bin * binned_data.nunique.iter().sum::<usize>()) as f32
} else {
(mem_bin * (effective_max_bin as usize + 2) * col_amount) as f32
};
let mem_hist = mem_hist
+ mem::size_of::<crate::histogram::NodeHistogram>() as f32
+ (mem::size_of::<crate::histogram::FeatureHistogram>() * col_amount) as f32;
let base_memory_bytes = ((data.rows * data.cols * 2)
+ (data.rows * 8)
+ (data.rows * 4)
+ (data.rows * 4)
+ if is_const_hess { 0 } else { data.rows * 4 }) as f32;
let sys = System::new_with_specifics(RefreshKind::nothing().with_memory(MemoryRefreshKind::everything()));
let mem_available = match sys.cgroup_limits() {
Some(limits) => limits.free_memory as f32,
None => sys.available_memory() as f32,
};
let ensemble_node_size = (mem::size_of::<crate::node::Node>() as f32 * 1.3) + if self.cfg.save_node_stats { 48.0 } else { 0.0 };
let iteration_limit = self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) as f32;
let avg_nodes_per_tree = 256.0_f32;
let n_nodes_alloc = match self.cfg.memory_limit {
Some(mem_limit) => {
let mem_limit_bytes = mem_limit * 1e9_f32;
let mem_limit_safe = mem_limit_bytes * 0.9;
let total_predicted_ensemble_mem = iteration_limit * avg_nodes_per_tree * ensemble_node_size;
let available_for_arena = (mem_limit_safe - base_memory_bytes - total_predicted_ensemble_mem).max(0.0);
let usable_memory = available_for_arena.min(mem_available);
let n = (FREE_MEM_ALLOC_FACTOR * (usable_memory / mem_hist)) as usize;
let data_rows_cap = (data.rows * 2).max(N_NODES_ALLOC_MIN);
n.max(3).min(data_rows_cap).min(N_NODES_ALLOC_MAX)
}
None => {
let actual_available = (mem_available - base_memory_bytes).max(0.0);
let n = (FREE_MEM_ALLOC_FACTOR * (actual_available / mem_hist)) as usize;
let data_rows_cap = (data.rows * 2).max(N_NODES_ALLOC_MIN);
n.min(data_rows_cap).clamp(N_NODES_ALLOC_MIN, N_NODES_ALLOC_MAX)
}
};
let mut hist_arena = if col_amount == col_index.len() {
HistogramArena::from_cuts(&binned_data.cuts, &col_index, is_const_hess, n_nodes_alloc)
} else {
HistogramArena::from_fixed(effective_max_bin, col_amount, is_const_hess, n_nodes_alloc)
};
let mut hist_tree: Vec<NodeHistogram> = hist_arena.as_node_histograms();
let mut split_info_vec: Vec<SplitInfo> = (0..col_amount).map(|_| SplitInfo::default()).collect();
let mut split_info_slice = SplitInfoSlice::new(&mut split_info_vec);
let mut index_buf = data.index.to_owned();
let mut total_ensemble_bytes = 0_usize;
for i in 0..self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) {
let verbose = if self.cfg.log_iterations == 0 {
false
} else {
i % self.cfg.log_iterations == 0
};
let tld = if n_low_loss_rounds > (self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) + 1) {
None
} else {
Some(target_loss_decrement)
};
let col_index_sample: Vec<usize> = if col_amount == col_index.len() {
Vec::new()
} else {
let mut v: Vec<usize> = col_index.sample(&mut rng, col_amount).copied().collect();
v.sort();
v
};
let col_index_fit = if col_amount == col_index.len() {
&col_index
} else {
&col_index_sample
};
if col_amount != col_index.len() {
hist_tree.iter().for_each(|h| {
update_cuts(h, col_index_fit, &binned_data.cuts, true);
});
}
let mut tree = Tree::new();
index_buf.copy_from_slice(&data.index);
tree.fit(
objective_fn,
&bdata,
index_buf,
col_index_fit,
&mut grad,
hess.as_deref_mut(),
splitter,
&pool,
tld,
&loss,
y,
&yhat,
sample_weight,
group,
is_const_hess,
&mut hist_tree,
self.cfg.categorical_features.as_ref(),
&mut split_info_slice,
n_nodes_alloc,
self.cfg.save_node_stats,
);
self.update_predictions_inplace(&mut yhat, &tree, data);
index_buf = std::mem::take(&mut tree.train_index);
if tree.nodes.len() < 5 {
let generalization = tree
.nodes
.values()
.map(|n| n.stats.as_ref().and_then(|s| s.generalization).unwrap_or(0.0))
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(0.0);
if generalization < GENERALIZATION_THRESHOLD_RELAXED && tree.stopper != TreeStopper::StepSize {
stopping += 1;
if tree.nodes.len() == 1 {
break;
}
}
}
if tree.stopper != TreeStopper::StepSize {
n_low_loss_rounds += 1;
} else {
n_low_loss_rounds = 0;
}
objective_fn.gradient_and_loss_into(y, &yhat, sample_weight, group, &mut grad, &mut hess, &mut loss);
let current_loss_avg = loss.iter().sum::<f32>() / loss.len() as f32;
if current_loss_avg < best_loss_avg {
best_loss_avg = current_loss_avg;
no_improvement_rounds = 0;
} else {
no_improvement_rounds += 1;
}
if verbose {
info!(
"round {:0?}, tree.nodes: {:1?}, tree.depth: {:2?}, tree.stopper: {:3?}, loss: {:4?}",
i,
tree.nodes.len(),
tree.depth,
tree.stopper,
current_loss_avg,
);
}
tree.leaf_bounds = Vec::new();
tree.train_index = Vec::new();
let cat_bytes: usize = tree
.nodes
.values()
.map(|n| n.left_cats.as_ref().map_or(0, |c| c.len()))
.sum();
let tree_bytes = (tree.nodes.capacity() as f32 * ensemble_node_size) as usize
+ tree.leaf_bounds.capacity() * std::mem::size_of::<(f64, usize, usize)>()
+ cat_bytes;
total_ensemble_bytes += tree_bytes;
self.trees.push(tree);
if let Some(mem_limit) = self.cfg.memory_limit {
let mem_limit_safe = mem_limit * 1e9_f32 * 0.9;
let current_total_bytes =
base_memory_bytes + (n_nodes_alloc as f32 * mem_hist) + (total_ensemble_bytes as f32);
if current_total_bytes > mem_limit_safe {
warn!(
"Reached memory limit before auto stopping. Stopped at iteration {}. Try to increase memory_limit.",
i
);
break;
}
}
if stopping >= self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) {
info!("Auto stopping since stopping round limit reached.");
break;
}
if no_improvement_rounds >= self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) {
info!(
"Auto stopping since training loss did not improve for {} consecutive rounds.",
no_improvement_rounds
);
break;
}
if self.cfg.timeout.is_some_and(|t| start.elapsed().as_secs_f32() > t) {
warn!(
"Reached timeout before auto stopping. Try to decrease the budget or increase the timeout for the best performance."
);
break;
}
if i == self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) - 1 {
warn!(
"Reached iteration limit before auto stopping. Try to decrease the budget or increase the iteration limit for the best performance."
);
}
}
if self.cfg.log_iterations > 0 {
info!(
"Finished training a booster with {0} trees in {1} seconds.",
self.trees.len(),
start.elapsed().as_secs()
);
}
Ok(())
}
pub fn fit_trees_columnar<T: Splitter>(
&mut self,
data: &ColumnarMatrix<f64>,
y: &[f64],
splitter: &T,
sample_weight: Option<&[f64]>,
group: Option<&[u64]>,
) -> Result<(), PerpetualError> {
let start = Instant::now();
let objective_fn = &self.cfg.objective;
let n_threads_available = std::thread::available_parallelism().unwrap().get();
let num_threads = match self.cfg.num_threads {
Some(num_threads) => num_threads,
None => n_threads_available,
};
let builder = rayon::ThreadPoolBuilder::new().num_threads(num_threads);
let pool = builder.build().unwrap();
let mut yhat;
if self.cfg.reset.unwrap_or(true) || self.trees.is_empty() {
self.trees.clear();
if self.base_score.is_nan() {
self.base_score = objective_fn.initial_value(y, sample_weight, group);
}
yhat = vec![self.base_score; y.len()];
} else {
yhat = self.predict_columnar(data, true);
}
let (mut grad, mut hess, mut loss) = objective_fn.gradient_and_loss(y, &yhat, sample_weight, group);
let loss_avg = if self.cfg.reset.unwrap_or(true) || self.trees.is_empty() {
loss.iter().sum::<f32>() / loss.len() as f32
} else {
let loss_base = objective_fn.loss(y, &vec![self.base_score; y.len()], sample_weight, group);
loss_base.iter().sum::<f32>() / loss_base.len() as f32
};
let base = 10.0_f32;
let n = base / self.cfg.budget;
let reciprocals_of_powers = n / (n - 1.0);
let truncated_series_sum = reciprocals_of_powers - (1.0 + 1.0 / n);
let c = 1.0 / n - truncated_series_sum;
let target_loss_decrement = c * base.powf(-self.cfg.budget) * loss_avg;
let is_const_hess = hess.is_none();
let binned_data = bin_columnar_matrix(
data,
sample_weight,
self.cfg.max_bin,
self.cfg.missing,
self.cfg.categorical_features.as_ref(),
)?;
let bdata = Matrix::new(&binned_data.binned_data, data.rows, data.cols);
let col_index: Vec<usize> = (0..data.cols).collect();
let mut stopping = 0;
let mut n_low_loss_rounds = 0;
let mut best_loss_avg = loss.iter().sum::<f32>() / loss.len() as f32;
let mut no_improvement_rounds: usize = 0;
let mut rng = StdRng::seed_from_u64(self.cfg.seed);
let row_column_ratio_limit = 10.0_f32.powf(-self.cfg.budget) * 1000.0;
let colsample_bytree = (data.rows as f32 / data.cols as f32) / row_column_ratio_limit;
let col_amount = (((col_index.len() as f32) * colsample_bytree).floor() as usize)
.clamp(usize::min(MIN_COL_AMOUNT, col_index.len()), col_index.len());
let mem_bin = mem::size_of::<Bin>();
let effective_max_bin = if col_amount == col_index.len() {
self.cfg.max_bin
} else {
let max_nunique = *binned_data
.nunique
.iter()
.max()
.unwrap_or(&(self.cfg.max_bin as usize + 2));
(max_nunique.saturating_sub(2) as u16).max(self.cfg.max_bin)
};
let mem_hist: f32 = if col_amount == col_index.len() {
(mem_bin * binned_data.nunique.iter().sum::<usize>()) as f32
} else {
(mem_bin * (effective_max_bin as usize + 2) * col_amount) as f32
};
let mem_hist = mem_hist
+ mem::size_of::<crate::histogram::NodeHistogram>() as f32
+ (mem::size_of::<crate::histogram::FeatureHistogram>() * col_amount) as f32;
let base_memory_bytes = ((data.rows * data.cols * 2)
+ (data.rows * 8)
+ (data.rows * 4)
+ (data.rows * 4)
+ if is_const_hess { 0 } else { data.rows * 4 }) as f32;
let sys = System::new_with_specifics(RefreshKind::nothing().with_memory(MemoryRefreshKind::everything()));
let mem_available = match sys.cgroup_limits() {
Some(limits) => limits.free_memory as f32,
None => sys.available_memory() as f32,
};
let ensemble_node_size = (mem::size_of::<crate::node::Node>() as f32 * 1.3) + if self.cfg.save_node_stats { 48.0 } else { 0.0 };
let iteration_limit = self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) as f32;
let avg_nodes_per_tree = 256.0_f32;
let n_nodes_alloc = match self.cfg.memory_limit {
Some(mem_limit) => {
let mem_limit_bytes = mem_limit * 1e9_f32;
let mem_limit_safe = mem_limit_bytes * 0.9;
let total_predicted_ensemble_mem = iteration_limit * avg_nodes_per_tree * ensemble_node_size;
let available_for_arena = (mem_limit_safe - base_memory_bytes - total_predicted_ensemble_mem).max(0.0);
let usable_memory = available_for_arena.min(mem_available);
let n = (FREE_MEM_ALLOC_FACTOR * (usable_memory / mem_hist)) as usize;
let data_rows_cap = (data.rows * 2).max(N_NODES_ALLOC_MIN);
n.max(3).min(data_rows_cap).min(N_NODES_ALLOC_MAX)
}
None => {
let actual_available = (mem_available - base_memory_bytes).max(0.0);
let n = (FREE_MEM_ALLOC_FACTOR * (actual_available / mem_hist)) as usize;
let data_rows_cap = (data.rows * 2).max(N_NODES_ALLOC_MIN);
n.min(data_rows_cap).clamp(N_NODES_ALLOC_MIN, N_NODES_ALLOC_MAX)
}
};
let mut hist_arena = if col_amount == col_index.len() {
HistogramArena::from_cuts(&binned_data.cuts, &col_index, is_const_hess, n_nodes_alloc)
} else {
HistogramArena::from_fixed(effective_max_bin, col_amount, is_const_hess, n_nodes_alloc)
};
let mut hist_tree: Vec<NodeHistogram> = hist_arena.as_node_histograms();
let mut split_info_vec: Vec<SplitInfo> = (0..col_amount).map(|_| SplitInfo::default()).collect();
let mut split_info_slice = SplitInfoSlice::new(&mut split_info_vec);
let mut index_buf = data.index.to_owned();
let mut total_ensemble_bytes = 0_usize;
for i in 0..self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) {
let verbose = if self.cfg.log_iterations == 0 {
false
} else {
i % self.cfg.log_iterations == 0
};
let tld = if n_low_loss_rounds > (self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) + 1) {
None
} else {
Some(target_loss_decrement)
};
let col_index_sample: Vec<usize> = if col_amount == col_index.len() {
Vec::new()
} else {
let mut v: Vec<usize> = col_index.sample(&mut rng, col_amount).copied().collect();
v.sort();
v
};
let col_index_fit = if col_amount == col_index.len() {
&col_index
} else {
&col_index_sample
};
if col_amount != col_index.len() {
hist_tree.iter().for_each(|h| {
update_cuts(h, col_index_fit, &binned_data.cuts, true);
})
}
let mut tree = Tree::new();
index_buf.copy_from_slice(&data.index);
tree.fit(
objective_fn,
&bdata,
index_buf,
col_index_fit,
&mut grad,
hess.as_deref_mut(),
splitter,
&pool,
tld,
&loss,
y,
&yhat,
sample_weight,
group,
is_const_hess,
&mut hist_tree,
self.cfg.categorical_features.as_ref(),
&mut split_info_slice,
n_nodes_alloc,
self.cfg.save_node_stats,
);
self.update_predictions_inplace_columnar(&mut yhat, &tree, data);
index_buf = std::mem::take(&mut tree.train_index);
if tree.nodes.len() < 5 {
let generalization = tree
.nodes
.values()
.map(|n| n.stats.as_ref().and_then(|s| s.generalization).unwrap_or(0.0))
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(0.0);
if generalization < GENERALIZATION_THRESHOLD_RELAXED && tree.stopper != TreeStopper::StepSize {
stopping += 1;
if tree.nodes.len() == 1 {
break;
}
}
}
if tree.stopper != TreeStopper::StepSize {
n_low_loss_rounds += 1;
} else {
n_low_loss_rounds = 0;
}
objective_fn.gradient_and_loss_into(y, &yhat, sample_weight, group, &mut grad, &mut hess, &mut loss);
let current_loss_avg = loss.iter().sum::<f32>() / loss.len() as f32;
if current_loss_avg < best_loss_avg {
best_loss_avg = current_loss_avg;
no_improvement_rounds = 0;
} else {
no_improvement_rounds += 1;
}
if verbose {
info!(
"round {:0?}, tree.nodes: {:1?}, tree.depth: {:2?}, tree.stopper: {:3?}, loss: {:4?}",
i,
tree.nodes.len(),
tree.depth,
tree.stopper,
current_loss_avg,
);
}
tree.leaf_bounds = Vec::new();
tree.train_index = Vec::new();
let cat_bytes: usize = tree
.nodes
.values()
.map(|n| n.left_cats.as_ref().map_or(0, |c| c.len()))
.sum();
let tree_bytes = (tree.nodes.capacity() as f32 * ensemble_node_size) as usize
+ tree.leaf_bounds.capacity() * std::mem::size_of::<(f64, usize, usize)>()
+ cat_bytes;
total_ensemble_bytes += tree_bytes;
self.trees.push(tree);
if let Some(mem_limit) = self.cfg.memory_limit {
let mem_limit_safe = mem_limit * 1e9_f32 * 0.9;
let current_total_bytes =
base_memory_bytes + (n_nodes_alloc as f32 * mem_hist) + (total_ensemble_bytes as f32);
if current_total_bytes > mem_limit_safe {
warn!(
"Reached memory limit before auto stopping. Stopped at iteration {}. Try to increase memory_limit.",
i
);
break;
}
}
if stopping >= self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) {
info!("Auto stopping since stopping round limit reached.");
break;
}
if no_improvement_rounds >= self.cfg.stopping_rounds.unwrap_or(STOPPING_ROUNDS) {
info!(
"Auto stopping since training loss did not improve for {} consecutive rounds.",
no_improvement_rounds
);
break;
}
if self.cfg.timeout.is_some_and(|t| start.elapsed().as_secs_f32() > t) {
warn!(
"Reached timeout before auto stopping. Try to decrease the budget or increase the timeout for the best performance."
);
break;
}
if i == self.cfg.iteration_limit.unwrap_or(ITER_LIMIT) - 1 {
warn!(
"Reached iteration limit before auto stopping. Try to decrease the budget or increase the iteration limit for the best performance."
);
}
}
if self.cfg.log_iterations > 0 {
info!(
"Finished training a booster with {0} trees in {1} seconds.",
self.trees.len(),
start.elapsed().as_secs()
);
}
Ok(())
}
fn update_predictions_inplace(&self, yhat: &mut [f64], tree: &Tree, _data: &Matrix<f64>) {
if !tree.leaf_bounds.is_empty() && !tree.train_index.is_empty() {
for &(weight, start, stop) in &tree.leaf_bounds {
for &i in &tree.train_index[start..stop] {
yhat[i] += weight;
}
}
} else {
let preds = tree.predict(_data, true, &self.cfg.missing);
yhat.iter_mut().zip(preds).for_each(|(i, j)| *i += j);
}
}
fn update_predictions_inplace_columnar(&self, yhat: &mut [f64], tree: &Tree, _data: &ColumnarMatrix<f64>) {
if !tree.leaf_bounds.is_empty() && !tree.train_index.is_empty() {
for &(weight, start, stop) in &tree.leaf_bounds {
for &i in &tree.train_index[start..stop] {
yhat[i] += weight;
}
}
} else {
let preds = tree.predict_columnar(_data, true, &self.cfg.missing);
yhat.iter_mut().zip(preds).for_each(|(i, j)| *i += j);
}
}
pub fn set_eta(&mut self, budget: f32) {
let budget = f32::max(0.0, budget);
let power = -budget;
let base = 10_f32;
self.eta = base.powf(power);
}
pub fn get_prediction_trees(&self) -> &[Tree] {
&self.trees
}
pub fn value_partial_dependence(&self, feature: usize, value: f64) -> f64 {
let pd: f64 = if true {
self.get_prediction_trees()
.par_iter()
.map(|t| t.value_partial_dependence(feature, value, &self.cfg.missing))
.sum()
} else {
self.get_prediction_trees()
.iter()
.map(|t| t.value_partial_dependence(feature, value, &self.cfg.missing))
.sum()
};
pd + self.base_score
}
pub fn calculate_feature_importance(&self, method: ImportanceMethod, normalize: bool) -> HashMap<usize, f32> {
let (average, importance_fn): (bool, ImportanceFn) = match method {
ImportanceMethod::Weight => (false, Tree::calculate_importance_weight),
ImportanceMethod::Gain => (true, Tree::calculate_importance_gain),
ImportanceMethod::TotalGain => (false, Tree::calculate_importance_gain),
ImportanceMethod::Cover => (true, Tree::calculate_importance_cover),
ImportanceMethod::TotalCover => (false, Tree::calculate_importance_cover),
};
let mut stats = HashMap::new();
for tree in self.trees.iter() {
importance_fn(tree, &mut stats)
}
let importance = stats
.iter()
.map(|(k, (v, c))| if average { (*k, v / (*c as f32)) } else { (*k, *v) })
.collect::<HashMap<usize, f32>>();
if normalize {
let mut values: Vec<f32> = importance.values().copied().collect();
values.sort_by(|a, b| a.total_cmp(b));
let total: f32 = values.iter().sum();
importance.iter().map(|(k, v)| (*k, v / total)).collect()
} else {
importance
}
}
pub fn insert_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
}
pub fn get_metadata(&self, key: &String) -> Option<String> {
self.metadata.get(key).cloned()
}
}
pub(crate) fn fix_legacy_value(value: &mut serde_json::Value) {
if let Some(map) = value.as_object_mut() {
if let Some(nodes) = map.get_mut("nodes").and_then(|n| n.as_object_mut()) {
for node in nodes.values_mut() {
fix_legacy_node(node);
}
}
for v in map.values_mut() {
fix_legacy_value(v);
}
} else if let serde_json::Value::Array(arr) = value {
for v in arr {
fix_legacy_value(v);
}
}
}
pub(crate) fn fix_legacy_node(node: &mut serde_json::Value) {
if let Some(node_obj) = node.as_object_mut() {
if let Some(left_cats_arr) = node_obj
.get("left_cats")
.and_then(|v| v.as_array())
.filter(|arr| arr.len() != 8192 && (!arr.is_empty() || node_obj.contains_key("right_cats")))
{
let left_cats_indices: Vec<u16> = left_cats_arr
.iter()
.filter_map(|v| v.as_u64().map(|n| n as u16))
.collect();
let right_cats_indices: Vec<u16> = node_obj
.get("right_cats")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as u16)).collect())
.unwrap_or_default();
if !left_cats_indices.is_empty() || !right_cats_indices.is_empty() {
let missing_node = node_obj.get("missing_node").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let left_child = node_obj.get("left_child").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let mut bitset = vec![0u8; 8192];
if missing_node == left_child {
bitset.fill(255);
for &cat in &right_cats_indices {
let byte_idx = (cat >> 3) as usize;
let bit_idx = (cat & 7) as u8;
if byte_idx < 8192 {
bitset[byte_idx] &= !(1 << bit_idx);
}
}
} else {
for &cat in &left_cats_indices {
let byte_idx = (cat >> 3) as usize;
let bit_idx = (cat & 7) as u8;
if byte_idx < 8192 {
bitset[byte_idx] |= 1 << bit_idx;
}
}
}
node_obj.insert(
"left_cats".to_string(),
serde_json::Value::Array(
bitset
.into_iter()
.map(|b| serde_json::Value::Number(b.into()))
.collect(),
),
);
} else {
node_obj.insert("left_cats".to_string(), serde_json::Value::Null);
}
}
node_obj.remove("right_cats");
}
}
impl BoosterIO for PerpetualBooster {
fn from_json(json_str: &str) -> Result<Self, PerpetualError> {
let mut value: serde_json::Value =
serde_json::from_str(json_str).map_err(|e| PerpetualError::UnableToRead(e.to_string()))?;
fix_legacy_value(&mut value);
serde_json::from_value::<Self>(value).map_err(|e| PerpetualError::UnableToRead(e.to_string()))
}
}
#[cfg(test)]
mod perpetual_booster_test {
use crate::booster::config::*;
use crate::constraints::{Constraint, ConstraintMap};
use crate::metrics::ranking::{GainScheme, ndcg_at_k_metric};
use crate::objective::{Objective, ObjectiveFunction};
use crate::utils::between;
use crate::{Matrix, PerpetualBooster};
use approx::assert_relative_eq;
use rand::RngExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
fn read_data(path: &str) -> Result<(Vec<f64>, Vec<f64>), Box<dyn Error>> {
let feature_names = [
"MedInc",
"HouseAge",
"AveRooms",
"AveBedrms",
"Population",
"AveOccup",
"Latitude",
"Longitude",
];
let target_name = "MedHouseVal";
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut csv_reader = csv::ReaderBuilder::new().has_headers(true).from_reader(reader);
let headers = csv_reader.headers()?.clone();
let feature_indices: Vec<usize> = feature_names
.iter()
.map(|&name| headers.iter().position(|h| h == name).unwrap())
.collect();
let target_index = headers.iter().position(|h| h == target_name).unwrap();
let mut data_columns: Vec<Vec<f64>> = vec![Vec::new(); feature_names.len()];
let mut y = Vec::new();
for result in csv_reader.records() {
let record = result?;
let target_str = &record[target_index];
let target_val = if target_str.is_empty() {
f64::NAN
} else {
target_str.parse::<f64>().unwrap_or(f64::NAN)
};
y.push(target_val);
for (i, &idx) in feature_indices.iter().enumerate() {
let val_str = &record[idx];
let val = if val_str.is_empty() {
f64::NAN
} else {
val_str.parse::<f64>().unwrap_or(f64::NAN)
};
data_columns[i].push(val);
}
}
let data: Vec<f64> = data_columns.into_iter().flatten().collect();
Ok((data, y))
}
#[test]
fn test_booster_fit() {
let file =
fs::read_to_string("resources/contiguous_with_missing.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let data = Matrix::new(&data_vec, 891, 5);
let mut booster = PerpetualBooster::default().set_budget(0.3);
booster.fit(&data, &y, None, None).unwrap();
let preds = booster.predict(&data, false);
let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
println!("{}", booster.trees[0]);
println!("{}", booster.trees[0].nodes.len());
println!("{}", booster.trees.last().unwrap().nodes.len());
println!("{:?}", &preds[0..10]);
}
#[test]
fn test_booster_fit_no_fitted_base_score() {
let file =
fs::read_to_string("resources/contiguous_with_missing.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let file = fs::read_to_string("resources/performance-fare.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let data = Matrix::new(&data_vec, 891, 5);
let mut booster = PerpetualBooster::default()
.set_objective(Objective::SquaredLoss)
.set_max_bin(300)
.set_budget(0.3);
booster.fit(&data, &y, None, None).unwrap();
let preds = booster.predict(&data, false);
let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
println!("{}", booster.trees[0]);
println!("{}", booster.trees[0].nodes.len());
println!("{}", booster.trees.last().unwrap().nodes.len());
println!("{:?}", &preds[0..10]);
}
#[test]
fn test_tree_save() {
let file =
fs::read_to_string("resources/contiguous_with_missing.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let data = Matrix::new(&data_vec, 891, 5);
let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let mut booster = PerpetualBooster::default()
.set_max_bin(300)
.set_base_score(0.5)
.set_budget(0.3);
booster.fit(&data, &y, None, None).unwrap();
let preds = booster.predict(&data, true);
booster.save_booster("resources/model64.json").unwrap();
let booster2 = PerpetualBooster::load_booster("resources/model64.json").unwrap();
assert_eq!(booster2.predict(&data, true)[0..10], preds[0..10]);
booster.cfg.missing = 0.0;
booster.save_booster("resources/modelmissing.json").unwrap();
let booster3 = PerpetualBooster::load_booster("resources/modelmissing.json").unwrap();
assert_eq!(booster3.cfg.missing, 0.);
assert_eq!(booster3.cfg.missing, booster.cfg.missing);
}
#[test]
fn test_gbm_categorical() -> Result<(), Box<dyn Error>> {
let n_columns = 13;
let file = fs::read_to_string("resources/titanic_test_y.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let file =
fs::read_to_string("resources/titanic_test_flat.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let data = Matrix::new(&data_vec, y.len(), n_columns);
let cat_index = HashSet::from([0, 3, 4, 6, 7, 8, 10, 11]);
let mut booster = PerpetualBooster::default()
.set_budget(0.1)
.set_categorical_features(Some(cat_index));
booster.fit(&data, &y, None, None).unwrap();
let file = fs::read_to_string("resources/titanic_train_y.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let file =
fs::read_to_string("resources/titanic_train_flat.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let data = Matrix::new(&data_vec, y.len(), n_columns);
let probabilities = booster.predict_proba(&data, true, false);
let accuracy = probabilities
.iter()
.zip(y.iter())
.map(|(p, y)| if p.round() == *y { 1 } else { 0 })
.sum::<usize>() as f32
/ y.len() as f32;
println!("accuracy: {}", accuracy);
assert!(between(0.76, 0.78, accuracy));
Ok(())
}
#[test]
fn test_gbm_parallel() -> Result<(), Box<dyn Error>> {
let (data_train, y_train) = read_data("resources/cal_housing_train.csv")?;
let (data_test, y_test) = read_data("resources/cal_housing_test.csv")?;
let matrix_train = Matrix::new(&data_train, y_train.len(), 8);
let matrix_test = Matrix::new(&data_test, y_test.len(), 8);
let mut model1 = PerpetualBooster::default()
.set_objective(Objective::SquaredLoss)
.set_max_bin(10)
.set_num_threads(Some(1))
.set_budget(0.1);
let mut model2 = PerpetualBooster::default()
.set_objective(Objective::SquaredLoss)
.set_max_bin(10)
.set_num_threads(Some(2))
.set_budget(0.1);
model1.fit(&matrix_test, &y_test, None, None)?;
model2.fit(&matrix_test, &y_test, None, None)?;
let trees1 = model1.get_prediction_trees();
let trees2 = model2.get_prediction_trees();
assert_eq!(trees1.len(), trees2.len());
let n_leaves1: usize = trees1.iter().map(|t| t.nodes.len().div_ceil(2)).sum();
let n_leaves2: usize = trees2.iter().map(|t| t.nodes.len().div_ceil(2)).sum();
assert_eq!(n_leaves1, n_leaves2);
println!("{}", trees1.last().unwrap());
println!("{}", trees2.last().unwrap());
let y_pred1 = model1.predict(&matrix_train, true);
let y_pred2 = model2.predict(&matrix_train, true);
let mse1 = y_pred1
.iter()
.zip(y_train.iter())
.map(|(y1, y2)| (y1 - y2) * (y1 - y2))
.sum::<f64>()
/ y_train.len() as f64;
let mse2 = y_pred2
.iter()
.zip(y_train.iter())
.map(|(y1, y2)| (y1 - y2) * (y1 - y2))
.sum::<f64>()
/ y_train.len() as f64;
assert_relative_eq!(mse1, mse2, max_relative = 0.99);
Ok(())
}
#[test]
fn test_gbm_sensory() -> Result<(), Box<dyn Error>> {
let n_columns = 11;
let iter_limit = 10;
let file = fs::read_to_string("resources/sensory_y.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let file = fs::read_to_string("resources/sensory_flat.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let data = Matrix::new(&data_vec, y.len(), n_columns);
let cat_index = HashSet::from([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let mut booster = PerpetualBooster::default()
.set_log_iterations(1)
.set_objective(Objective::SquaredLoss)
.set_categorical_features(Some(cat_index))
.set_iteration_limit(Some(iter_limit))
.set_memory_limit(Some(0.0001))
.set_save_node_stats(true)
.set_budget(1.0);
booster.fit(&data, &y, None, None).unwrap();
let split_features_test = vec![6, 6, 6, 1, 6, 1, 6, 9, 1, 6];
let split_gains_test = vec![
31.172, 25.249, 20.452, 17.503, 16.566, 14.345, 13.418, 12.505, 12.232, 10.869,
];
for (i, tree) in booster.get_prediction_trees().iter().enumerate() {
let nodes = &tree.nodes;
let root_node = nodes.get(&0).unwrap();
println!("Tree {}: nodes.len = {}", i, nodes.len());
assert_eq!(3, nodes.len());
assert_eq!(root_node.split_feature, split_features_test[i]);
assert_relative_eq!(root_node.split_gain, split_gains_test[i], max_relative = 0.99);
}
assert_eq!(iter_limit, booster.get_prediction_trees().len());
let pred_nodes = booster.predict_nodes(&data, true);
println!("pred_nodes.len: {}", pred_nodes.len());
assert_eq!(booster.get_prediction_trees().len(), pred_nodes.len());
assert_eq!(data.rows, pred_nodes[0].len());
Ok(())
}
#[test]
fn test_booster_fit_subsample() {
let file =
fs::read_to_string("resources/contiguous_with_missing.csv").expect("Something went wrong reading the file");
let data_vec: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap_or(f64::NAN)).collect();
let file = fs::read_to_string("resources/performance.csv").expect("Something went wrong reading the file");
let y: Vec<f64> = file.lines().map(|x| x.parse::<f64>().unwrap()).collect();
let data = Matrix::new(&data_vec, 891, 5);
let mut booster = PerpetualBooster::default()
.set_max_bin(300)
.set_base_score(0.5)
.set_budget(0.3);
booster.fit(&data, &y, None, None).unwrap();
let preds = booster.predict(&data, false);
let contribs = booster.predict_contributions(&data, ContributionsMethod::Average, false);
assert_eq!(contribs.len(), (data.cols + 1) * data.rows);
println!("{}", booster.trees[0]);
println!("{}", booster.trees[0].nodes.len());
println!("{}", booster.trees.last().unwrap().nodes.len());
println!("{:?}", &preds[0..10]);
}
#[test]
fn test_huber_loss() -> Result<(), Box<dyn Error>> {
let (data_test, y_test) = read_data("resources/cal_housing_test.csv")?;
let matrix_test = Matrix::new(&data_test, y_test.len(), 8);
let mut model = PerpetualBooster::default()
.set_objective(Objective::HuberLoss { delta: Some(1.0) })
.set_max_bin(10)
.set_budget(0.1);
model.fit(&matrix_test, &y_test, None, None)?;
let trees = model.get_prediction_trees();
println!("trees = {}", trees.len());
assert_eq!(trees.len(), 41);
Ok(())
}
#[test]
fn test_adaptive_huber_loss() -> Result<(), Box<dyn Error>> {
let (data_test, y_test) = read_data("resources/cal_housing_test.csv")?;
let matrix_test = Matrix::new(&data_test, y_test.len(), 8);
let mut model = PerpetualBooster::default()
.set_objective(Objective::AdaptiveHuberLoss { quantile: Some(0.5) })
.set_max_bin(10)
.set_budget(0.1);
model.fit(&matrix_test, &y_test, None, None)?;
let trees = model.get_prediction_trees();
println!("trees = {}", trees.len());
assert_eq!(trees.len(), 3);
Ok(())
}
#[test]
fn test_custom_objective_function() -> Result<(), Box<dyn Error>> {
#[derive(Clone, Serialize, Deserialize)]
struct CustomSquaredLoss;
impl ObjectiveFunction for CustomSquaredLoss {
fn loss(&self, y: &[f64], yhat: &[f64], sample_weight: Option<&[f64]>, _group: Option<&[u64]>) -> Vec<f32> {
y.iter()
.zip(yhat)
.enumerate()
.map(|(i, (yi, yhi))| {
let d = yi - yhi;
let l = d * d;
match sample_weight {
Some(w) => (l * w[i]) as f32,
None => l as f32,
}
})
.collect()
}
fn gradient(
&self,
y: &[f64],
yhat: &[f64],
sample_weight: Option<&[f64]>,
_group: Option<&[u64]>,
) -> (Vec<f32>, Option<Vec<f32>>) {
let grad: Vec<f32> = y
.iter()
.zip(yhat)
.enumerate()
.map(|(i, (yi, yhi))| {
let g = yhi - yi;
match sample_weight {
Some(w) => (g * w[i]) as f32,
None => g as f32,
}
})
.collect();
(grad, None) }
}
let (data, y) = read_data("resources/cal_housing_test.csv")?;
let matrix = Matrix::new(&data, y.len(), 8);
let mut custom_booster = PerpetualBooster::default()
.set_objective(Objective::Custom(Arc::new(CustomSquaredLoss)))
.set_max_bin(10)
.set_budget(0.1)
.set_iteration_limit(Some(10));
let mut booster = PerpetualBooster::default()
.set_objective(Objective::SquaredLoss)
.set_max_bin(10)
.set_budget(0.1)
.set_iteration_limit(Some(10));
booster.fit(&matrix, &y, None, None)?;
custom_booster.fit(&matrix, &y, None, None)?;
let custom_prediction = custom_booster.predict(&matrix, false);
let booster_prediction = booster.predict(&matrix, false);
assert_relative_eq!(custom_prediction[..5], booster_prediction[..5], max_relative = 1e-6);
Ok(())
}
#[test]
fn test_listnet_loss() -> Result<(), Box<dyn std::error::Error>> {
let file = File::open("resources/goodreads.csv")?;
let reader = BufReader::new(file);
let mut csv_reader = csv::ReaderBuilder::new().has_headers(true).from_reader(reader);
let headers = csv_reader.headers()?.clone();
let year_idx = headers.iter().position(|h| h == "year").unwrap();
let category_idx = headers.iter().position(|h| h == "category").unwrap();
let rank_idx = headers.iter().position(|h| h == "rank").unwrap();
let feature_names = [
"avg_rating",
"pages",
"5stars",
"4stars",
"3stars",
"2stars",
"1stars",
"ratings",
];
let feature_indices: Vec<usize> = feature_names
.iter()
.map(|&name| headers.iter().position(|h| h == name).unwrap())
.collect();
let mut groups: Vec<u64> = Vec::new();
let mut y_raw: Vec<i64> = Vec::new();
let mut data_columns: Vec<Vec<f64>> = vec![Vec::new(); feature_names.len()];
let mut group_map: HashMap<(i64, String), u64> = HashMap::new();
let mut current_group_id = 0;
for result in csv_reader.records() {
let record = result?;
let year = record[year_idx].parse::<i64>().unwrap();
let category = record[category_idx].to_string();
let key = (year, category);
let group_id = *group_map.entry(key).or_insert_with(|| {
let id = current_group_id;
current_group_id += 1;
id
});
groups.push(group_id);
let rank = record[rank_idx].parse::<i64>().unwrap();
y_raw.push(rank);
for (i, &idx) in feature_indices.iter().enumerate() {
let val_str = &record[idx];
let val = if val_str.is_empty() {
0.0 } else {
val_str.parse::<f64>().unwrap_or(0.0)
};
data_columns[i].push(val);
}
}
let max_rank = *y_raw.iter().max().unwrap();
let y: Vec<f64> = y_raw.iter().map(|&v| (max_rank - v) as f64).collect();
let data: Vec<f64> = data_columns.into_iter().flatten().collect();
let mut group_counts: HashMap<u64, u64> = HashMap::new();
for group_id in &groups {
*group_counts.entry(*group_id).or_default() += 1;
}
let group_counts_vec: Vec<u64> = (0..current_group_id)
.map(|id| group_counts.get(&id).cloned().unwrap_or(0))
.collect();
let matrix = Matrix::new(&data, y.len(), feature_names.len());
let mut booster = PerpetualBooster::default()
.set_objective(Objective::ListNetLoss)
.set_budget(0.1)
.set_iteration_limit(Some(10))
.set_max_bin(10)
.set_memory_limit(Some(0.001));
booster.fit(&matrix, &y, None, Some(&group_counts_vec))?;
let objective_fn = &booster.cfg.objective;
let final_yhat = booster.predict(&matrix, true);
let _final_loss: f32 = objective_fn
.loss(&y, &final_yhat, None, Some(&group_counts_vec))
.iter()
.sum();
let sample_weight = vec![1.0; y.len()];
let final_ndcg = ndcg_at_k_metric(
&y,
&final_yhat,
&sample_weight,
&group_counts_vec,
None,
&GainScheme::Burges,
);
let mut rng = rand::rng();
let random_guesses: Vec<f64> = (0..y.len())
.map(|_| rng.random::<f64>()) .collect();
let random_ndcg = ndcg_at_k_metric(
&y,
&random_guesses,
&sample_weight,
&group_counts_vec,
None,
&GainScheme::Burges,
);
assert!(final_ndcg > random_ndcg);
Ok(())
}
#[test]
fn test_booster_timeout() {
let (data, y) = read_data("resources/cal_housing_test.csv").unwrap();
let matrix = Matrix::new(&data, y.len(), 8);
let mut booster = PerpetualBooster::default().set_budget(2.0).set_timeout(Some(0.001)); booster.fit(&matrix, &y, None, None).unwrap();
}
#[test]
fn test_booster_constraints() {
let mut constraints = ConstraintMap::new();
constraints.insert(0, Constraint::Positive);
let mut booster = PerpetualBooster::default()
.set_budget(0.1)
.set_monotone_constraints(Some(constraints))
.set_interaction_constraints(Some(vec![vec![0, 1]]));
let data = Matrix::new(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let y = vec![1.0, 2.0];
booster.fit(&data, &y, None, None).unwrap();
}
#[test]
fn test_booster_categorical() {
let cat_features = HashSet::from([0]);
let mut booster = PerpetualBooster::default()
.set_budget(0.1)
.set_categorical_features(Some(cat_features));
let data = Matrix::new(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let y = vec![1.0, 2.0];
booster.fit(&data, &y, None, None).unwrap();
}
}