use rand::prelude::*;
use std::f64;
#[derive(Debug, Clone)]
pub enum XGBObjective {
RegSquareError,
BinaryLogistic,
}
#[derive(Debug, Clone)]
pub struct XGBConfig {
pub n_estimators: usize,
pub max_depth: usize,
pub lambda: f64,
pub alpha: f64,
pub gamma: f64,
pub min_child_weight: f64,
pub subsample: f64,
pub colsample_bytree: f64,
pub learning_rate: f64,
pub seed: Option<u64>,
}
#[derive(Debug)]
pub struct XGBoostModel {
pub objective: XGBObjective,
pub trees: Vec<XGBTree>,
pub config: XGBConfig,
pub base_score: f64,
}
#[derive(Debug, Clone)]
enum XGBTreeNode {
Leaf(f64),
Internal {
feature_index: usize,
threshold: f64,
left_child: Box<XGBTreeNode>,
right_child: Box<XGBTreeNode>,
},
}
#[derive(Debug, Clone)]
pub struct XGBTree {
root: XGBTreeNode,
}
impl XGBTree {
pub fn predict_one(&self, sample: &[f64]) -> f64 {
traverse(&self.root, sample)
}
}
impl XGBoostModel {
pub fn new(objective: XGBObjective, config: XGBConfig) -> Self {
Self {
objective,
trees: Vec::new(),
config,
base_score: 0.0,
}
}
pub fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) {
let n = x.len();
if n == 0 {
panic!("No training data provided to XGBoostModel.");
}
if y.len() != n {
panic!("Features and labels must match in length.");
}
if let XGBObjective::BinaryLogistic = self.objective {
for &lbl in y {
if !(0.0..=1.0).contains(&lbl) {
panic!("BinaryLogistic expects labels in [0,1], got {}", lbl);
}
}
}
match self.objective {
XGBObjective::RegSquareError => {
self.base_score = mean(y);
}
XGBObjective::BinaryLogistic => {
self.base_score = 0.0; }
}
let mut rng = match self.config.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let mut preds = vec![self.base_score; n];
self.trees.clear();
for _round in 0..self.config.n_estimators {
let (grad, hess) = match self.objective {
XGBObjective::RegSquareError => {
let mut g = vec![0.0; n];
let hh = vec![1.0; n];
for i in 0..n {
g[i] = preds[i] - y[i];
}
(g, hh)
}
XGBObjective::BinaryLogistic => {
let mut g = vec![0.0; n];
let mut hh = vec![0.0; n];
for i in 0..n {
let p = 1.0 / (1.0 + (-preds[i]).exp());
g[i] = p - y[i];
hh[i] = p * (1.0 - p);
}
(g, hh)
}
};
let (sample_mask, col_mask) = subsample_masks(
n,
x[0].len(),
self.config.subsample,
self.config.colsample_bytree,
&mut rng,
);
let tree = build_xgb_tree(x, &grad, &hess, &sample_mask, &col_mask, &self.config, 0);
self.trees.push(XGBTree { root: tree.clone() });
for i in 0..n {
if sample_mask[i] {
let incr = traverse(&tree, &x[i]) * self.config.learning_rate;
preds[i] += incr;
}
}
}
}
pub fn predict_one(&self, sample: &[f64]) -> f64 {
let score = self.decision_function_one(sample);
match self.objective {
XGBObjective::RegSquareError => score,
XGBObjective::BinaryLogistic => {
let p = 1.0 / (1.0 + (-score).exp());
if p >= 0.5 {
1.0
} else {
0.0
}
}
}
}
pub fn decision_function_one(&self, sample: &[f64]) -> f64 {
let mut sum_val = self.base_score;
for tree in &self.trees {
sum_val += tree.predict_one(sample) * self.config.learning_rate;
}
sum_val
}
pub fn predict_batch(&self, data: &[Vec<f64>]) -> Vec<f64> {
data.iter().map(|row| self.predict_one(row)).collect()
}
}
#[derive(Clone, Debug)]
struct XGBNodeSplit {
feature_index: usize,
threshold: f64,
left_index: Vec<usize>,
right_index: Vec<usize>,
gain: f64,
}
struct SplitParams<'a> {
x: &'a [Vec<f64>],
grad: &'a [f64],
hess: &'a [f64],
indices: &'a [usize],
feat_idx: usize,
lambda: f64,
alpha: f64,
g_node: f64,
h_node: f64,
min_child_weight: f64,
col_mask: &'a [bool],
}
fn build_xgb_tree(
x: &[Vec<f64>],
grad: &[f64],
hess: &[f64],
sample_mask: &[bool],
col_mask: &[bool],
config: &XGBConfig,
depth: usize,
) -> XGBTreeNode {
let mut indices = Vec::new();
for (i, &m) in sample_mask.iter().enumerate() {
if m {
indices.push(i);
}
}
if indices.is_empty() || depth >= config.max_depth {
let leaf_val = compute_leaf_weight(grad, hess, &indices, config);
return XGBTreeNode::Leaf(leaf_val);
}
let (g_node, h_node) = sum_grad_hess(grad, hess, &indices);
if h_node < config.min_child_weight {
let leaf_val = calc_gamma(g_node, h_node, config.lambda, config.alpha);
return XGBTreeNode::Leaf(leaf_val);
}
let best_split = find_best_xgb_split(SplitParams {
x,
grad,
hess,
indices: &indices,
feat_idx: 0,
lambda: config.lambda,
alpha: config.alpha,
g_node,
h_node,
min_child_weight: config.min_child_weight,
col_mask,
});
match best_split {
None => {
let leaf_val = calc_gamma(g_node, h_node, config.lambda, config.alpha);
XGBTreeNode::Leaf(leaf_val)
}
Some(sp) => {
if sp.gain < config.gamma {
let leaf_val = calc_gamma(g_node, h_node, config.lambda, config.alpha);
return XGBTreeNode::Leaf(leaf_val);
}
let mut left_mask = vec![false; sample_mask.len()];
for &i in sp.left_index.iter() {
left_mask[i] = true;
}
let mut right_mask = vec![false; sample_mask.len()];
for &i in sp.right_index.iter() {
right_mask[i] = true;
}
let left_child = build_xgb_tree(x, grad, hess, &left_mask, col_mask, config, depth + 1);
let right_child =
build_xgb_tree(x, grad, hess, &right_mask, col_mask, config, depth + 1);
XGBTreeNode::Internal {
feature_index: sp.feature_index,
threshold: sp.threshold,
left_child: Box::new(left_child),
right_child: Box::new(right_child),
}
}
}
}
fn find_best_xgb_split(params: SplitParams) -> Option<XGBNodeSplit> {
let mut best: Option<XGBNodeSplit> = None;
let base_score = calc_gain(params.g_node, params.h_node, params.lambda, params.alpha);
for (feat_idx, &use_col) in params.col_mask.iter().enumerate() {
if !use_col {
continue;
}
let split_params = SplitParams {
x: params.x,
grad: params.grad,
hess: params.hess,
indices: params.indices,
feat_idx,
lambda: params.lambda,
alpha: params.alpha,
g_node: params.g_node,
h_node: params.h_node,
min_child_weight: params.min_child_weight,
col_mask: params.col_mask,
};
if let Some(split) =
find_best_split_for_feature(split_params, base_score, params.min_child_weight)
{
if best.is_none() || split.gain > best.as_ref().unwrap().gain {
best = Some(split);
}
}
}
best
}
fn find_best_split_for_feature(
params: SplitParams,
base_score: f64,
min_child_weight: f64,
) -> Option<XGBNodeSplit> {
let mut vals = Vec::with_capacity(params.indices.len());
for &i in params.indices.iter() {
vals.push((params.x[i][params.feat_idx], i));
}
vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut best_gain = 0.0;
let mut best_split = None;
let mut g_left = 0.0;
let mut h_left = 0.0;
let mut left_idx = Vec::new();
for i in 0..vals.len() - 1 {
let (v, idx) = vals[i];
g_left += params.grad[idx];
h_left += params.hess[idx];
left_idx.push(idx);
let next_val = vals[i + 1].0;
if (v - next_val).abs() > f64::EPSILON {
let g_right = params.g_node - g_left;
let h_right = params.h_node - h_left;
if h_left >= min_child_weight && h_right >= min_child_weight {
let gain_left = calc_gain(g_left, h_left, params.lambda, params.alpha);
let gain_right = calc_gain(g_right, h_right, params.lambda, params.alpha);
let gain = gain_left + gain_right - base_score;
if gain > best_gain {
best_gain = gain;
let right_idx: Vec<usize> =
vals[(i + 1)..].iter().map(|(_, idx)| *idx).collect();
best_split = Some(XGBNodeSplit {
feature_index: params.feat_idx,
threshold: (v + next_val) / 2.0,
left_index: left_idx.clone(),
right_index: right_idx,
gain,
});
}
}
}
}
best_split
}
fn subsample_masks(
n_rows: usize,
n_cols: usize,
subsample_ratio: f64,
colsample_ratio: f64,
rng: &mut impl Rng,
) -> (Vec<bool>, Vec<bool>) {
let mut row_mask = vec![false; n_rows];
let sample_size = (subsample_ratio * n_rows as f64).ceil() as usize;
if sample_size >= n_rows {
for mask in row_mask.iter_mut().take(n_rows) {
*mask = true;
}
} else {
let mut indices: Vec<usize> = (0..n_rows).collect();
indices.shuffle(rng);
for i in 0..sample_size.min(n_rows) {
row_mask[indices[i]] = true;
}
}
let mut col_mask = vec![false; n_cols];
let col_sample_size = (colsample_ratio * n_cols as f64).ceil() as usize;
if col_sample_size >= n_cols {
for mask in col_mask.iter_mut().take(n_cols) {
*mask = true;
}
} else {
let mut indices: Vec<usize> = (0..n_cols).collect();
indices.shuffle(rng);
for i in 0..col_sample_size.min(n_cols) {
col_mask[indices[i]] = true;
}
}
(row_mask, col_mask)
}
fn sum_grad_hess(grad: &[f64], hess: &[f64], indices: &[usize]) -> (f64, f64) {
let mut g = 0.0;
let mut hh = 0.0;
for &i in indices {
g += grad[i];
hh += hess[i];
}
(g, hh)
}
fn compute_leaf_weight(grad: &[f64], hess: &[f64], indices: &[usize], cfg: &XGBConfig) -> f64 {
let (g, hh) = sum_grad_hess(grad, hess, indices);
calc_gamma(g, hh, cfg.lambda, cfg.alpha)
}
fn calc_gamma(g: f64, h: f64, lambda: f64, alpha: f64) -> f64 {
if h.abs() < f64::EPSILON {
return 0.0;
}
let sign_g = if g > 0.0 { 1.0 } else { -1.0 };
let abs_g = g.abs();
let res = (abs_g - alpha).max(0.0) / (h + lambda);
-sign_g * res
}
fn calc_gain(g: f64, h: f64, lambda: f64, _alpha: f64) -> f64 {
if h.abs() < f64::EPSILON {
return 0.0;
}
0.5 * ((g * g) / (h + lambda))
}
fn traverse(node: &XGBTreeNode, sample: &[f64]) -> f64 {
match node {
XGBTreeNode::Leaf(w) => *w,
XGBTreeNode::Internal {
feature_index,
threshold,
left_child,
right_child,
} => {
if sample[*feature_index] <= *threshold {
traverse(left_child, sample)
} else {
traverse(right_child, sample)
}
}
}
}
fn mean(arr: &[f64]) -> f64 {
if arr.is_empty() {
0.0
} else {
arr.iter().sum::<f64>() / (arr.len() as f64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xgb_reg_square_error() {
let x = vec![
vec![1.0, 1.0],
vec![2.0, 2.0],
vec![3.0, 1.0],
vec![0.0, 0.0],
vec![4.0, 2.0],
];
let y: Vec<f64> = x.iter().map(|row| row[0] + 2.0 * row[1]).collect();
let config = XGBConfig {
n_estimators: 100,
max_depth: 3,
lambda: 1.0,
alpha: 0.0,
gamma: 0.0,
min_child_weight: 0.1,
subsample: 1.0,
colsample_bytree: 1.0,
learning_rate: 0.3,
seed: Some(42),
};
let mut model = XGBoostModel::new(XGBObjective::RegSquareError, config);
model.fit(&x, &y);
for i in 0..x.len() {
let pred = model.predict_one(&x[i]);
let err = (pred - y[i]).abs();
assert!(err < 1.0, "Prediction error is too large: err={}", err);
}
}
#[test]
fn test_xgb_binary_logistic() {
let x = vec![
vec![0.0, 0.0], vec![5.0, 5.0], vec![0.0, 1.0], vec![5.0, 4.0], vec![0.5, 0.5], ];
let y: Vec<f64> = x
.iter()
.map(|row| if row[0] + row[1] > 3.0 { 1.0 } else { 0.0 })
.collect();
let config = XGBConfig {
n_estimators: 100,
max_depth: 3,
lambda: 1.0,
alpha: 0.0,
gamma: 0.0,
min_child_weight: 0.1,
subsample: 1.0,
colsample_bytree: 1.0,
learning_rate: 0.3,
seed: Some(123),
};
let mut model = XGBoostModel::new(XGBObjective::BinaryLogistic, config);
model.fit(&x, &y);
for i in 0..x.len() {
let pred = model.predict_one(&x[i]);
let truth = y[i];
let is_correct = (pred - truth).abs() < 0.5;
assert!(
is_correct,
"Wrong classification for row {} => pred={}",
i, pred
);
}
}
}