use super::{TreeNode, LEAF_SENTINEL};
#[repr(C)]
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct FlatNode {
pub right: u32,
pub feature_idx: u32,
pub threshold: f64,
}
impl FlatNode {
pub fn new(right: u32, feature_idx: u32, threshold: f64) -> Self {
Self {
right,
feature_idx,
threshold,
}
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[allow(clippy::unsafe_derive_deserialize)]
#[non_exhaustive]
pub struct FlatTree {
pub nodes: Vec<FlatNode>,
pub predictions: Vec<f64>,
pub leaf_probas: Vec<f32>,
pub n_classes_stored: u32,
pub node_counts: Vec<usize>,
}
impl FlatTree {
pub fn new(
nodes: Vec<FlatNode>,
predictions: Vec<f64>,
leaf_probas: Vec<f32>,
n_classes_stored: u32,
) -> Self {
Self {
nodes,
predictions,
leaf_probas,
n_classes_stored,
node_counts: Vec::new(),
}
}
pub fn from_tree_node(root: &TreeNode, n_classes: usize) -> Self {
let mut nodes = Vec::new();
let mut predictions = Vec::new();
let mut leaf_probas: Vec<f32> = Vec::new();
let mut leaf_count: u32 = 0;
let mut node_counts = Vec::new();
Self::flatten_dfs(
root,
&mut nodes,
&mut predictions,
&mut leaf_probas,
&mut leaf_count,
n_classes,
&mut node_counts,
);
FlatTree {
nodes,
predictions,
leaf_probas,
n_classes_stored: n_classes as u32,
node_counts,
}
}
fn flatten_dfs(
node: &TreeNode,
nodes: &mut Vec<FlatNode>,
predictions: &mut Vec<f64>,
leaf_probas: &mut Vec<f32>,
leaf_count: &mut u32,
n_classes: usize,
node_counts: &mut Vec<usize>,
) {
match node {
TreeNode::Leaf {
prediction,
n_samples,
class_counts,
..
} => {
let li = *leaf_count;
*leaf_count += 1;
nodes.push(FlatNode {
right: LEAF_SENTINEL,
feature_idx: li, threshold: 0.0,
});
node_counts.push(*n_samples);
predictions.push(*prediction);
Self::append_proba(leaf_probas, class_counts, *n_samples, n_classes);
}
TreeNode::Split {
feature_idx,
threshold,
left,
right,
n_samples,
..
} => {
let my_idx = nodes.len();
nodes.push(FlatNode {
right: 0, feature_idx: *feature_idx as u32,
threshold: *threshold,
});
node_counts.push(*n_samples);
Self::flatten_dfs(
left,
nodes,
predictions,
leaf_probas,
leaf_count,
n_classes,
node_counts,
);
nodes[my_idx].right = nodes.len() as u32;
Self::flatten_dfs(
right,
nodes,
predictions,
leaf_probas,
leaf_count,
n_classes,
node_counts,
);
}
}
}
fn append_proba(
probas: &mut Vec<f32>,
class_counts: &[usize],
n_samples: usize,
n_classes: usize,
) {
if n_classes > 0 && n_samples > 0 {
let total = n_samples as f64;
for i in 0..n_classes {
let count = if i < class_counts.len() {
class_counts[i]
} else {
0
};
probas.push((count as f64 / total) as f32);
}
}
}
#[inline(always)]
#[allow(unsafe_code, clippy::inline_always)]
pub fn predict_sample(&self, sample: &[f64]) -> f64 {
let nodes = self.nodes.as_slice();
let preds = self.predictions.as_slice();
debug_assert!(!nodes.is_empty());
let mut idx = 0usize;
loop {
debug_assert!(idx < nodes.len());
let node = unsafe { nodes.get_unchecked(idx) };
if node.right == LEAF_SENTINEL {
let li = node.feature_idx as usize;
return unsafe { *preds.get_unchecked(li) };
}
let feat_val = unsafe { *sample.get_unchecked(node.feature_idx as usize) };
idx = if feat_val <= node.threshold {
idx + 1 } else {
node.right as usize };
}
}
#[inline(always)]
#[allow(unsafe_code, clippy::inline_always)]
pub fn predict_proba_sample(&self, sample: &[f64], n_classes: usize) -> Vec<f64> {
let nodes = self.nodes.as_slice();
let nc = self.n_classes_stored as usize;
debug_assert!(!nodes.is_empty());
let mut idx = 0usize;
loop {
debug_assert!(idx < nodes.len());
let node = unsafe { nodes.get_unchecked(idx) };
if node.right == LEAF_SENTINEL {
let li = node.feature_idx as usize;
let start = li * nc;
let mut result = vec![0.0; n_classes];
let copy_len = n_classes.min(nc);
for (i, p) in self.leaf_probas[start..start + copy_len].iter().enumerate() {
result[i] = *p as f64;
}
return result;
}
let feat_val = unsafe { *sample.get_unchecked(node.feature_idx as usize) };
idx = if feat_val <= node.threshold {
idx + 1
} else {
node.right as usize
};
}
}
#[inline]
pub fn n_nodes(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn leaf_prediction(&self, node_idx: usize) -> f64 {
let node = &self.nodes[node_idx];
debug_assert!(
node.right == LEAF_SENTINEL,
"node_idx {node_idx} is not a leaf"
);
self.predictions[node.feature_idx as usize]
}
#[inline]
pub fn set_leaf_prediction(&mut self, node_idx: usize, value: f64) {
let node = &self.nodes[node_idx];
debug_assert!(
node.right == LEAF_SENTINEL,
"node_idx {node_idx} is not a leaf"
);
self.predictions[node.feature_idx as usize] = value;
}
#[inline]
pub fn is_leaf(&self, node_idx: usize) -> bool {
self.nodes[node_idx].right == LEAF_SENTINEL
}
pub fn predict(&self, features: &[Vec<f64>]) -> Vec<f64> {
features
.iter()
.map(|row| self.predict_sample(row))
.collect()
}
#[inline(always)]
#[allow(unsafe_code, clippy::inline_always)]
pub(crate) fn apply_sample(&self, sample: &[f64]) -> usize {
let nodes = self.nodes.as_slice();
debug_assert!(!nodes.is_empty());
let mut idx = 0usize;
loop {
debug_assert!(idx < nodes.len());
let node = unsafe { nodes.get_unchecked(idx) };
if node.right == LEAF_SENTINEL {
return idx;
}
let feat_val = unsafe { *sample.get_unchecked(node.feature_idx as usize) };
idx = if feat_val <= node.threshold {
idx + 1
} else {
node.right as usize
};
}
}
pub(crate) fn apply(&self, features: &[Vec<f64>]) -> Vec<usize> {
features.iter().map(|row| self.apply_sample(row)).collect()
}
pub fn depth(&self) -> usize {
if self.nodes.is_empty() {
return 0;
}
Self::depth_at(&self.nodes, 0, 1)
}
fn depth_at(nodes: &[FlatNode], idx: usize, d: usize) -> usize {
if idx >= nodes.len() {
return 0;
}
let node = &nodes[idx];
if node.right == LEAF_SENTINEL {
d
} else {
let l = Self::depth_at(nodes, idx + 1, d + 1);
let r = Self::depth_at(nodes, node.right as usize, d + 1);
l.max(r)
}
}
pub fn n_leaves(&self) -> usize {
self.nodes
.iter()
.filter(|n| n.right == LEAF_SENTINEL)
.count()
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn single_leaf_tree() -> FlatTree {
let nodes = vec![FlatNode::new(LEAF_SENTINEL, 0, 0.0)];
let predictions = vec![42.0];
let leaf_probas = vec![0.3f32, 0.7];
FlatTree::new(nodes, predictions, leaf_probas, 2)
}
fn balanced_tree() -> FlatTree {
let nodes = vec![
FlatNode::new(4, 0, 0.5), FlatNode::new(3, 1, 0.3), FlatNode::new(LEAF_SENTINEL, 0, 0.0), FlatNode::new(LEAF_SENTINEL, 1, 0.0), FlatNode::new(6, 1, 0.7), FlatNode::new(LEAF_SENTINEL, 2, 0.0), FlatNode::new(LEAF_SENTINEL, 3, 0.0), ];
let predictions = vec![1.0, 2.0, 3.0, 4.0];
let leaf_probas = vec![
1.0, 0.0, 0.0, 1.0, 0.6, 0.4, 0.2, 0.8, ];
FlatTree::new(nodes, predictions, leaf_probas, 2)
}
fn all_left_tree(depth: usize) -> FlatTree {
let mut nodes = Vec::new();
let mut predictions = Vec::new();
let mut leaf_count = 0u32;
fn build(
nodes: &mut Vec<FlatNode>,
predictions: &mut Vec<f64>,
leaf_count: &mut u32,
depth: usize,
max_depth: usize,
) {
if depth >= max_depth {
let li = *leaf_count;
*leaf_count += 1;
nodes.push(FlatNode::new(LEAF_SENTINEL, li, 0.0));
predictions.push(li as f64);
return;
}
let my_idx = nodes.len();
nodes.push(FlatNode::new(0, 0, 0.5)); build(nodes, predictions, leaf_count, depth + 1, max_depth);
nodes[my_idx].right = nodes.len() as u32;
let li = *leaf_count;
*leaf_count += 1;
nodes.push(FlatNode::new(LEAF_SENTINEL, li, 0.0));
predictions.push(li as f64);
}
build(&mut nodes, &mut predictions, &mut leaf_count, 0, depth);
FlatTree::new(nodes, predictions, vec![], 0)
}
fn all_right_tree(depth: usize) -> FlatTree {
let mut nodes = Vec::new();
let mut predictions = Vec::new();
let mut leaf_count = 0u32;
fn build(
nodes: &mut Vec<FlatNode>,
predictions: &mut Vec<f64>,
leaf_count: &mut u32,
depth: usize,
max_depth: usize,
) {
if depth >= max_depth {
let li = *leaf_count;
*leaf_count += 1;
nodes.push(FlatNode::new(LEAF_SENTINEL, li, 0.0));
predictions.push(li as f64);
return;
}
let my_idx = nodes.len();
nodes.push(FlatNode::new(0, 0, 0.5)); let li = *leaf_count;
*leaf_count += 1;
nodes.push(FlatNode::new(LEAF_SENTINEL, li, 0.0));
predictions.push(li as f64);
nodes[my_idx].right = nodes.len() as u32;
build(nodes, predictions, leaf_count, depth + 1, max_depth);
}
build(&mut nodes, &mut predictions, &mut leaf_count, 0, depth);
FlatTree::new(nodes, predictions, vec![], 0)
}
#[test]
fn test_flat_tree_predict_boundaries() {
let tree = single_leaf_tree();
assert_eq!(tree.predict_sample(&[0.0, 1.0]), 42.0);
assert_eq!(tree.predict_sample(&[99.0]), 42.0);
let proba = tree.predict_proba_sample(&[0.0], 2);
assert!((proba[0] - 0.3).abs() < 1e-5);
assert!((proba[1] - 0.7).abs() < 1e-5);
assert_eq!(tree.apply_sample(&[0.0]), 0);
let tree = balanced_tree();
assert_eq!(tree.predict_sample(&[0.0, 0.0]), 1.0);
assert_eq!(tree.predict_sample(&[0.0, 0.5]), 2.0);
assert_eq!(tree.predict_sample(&[0.8, 0.5]), 3.0);
assert_eq!(tree.predict_sample(&[0.8, 0.9]), 4.0);
let proba = tree.predict_proba_sample(&[0.8, 0.9], 2);
assert!((proba[0] - 0.2).abs() < 1e-5);
assert!((proba[1] - 0.8).abs() < 1e-5);
assert_eq!(tree.apply_sample(&[0.0, 0.0]), 2); assert_eq!(tree.apply_sample(&[0.8, 0.9]), 6);
let tree = all_left_tree(5);
assert_eq!(tree.depth(), 6);
let _ = tree.predict_sample(&[0.0]);
let _ = tree.predict_sample(&[1.0]);
let preds = tree.predict(&[vec![0.0], vec![1.0], vec![0.5]]);
assert_eq!(preds.len(), 3);
let tree = all_right_tree(5);
assert_eq!(tree.depth(), 6);
let _ = tree.predict_sample(&[0.0]);
let _ = tree.predict_sample(&[1.0]);
let preds = tree.predict(&[vec![0.0], vec![1.0]]);
assert_eq!(preds.len(), 2);
}
}