pub mod histogram;
pub mod learner;
pub mod split;
use serde::{Deserialize, Serialize};
pub use learner::TreeLearner;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MissingDir {
Left = 0,
Right = 1,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct SplitNode {
pub feature: u32,
pub threshold_bin: u16,
pub missing_dir: MissingDir,
pub left_child: i32,
pub right_child: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tree {
pub nodes: Vec<SplitNode>,
pub node_thresholds: Vec<f64>,
pub node_gains: Vec<f64>,
pub leaf_values: Vec<f64>,
}
impl Tree {
pub fn constant(value: f64) -> Self {
Self {
nodes: Vec::new(),
node_thresholds: Vec::new(),
node_gains: Vec::new(),
leaf_values: vec![value],
}
}
pub fn predict_raw(&self, row: &[f64]) -> f64 {
if self.nodes.is_empty() {
return self.leaf_values[0];
}
let mut node_idx: i32 = 0;
loop {
let i = node_idx as usize;
let node = &self.nodes[i];
let threshold = self.node_thresholds[i];
let v = row[node.feature as usize];
let go_left = if !v.is_finite() {
matches!(node.missing_dir, MissingDir::Left)
} else {
v <= threshold
};
let next = if go_left {
node.left_child
} else {
node.right_child
};
if next < 0 {
return self.leaf_values[(!next) as usize];
}
node_idx = next;
}
}
pub fn predict_on_dataset(&self, dataset: &crate::dataset::Dataset, row: usize) -> f64 {
use crate::dataset::with_columns;
let feats: Vec<usize> = (0..dataset.n_features()).collect();
with_columns!(dataset, feats, |cols| { self.predict_on_columns(&cols, row) })
}
#[inline]
pub fn predict_on_columns<B: crate::dataset::Bin>(
&self,
columns: &[&[B]],
row: usize,
) -> f64 {
if self.nodes.is_empty() {
return self.leaf_values[0];
}
let mut node_idx: i32 = 0;
unsafe {
loop {
let node = self.nodes.get_unchecked(node_idx as usize);
let col = *columns.get_unchecked(node.feature as usize);
let bin = *col.get_unchecked(row);
let go_left = if bin == B::MISSING {
matches!(node.missing_dir, MissingDir::Left)
} else {
bin.as_usize() <= node.threshold_bin as usize
};
let next = if go_left {
node.left_child
} else {
node.right_child
};
if next < 0 {
return *self.leaf_values.get_unchecked((!next) as usize);
}
node_idx = next;
}
}
}
}