use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub u32);
impl NodeId {
pub const NONE: NodeId = NodeId(u32::MAX);
#[inline]
pub fn is_none(self) -> bool {
self.0 == u32::MAX
}
#[inline]
pub fn idx(self) -> usize {
self.0 as usize
}
}
#[derive(Debug, Clone)]
pub struct TreeArena {
pub feature_idx: Vec<u32>,
pub threshold: Vec<f64>,
pub left: Vec<NodeId>,
pub right: Vec<NodeId>,
pub leaf_value: Vec<f64>,
pub is_leaf: Vec<bool>,
pub depth: Vec<u16>,
pub sample_count: Vec<u64>,
pub categorical_mask: Vec<Option<u64>>,
}
impl TreeArena {
pub fn new() -> Self {
Self {
feature_idx: Vec::new(),
threshold: Vec::new(),
left: Vec::new(),
right: Vec::new(),
leaf_value: Vec::new(),
is_leaf: Vec::new(),
depth: Vec::new(),
sample_count: Vec::new(),
categorical_mask: Vec::new(),
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
feature_idx: Vec::with_capacity(cap),
threshold: Vec::with_capacity(cap),
left: Vec::with_capacity(cap),
right: Vec::with_capacity(cap),
leaf_value: Vec::with_capacity(cap),
is_leaf: Vec::with_capacity(cap),
depth: Vec::with_capacity(cap),
sample_count: Vec::with_capacity(cap),
categorical_mask: Vec::with_capacity(cap),
}
}
pub fn add_leaf(&mut self, depth: u16) -> NodeId {
let id = self.feature_idx.len() as u32;
self.feature_idx.push(0);
self.threshold.push(0.0);
self.left.push(NodeId::NONE);
self.right.push(NodeId::NONE);
self.leaf_value.push(0.0);
self.is_leaf.push(true);
self.depth.push(depth);
self.sample_count.push(0);
self.categorical_mask.push(None);
NodeId(id)
}
pub fn split_leaf(
&mut self,
leaf_id: NodeId,
feature_idx: u32,
threshold: f64,
left_value: f64,
right_value: f64,
) -> (NodeId, NodeId) {
let i = leaf_id.idx();
assert!(
self.is_leaf[i],
"split_leaf called on non-leaf node {:?}",
leaf_id
);
let child_depth = self.depth[i] + 1;
let left_id = self.add_leaf(child_depth);
self.leaf_value[left_id.idx()] = left_value;
let right_id = self.add_leaf(child_depth);
self.leaf_value[right_id.idx()] = right_value;
self.is_leaf[i] = false;
self.feature_idx[i] = feature_idx;
self.threshold[i] = threshold;
self.left[i] = left_id;
self.right[i] = right_id;
(left_id, right_id)
}
pub fn split_leaf_categorical(
&mut self,
leaf_id: NodeId,
feature_idx: u32,
threshold: f64,
left_value: f64,
right_value: f64,
mask: u64,
) -> (NodeId, NodeId) {
let (left_id, right_id) =
self.split_leaf(leaf_id, feature_idx, threshold, left_value, right_value);
self.categorical_mask[leaf_id.idx()] = Some(mask);
(left_id, right_id)
}
#[inline]
pub fn get_categorical_mask(&self, id: NodeId) -> Option<u64> {
self.categorical_mask[id.idx()]
}
#[inline]
pub fn is_leaf(&self, id: NodeId) -> bool {
self.is_leaf[id.idx()]
}
#[inline]
pub fn predict(&self, id: NodeId) -> f64 {
let i = id.idx();
assert!(self.is_leaf[i], "predict called on internal node {:?}", id);
self.leaf_value[i]
}
#[inline]
pub fn set_leaf_value(&mut self, id: NodeId, value: f64) {
let i = id.idx();
assert!(
self.is_leaf[i],
"set_leaf_value called on internal node {:?}",
id
);
self.leaf_value[i] = value;
}
#[inline]
pub fn get_depth(&self, id: NodeId) -> u16 {
self.depth[id.idx()]
}
#[inline]
pub fn get_feature_idx(&self, id: NodeId) -> u32 {
self.feature_idx[id.idx()]
}
#[inline]
pub fn get_threshold(&self, id: NodeId) -> f64 {
self.threshold[id.idx()]
}
#[inline]
pub fn get_left(&self, id: NodeId) -> NodeId {
self.left[id.idx()]
}
#[inline]
pub fn get_right(&self, id: NodeId) -> NodeId {
self.right[id.idx()]
}
#[inline]
pub fn get_sample_count(&self, id: NodeId) -> u64 {
self.sample_count[id.idx()]
}
#[inline]
pub fn increment_sample_count(&mut self, id: NodeId) {
self.sample_count[id.idx()] += 1;
}
#[inline]
pub fn n_nodes(&self) -> usize {
self.is_leaf.len()
}
pub fn n_leaves(&self) -> usize {
self.is_leaf.iter().filter(|&&b| b).count()
}
pub fn reset(&mut self) {
self.feature_idx.clear();
self.threshold.clear();
self.left.clear();
self.right.clear();
self.leaf_value.clear();
self.is_leaf.clear();
self.depth.clear();
self.sample_count.clear();
self.categorical_mask.clear();
}
}
impl Default for TreeArena {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_leaf() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
assert_eq!(root, NodeId(0));
assert!(arena.is_leaf(root));
assert_eq!(arena.predict(root), 0.0);
assert_eq!(arena.get_depth(root), 0);
assert_eq!(arena.get_sample_count(root), 0);
assert_eq!(arena.get_left(root), NodeId::NONE);
assert_eq!(arena.get_right(root), NodeId::NONE);
}
#[test]
fn split_leaf_basic() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let (left, right) = arena.split_leaf(root, 3, 1.5, -0.25, 0.75);
assert!(!arena.is_leaf(root));
assert_eq!(arena.get_feature_idx(root), 3);
assert_eq!(arena.get_threshold(root), 1.5);
assert_eq!(arena.get_left(root), left);
assert_eq!(arena.get_right(root), right);
assert!(arena.is_leaf(left));
assert_eq!(arena.predict(left), -0.25);
assert_eq!(arena.get_depth(left), 1);
assert!(arena.is_leaf(right));
assert_eq!(arena.predict(right), 0.75);
assert_eq!(arena.get_depth(right), 1);
}
#[test]
fn split_child_three_levels() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let (left, right) = arena.split_leaf(root, 0, 5.0, 0.0, 0.0);
let (ll, lr) = arena.split_leaf(left, 1, 2.0, -1.0, 1.0);
assert!(!arena.is_leaf(root));
assert_eq!(arena.get_depth(root), 0);
assert!(!arena.is_leaf(left));
assert_eq!(arena.get_depth(left), 1);
assert_eq!(arena.get_feature_idx(left), 1);
assert_eq!(arena.get_threshold(left), 2.0);
assert_eq!(arena.get_left(left), ll);
assert_eq!(arena.get_right(left), lr);
assert!(arena.is_leaf(right));
assert_eq!(arena.get_depth(right), 1);
assert!(arena.is_leaf(ll));
assert_eq!(arena.get_depth(ll), 2);
assert_eq!(arena.predict(ll), -1.0);
assert!(arena.is_leaf(lr));
assert_eq!(arena.get_depth(lr), 2);
assert_eq!(arena.predict(lr), 1.0);
}
#[test]
fn node_and_leaf_counting() {
let mut arena = TreeArena::new();
assert_eq!(arena.n_nodes(), 0);
assert_eq!(arena.n_leaves(), 0);
let root = arena.add_leaf(0);
assert_eq!(arena.n_nodes(), 1);
assert_eq!(arena.n_leaves(), 1);
let (_left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
assert_eq!(arena.n_nodes(), 3);
assert_eq!(arena.n_leaves(), 2);
let _ = arena.split_leaf(right, 1, 2.0, 0.0, 0.0);
assert_eq!(arena.n_nodes(), 5);
assert_eq!(arena.n_leaves(), 3);
}
#[test]
fn node_id_none_sentinel() {
let none = NodeId::NONE;
assert!(none.is_none());
assert_eq!(none.0, u32::MAX);
let valid = NodeId(0);
assert!(!valid.is_none());
assert_ne!(valid, NodeId::NONE);
}
#[test]
fn reset_clears_everything() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
assert_eq!(arena.n_nodes(), 3);
arena.reset();
assert_eq!(arena.n_nodes(), 0);
assert_eq!(arena.n_leaves(), 0);
assert!(arena.feature_idx.capacity() >= 3);
assert!(arena.is_leaf.capacity() >= 3);
let new_root = arena.add_leaf(0);
assert_eq!(new_root, NodeId(0));
assert_eq!(arena.n_nodes(), 1);
assert_eq!(arena.n_leaves(), 1);
}
#[test]
fn sample_count_tracking() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
assert_eq!(arena.get_sample_count(root), 0);
arena.increment_sample_count(root);
assert_eq!(arena.get_sample_count(root), 1);
arena.increment_sample_count(root);
arena.increment_sample_count(root);
assert_eq!(arena.get_sample_count(root), 3);
let (left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
assert_eq!(arena.get_sample_count(left), 0);
assert_eq!(arena.get_sample_count(right), 0);
assert_eq!(arena.get_sample_count(root), 3);
arena.increment_sample_count(left);
assert_eq!(arena.get_sample_count(left), 1);
assert_eq!(arena.get_sample_count(right), 0);
}
#[test]
fn with_capacity_preallocates() {
let arena = TreeArena::with_capacity(64);
assert_eq!(arena.n_nodes(), 0);
assert_eq!(arena.n_leaves(), 0);
assert!(arena.feature_idx.capacity() >= 64);
assert!(arena.threshold.capacity() >= 64);
assert!(arena.left.capacity() >= 64);
assert!(arena.right.capacity() >= 64);
assert!(arena.leaf_value.capacity() >= 64);
assert!(arena.is_leaf.capacity() >= 64);
assert!(arena.depth.capacity() >= 64);
assert!(arena.sample_count.capacity() >= 64);
}
#[test]
fn set_leaf_value_updates() {
let mut arena = TreeArena::new();
let leaf = arena.add_leaf(0);
assert_eq!(arena.predict(leaf), 0.0);
arena.set_leaf_value(leaf, 42.5);
assert_eq!(arena.predict(leaf), 42.5);
arena.set_leaf_value(leaf, -3.25);
assert_eq!(arena.predict(leaf), -3.25);
}
#[test]
#[should_panic(expected = "predict called on internal node")]
fn predict_panics_on_internal_node() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
let _ = arena.predict(root);
}
#[test]
#[should_panic(expected = "set_leaf_value called on internal node")]
fn set_leaf_value_panics_on_internal_node() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
arena.set_leaf_value(root, 1.0);
}
#[test]
#[should_panic(expected = "split_leaf called on non-leaf node")]
fn split_leaf_panics_on_internal_node() {
let mut arena = TreeArena::new();
let root = arena.add_leaf(0);
let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
let _ = arena.split_leaf(root, 1, 2.0, 0.0, 0.0);
}
#[test]
fn default_matches_new() {
let a = TreeArena::new();
let b = TreeArena::default();
assert_eq!(a.n_nodes(), b.n_nodes());
assert_eq!(a.n_leaves(), b.n_leaves());
}
}