use crate::booster::config::{CalibrationMethod, MissingNodeTreatment};
use crate::objective::Objective;
use crate::{PerpetualBooster, constraints::ConstraintMap};
use std::collections::HashSet;
impl PerpetualBooster {
pub fn set_objective(mut self, objective: Objective) -> Self {
self.cfg.objective = objective;
self
}
pub fn set_budget(mut self, budget: f32) -> Self {
self.cfg.budget = budget;
self
}
pub fn set_base_score(mut self, base_score: f64) -> Self {
self.base_score = base_score;
self
}
pub fn set_max_bin(mut self, max_bin: u16) -> Self {
self.cfg.max_bin = max_bin;
self
}
pub fn set_num_threads(mut self, num_threads: Option<usize>) -> Self {
self.cfg.num_threads = num_threads;
self
}
pub fn set_monotone_constraints(mut self, monotone_constraints: Option<ConstraintMap>) -> Self {
self.cfg.monotone_constraints = monotone_constraints;
self
}
pub fn set_interaction_constraints(mut self, interaction_constraints: Option<Vec<Vec<usize>>>) -> Self {
self.cfg.interaction_constraints = interaction_constraints;
self
}
pub fn set_force_children_to_bound_parent(mut self, force_children_to_bound_parent: bool) -> Self {
self.cfg.force_children_to_bound_parent = force_children_to_bound_parent;
self
}
pub fn set_missing(mut self, missing: f64) -> Self {
self.cfg.missing = missing;
self
}
pub fn set_allow_missing_splits(mut self, allow_missing_splits: bool) -> Self {
self.cfg.allow_missing_splits = allow_missing_splits;
self
}
pub fn set_create_missing_branch(mut self, create_missing_branch: bool) -> Self {
self.cfg.create_missing_branch = create_missing_branch;
self
}
pub fn set_terminate_missing_features(mut self, terminate_missing_features: HashSet<usize>) -> Self {
self.cfg.terminate_missing_features = terminate_missing_features;
self
}
pub fn set_missing_node_treatment(mut self, missing_node_treatment: MissingNodeTreatment) -> Self {
self.cfg.missing_node_treatment = missing_node_treatment;
self
}
pub fn set_log_iterations(mut self, log_iterations: usize) -> Self {
self.cfg.log_iterations = log_iterations;
self
}
pub fn set_ref_log_iterations(mut self, log_iterations: usize) -> Self {
self.cfg.log_iterations = log_iterations;
self
}
pub fn set_seed(mut self, seed: u64) -> Self {
self.cfg.seed = seed;
self
}
pub fn set_reset(mut self, reset: Option<bool>) -> Self {
self.cfg.reset = reset;
self
}
pub fn set_categorical_features(mut self, categorical_features: Option<HashSet<usize>>) -> Self {
self.cfg.categorical_features = categorical_features;
self
}
pub fn set_timeout(mut self, timeout: Option<f32>) -> Self {
self.cfg.timeout = timeout;
self
}
pub fn set_iteration_limit(mut self, iteration_limit: Option<usize>) -> Self {
self.cfg.iteration_limit = iteration_limit;
self
}
pub fn set_memory_limit(mut self, memory_limit: Option<f32>) -> Self {
self.cfg.memory_limit = memory_limit;
self
}
pub fn set_stopping_rounds(mut self, stopping_rounds: Option<usize>) -> Self {
self.cfg.stopping_rounds = stopping_rounds;
self
}
pub fn set_save_node_stats(mut self, save_node_stats: bool) -> Self {
self.cfg.save_node_stats = save_node_stats;
self
}
pub fn set_calibration_method(mut self, calibration_method: CalibrationMethod) -> Self {
self.cfg.calibration_method = calibration_method;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::booster::config::MissingNodeTreatment;
use crate::constraints::ConstraintMap;
use crate::objective::Objective;
use std::collections::HashSet;
#[test]
fn test_setters() {
let booster = PerpetualBooster::default()
.set_objective(Objective::LogLoss)
.set_budget(2.0)
.set_base_score(0.5)
.set_max_bin(128)
.set_num_threads(Some(4))
.set_monotone_constraints(Some(ConstraintMap::new()))
.set_interaction_constraints(Some(vec![vec![0, 1]]))
.set_force_children_to_bound_parent(true)
.set_missing(0.0)
.set_allow_missing_splits(false)
.set_create_missing_branch(false)
.set_terminate_missing_features(HashSet::from([0]))
.set_missing_node_treatment(MissingNodeTreatment::AverageNodeWeight)
.set_log_iterations(10)
.set_ref_log_iterations(20)
.set_seed(42)
.set_reset(Some(true))
.set_categorical_features(Some(HashSet::from([1])))
.set_timeout(Some(100.0))
.set_iteration_limit(Some(50))
.set_memory_limit(Some(1024.0))
.set_stopping_rounds(Some(5))
.set_save_node_stats(true)
.set_calibration_method(CalibrationMethod::MinMax);
assert!(matches!(booster.cfg.objective, Objective::LogLoss));
assert_eq!(booster.cfg.budget, 2.0);
assert_eq!(booster.base_score, 0.5);
assert_eq!(booster.cfg.max_bin, 128);
assert_eq!(booster.cfg.num_threads, Some(4));
assert!(booster.cfg.monotone_constraints.is_some());
assert_eq!(booster.cfg.interaction_constraints, Some(vec![vec![0, 1]]));
assert!(booster.cfg.force_children_to_bound_parent);
assert_eq!(booster.cfg.missing, 0.0);
assert!(!booster.cfg.allow_missing_splits);
assert!(!booster.cfg.create_missing_branch);
assert!(booster.cfg.terminate_missing_features.contains(&0));
assert_eq!(
booster.cfg.missing_node_treatment,
MissingNodeTreatment::AverageNodeWeight
);
assert_eq!(booster.cfg.log_iterations, 20);
assert_eq!(booster.cfg.seed, 42);
assert_eq!(booster.cfg.reset, Some(true));
assert!(booster.cfg.categorical_features.as_ref().unwrap().contains(&1));
assert_eq!(booster.cfg.timeout, Some(100.0));
assert_eq!(booster.cfg.iteration_limit, Some(50));
assert_eq!(booster.cfg.memory_limit, Some(1024.0));
assert_eq!(booster.cfg.stopping_rounds, Some(5));
assert!(booster.cfg.save_node_stats);
assert_eq!(booster.cfg.calibration_method, CalibrationMethod::MinMax);
}
}