use crate::dataset::BinnedDataset;
use crate::defaults::linear_tree as linear_tree_defaults;
use crate::learner::{LinearConfig, TreeConfig};
use crate::tree::{NodeType, Tree};
use crate::Result;
use rkyv::{Archive, Deserialize, Serialize};
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct LinearTreeConfig {
pub tree_config: TreeConfig,
pub linear_config: LinearConfig,
pub min_samples_for_linear: usize,
}
impl Default for LinearTreeConfig {
fn default() -> Self {
Self {
tree_config: TreeConfig::default()
.with_max_depth(linear_tree_defaults::LINEAR_TREE_MAX_DEPTH) .with_max_leaves(linear_tree_defaults::LINEAR_TREE_MAX_LEAVES)
.with_min_samples_leaf(linear_tree_defaults::LINEAR_TREE_MIN_SAMPLES_LEAF), linear_config: LinearConfig::default()
.with_lambda(linear_tree_defaults::LINEAR_TREE_LAMBDA)
.with_max_iter(linear_tree_defaults::LINEAR_TREE_MAX_ITER),
min_samples_for_linear: linear_tree_defaults::MIN_SAMPLES_FOR_LINEAR,
}
}
}
impl LinearTreeConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_tree_config(mut self, config: TreeConfig) -> Self {
self.tree_config = config;
self
}
pub fn with_linear_config(mut self, config: LinearConfig) -> Self {
self.linear_config = config;
self
}
pub fn with_min_samples_for_linear(mut self, min_samples: usize) -> Self {
self.min_samples_for_linear = min_samples.max(2);
self
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct LeafLinearModel {
pub weights: Vec<f32>,
pub bias: f32,
pub is_linear: bool,
pub constant: f32,
}
impl LeafLinearModel {
pub fn constant(value: f32) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
is_linear: false,
constant: value,
}
}
pub fn linear(weights: Vec<f32>, bias: f32) -> Self {
Self {
weights,
bias,
is_linear: true,
constant: 0.0,
}
}
#[inline]
pub fn predict(&self, features: &[f32]) -> f32 {
if !self.is_linear {
return self.constant;
}
let mut pred = self.bias;
for (i, &w) in self.weights.iter().enumerate() {
if i < features.len() {
pred += w * features[i];
}
}
pred
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct LinearTreeBooster {
tree: Option<Tree>,
leaf_models: Vec<(usize, LeafLinearModel)>,
config: LinearTreeConfig,
num_features: usize,
feature_means: Vec<f32>,
feature_stds: Vec<f32>,
}
impl LinearTreeBooster {
pub fn new(config: LinearTreeConfig) -> Self {
Self {
tree: None,
leaf_models: Vec::new(),
config,
num_features: 0,
feature_means: Vec::new(),
feature_stds: Vec::new(),
}
}
pub fn config(&self) -> &LinearTreeConfig {
&self.config
}
pub fn is_fitted(&self) -> bool {
self.tree.is_some()
}
pub fn tree(&self) -> Option<&Tree> {
self.tree.as_ref()
}
pub fn num_leaf_models(&self) -> usize {
self.leaf_models.len()
}
pub fn fit_on_gradients(
&mut self,
dataset: &BinnedDataset,
raw_features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
) -> Result<()> {
let num_rows = dataset.num_rows();
self.num_features = num_features;
self.compute_feature_stats(raw_features, num_features, num_rows);
let grower = self.config.tree_config.build_grower(num_features, None);
let tree = grower.grow(dataset, gradients, hessians);
let leaf_assignments = self.assign_samples_to_leaves(&tree, dataset, num_rows);
self.fit_leaf_models(
&tree,
&leaf_assignments,
raw_features,
num_features,
gradients,
hessians,
);
self.tree = Some(tree);
Ok(())
}
fn compute_feature_stats(&mut self, features: &[f32], num_features: usize, num_rows: usize) {
self.feature_means = vec![0.0; num_features];
self.feature_stds = vec![1.0; num_features];
if num_rows == 0 {
return;
}
for j in 0..num_features {
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for i in 0..num_rows {
let val = features[i * num_features + j] as f64;
sum += val;
sum_sq += val * val;
}
let mean = sum / num_rows as f64;
let variance = (sum_sq / num_rows as f64) - mean * mean;
let std = variance.max(0.0).sqrt();
self.feature_means[j] = mean as f32;
self.feature_stds[j] = if std > 1e-10 { std as f32 } else { 1.0 };
}
}
#[inline]
fn standardize(&self, value: f32, feature_idx: usize) -> f32 {
(value - self.feature_means[feature_idx]) / self.feature_stds[feature_idx]
}
fn assign_samples_to_leaves(
&self,
tree: &Tree,
dataset: &BinnedDataset,
num_rows: usize,
) -> Vec<Vec<usize>> {
let num_nodes = tree.num_nodes();
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
for row_idx in 0..num_rows {
let leaf_idx = self.find_leaf_index(tree, dataset, row_idx);
assignments[leaf_idx].push(row_idx);
}
assignments
}
fn find_leaf_index(&self, tree: &Tree, dataset: &BinnedDataset, row_idx: usize) -> usize {
let mut node_idx = 0;
loop {
let node = tree.get_node(node_idx);
match node.node_type {
NodeType::Leaf { .. } => return node_idx,
NodeType::Internal {
feature_idx,
bin_threshold,
left_child,
right_child,
..
} => {
let bin = dataset.get_bin(row_idx, feature_idx);
node_idx = if bin <= bin_threshold {
left_child
} else {
right_child
};
}
}
}
}
fn fit_leaf_models(
&mut self,
tree: &Tree,
leaf_assignments: &[Vec<usize>],
raw_features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
) {
self.leaf_models.clear();
for (node_idx, sample_indices) in leaf_assignments.iter().enumerate() {
let node = tree.get_node(node_idx);
if !node.is_leaf() {
continue;
}
let num_samples = sample_indices.len();
let default_value = match node.node_type {
NodeType::Leaf { value } => value,
_ => 0.0,
};
if num_samples < self.config.min_samples_for_linear {
self.leaf_models
.push((node_idx, LeafLinearModel::constant(default_value)));
continue;
}
let model = self.fit_ridge_in_leaf(
sample_indices,
raw_features,
num_features,
gradients,
hessians,
default_value,
);
self.leaf_models.push((node_idx, model));
}
}
fn fit_ridge_in_leaf(
&self,
sample_indices: &[usize],
raw_features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
default_value: f32,
) -> LeafLinearModel {
let n = sample_indices.len();
let lambda = self.config.linear_config.lambda;
let max_iter = self.config.linear_config.max_iter;
let tol = self.config.linear_config.tol;
let shrinkage = self.config.linear_config.shrinkage_factor;
let mut weights = vec![0.0f32; num_features];
let mut bias = default_value;
let mut std_features: Vec<Vec<f32>> = Vec::with_capacity(n);
for &idx in sample_indices {
let mut row = Vec::with_capacity(num_features);
for j in 0..num_features {
let val = raw_features[idx * num_features + j];
row.push(self.standardize(val, j));
}
std_features.push(row);
}
for _iter in 0..max_iter {
let mut max_change = 0.0f32;
{
let mut grad_bias = 0.0f32;
let mut hess_bias = 0.0f32;
for (local_idx, &global_idx) in sample_indices.iter().enumerate() {
let h = hessians[global_idx];
let g = gradients[global_idx];
let mut pred = bias;
for (j, &w) in weights.iter().enumerate() {
pred += w * std_features[local_idx][j];
}
let residual = pred + g / h.max(1e-10);
grad_bias += h * residual;
hess_bias += h;
}
let delta = -grad_bias / (hess_bias + lambda);
let delta = delta.clamp(-10.0, 10.0);
bias += shrinkage * delta;
max_change = max_change.max(delta.abs());
}
for j in 0..num_features {
let mut grad_j = 0.0f32;
let mut hess_j = 0.0f32;
for (local_idx, &global_idx) in sample_indices.iter().enumerate() {
let h = hessians[global_idx];
let g = gradients[global_idx];
let x_j = std_features[local_idx][j];
let mut pred = bias;
for (k, &w) in weights.iter().enumerate() {
pred += w * std_features[local_idx][k];
}
let residual = pred + g / h.max(1e-10);
grad_j += h * residual * x_j;
hess_j += h * x_j * x_j;
}
grad_j += lambda * weights[j];
let delta = -grad_j / (hess_j + lambda);
let delta = delta.clamp(-10.0, 10.0);
weights[j] += shrinkage * delta;
max_change = max_change.max(delta.abs());
}
if max_change < tol {
break;
}
}
let mut bias_orig = bias;
let weights_orig: Vec<f32> = weights
.iter()
.zip(self.feature_means.iter())
.zip(self.feature_stds.iter())
.map(|((&w, &mean), &std)| {
bias_orig -= w * mean / std;
w / std
})
.collect();
LeafLinearModel::linear(weights_orig, bias_orig)
}
pub fn predict_batch(
&self,
dataset: &BinnedDataset,
raw_features: &[f32],
num_features: usize,
) -> Vec<f32> {
let tree = match &self.tree {
Some(t) => t,
None => return vec![0.0; dataset.num_rows()],
};
let num_rows = dataset.num_rows();
let mut predictions = Vec::with_capacity(num_rows);
let mut leaf_lookup: std::collections::HashMap<usize, &LeafLinearModel> =
std::collections::HashMap::new();
for (node_idx, model) in &self.leaf_models {
leaf_lookup.insert(*node_idx, model);
}
for row_idx in 0..num_rows {
let leaf_idx = self.find_leaf_index(tree, dataset, row_idx);
let pred = if let Some(model) = leaf_lookup.get(&leaf_idx) {
let row_features: Vec<f32> = (0..num_features)
.map(|j| raw_features[row_idx * num_features + j])
.collect();
model.predict(&row_features)
} else {
tree.predict_row(dataset, row_idx)
};
predictions.push(pred);
}
predictions
}
pub fn predict_row(
&self,
dataset: &BinnedDataset,
raw_features: &[f32],
num_features: usize,
row_idx: usize,
) -> f32 {
let tree = match &self.tree {
Some(t) => t,
None => return 0.0,
};
let leaf_idx = self.find_leaf_index(tree, dataset, row_idx);
for (node_idx, model) in &self.leaf_models {
if *node_idx == leaf_idx {
let row_features: Vec<f32> = (0..num_features)
.map(|j| raw_features[row_idx * num_features + j])
.collect();
return model.predict(&row_features);
}
}
tree.predict_row(dataset, row_idx)
}
pub fn predict_batch_add(
&self,
dataset: &BinnedDataset,
raw_features: &[f32],
num_features: usize,
predictions: &mut [f32],
) {
let batch_preds = self.predict_batch(dataset, raw_features, num_features);
for (i, p) in batch_preds.into_iter().enumerate() {
predictions[i] += p;
}
}
pub fn num_params(&self) -> usize {
let tree_params = self.tree.as_ref().map(|t| t.num_leaves()).unwrap_or(0);
let linear_params: usize = self
.leaf_models
.iter()
.map(|(_, m)| if m.is_linear { m.weights.len() + 1 } else { 1 })
.sum();
tree_params + linear_params
}
pub fn reset(&mut self) {
self.tree = None;
self.leaf_models.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{FeatureInfo, FeatureType};
fn create_test_dataset(num_rows: usize, num_features: usize) -> BinnedDataset {
let mut features = Vec::with_capacity(num_rows * num_features);
for f in 0..num_features {
for r in 0..num_rows {
features.push(((r * 3 + f * 7) % 256) as u8);
}
}
let targets: Vec<f32> = (0..num_rows).map(|i| (i as f32) * 0.1).collect();
let feature_info = (0..num_features)
.map(|i| FeatureInfo {
name: format!("f{}", i),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: (0..255).map(|b| b as f64).collect(),
})
.collect();
BinnedDataset::new(num_rows, features, targets, feature_info)
}
fn create_raw_features(num_rows: usize, num_features: usize) -> Vec<f32> {
let mut features = Vec::with_capacity(num_rows * num_features);
for r in 0..num_rows {
for f in 0..num_features {
features.push(((r * 3 + f * 7) % 256) as f32);
}
}
features
}
#[test]
fn test_linear_tree_config_defaults() {
let config = LinearTreeConfig::default();
assert_eq!(config.tree_config.max_depth, 4);
assert_eq!(config.min_samples_for_linear, 10);
}
#[test]
fn test_linear_tree_config_builder() {
let config = LinearTreeConfig::new()
.with_min_samples_for_linear(20)
.with_tree_config(TreeConfig::default().with_max_depth(3));
assert_eq!(config.min_samples_for_linear, 20);
assert_eq!(config.tree_config.max_depth, 3);
}
#[test]
fn test_linear_tree_booster_creation() {
let config = LinearTreeConfig::default();
let booster = LinearTreeBooster::new(config);
assert!(!booster.is_fitted());
assert!(booster.tree().is_none());
}
#[test]
fn test_leaf_linear_model_constant() {
let model = LeafLinearModel::constant(5.0);
assert!(!model.is_linear);
assert_eq!(model.predict(&[1.0, 2.0, 3.0]), 5.0);
}
#[test]
fn test_leaf_linear_model_linear() {
let model = LeafLinearModel::linear(vec![1.0, 2.0], 0.5);
assert!(model.is_linear);
assert!((model.predict(&[1.0, 2.0]) - 5.5).abs() < 1e-5);
}
#[test]
fn test_linear_tree_booster_fit() {
let dataset = create_test_dataset(100, 3);
let raw_features = create_raw_features(100, 3);
let gradients: Vec<f32> = (0..100).map(|i| -(i as f32) * 0.1).collect();
let hessians = vec![1.0; 100];
let config = LinearTreeConfig::default().with_min_samples_for_linear(5);
let mut booster = LinearTreeBooster::new(config);
booster
.fit_on_gradients(&dataset, &raw_features, 3, &gradients, &hessians)
.unwrap();
assert!(booster.is_fitted());
assert!(booster.tree().is_some());
assert!(booster.num_leaf_models() > 0);
}
#[test]
fn test_linear_tree_booster_predict() {
let dataset = create_test_dataset(100, 3);
let raw_features = create_raw_features(100, 3);
let gradients: Vec<f32> = (0..100).map(|i| -(i as f32) * 0.1).collect();
let hessians = vec![1.0; 100];
let config = LinearTreeConfig::default().with_min_samples_for_linear(5);
let mut booster = LinearTreeBooster::new(config);
booster
.fit_on_gradients(&dataset, &raw_features, 3, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&dataset, &raw_features, 3);
assert_eq!(predictions.len(), 100);
assert!(predictions.iter().all(|p| p.is_finite()));
}
#[test]
fn test_linear_tree_single_row_matches_batch() {
let dataset = create_test_dataset(50, 3);
let raw_features = create_raw_features(50, 3);
let gradients: Vec<f32> = (0..50).map(|i| -(i as f32) * 0.1).collect();
let hessians = vec![1.0; 50];
let config = LinearTreeConfig::default().with_min_samples_for_linear(3);
let mut booster = LinearTreeBooster::new(config);
booster
.fit_on_gradients(&dataset, &raw_features, 3, &gradients, &hessians)
.unwrap();
let batch_preds = booster.predict_batch(&dataset, &raw_features, 3);
for i in 0..50 {
let single_pred = booster.predict_row(&dataset, &raw_features, 3, i);
assert!(
(batch_preds[i] - single_pred).abs() < 1e-5,
"Mismatch at row {}: batch={}, single={}",
i,
batch_preds[i],
single_pred
);
}
}
#[test]
fn test_linear_tree_booster_reset() {
let dataset = create_test_dataset(50, 3);
let raw_features = create_raw_features(50, 3);
let gradients = vec![-1.0; 50];
let hessians = vec![1.0; 50];
let config = LinearTreeConfig::default();
let mut booster = LinearTreeBooster::new(config);
booster
.fit_on_gradients(&dataset, &raw_features, 3, &gradients, &hessians)
.unwrap();
assert!(booster.is_fitted());
booster.reset();
assert!(!booster.is_fitted());
}
}