use rkyv::{Archive, Deserialize, Serialize};
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Archive,
Serialize,
Deserialize,
serde::Serialize,
serde::Deserialize,
)]
pub enum NodeType {
Internal {
feature_idx: usize,
bin_threshold: u8,
split_value: f64,
left_child: usize,
right_child: usize,
},
Leaf {
value: f32,
},
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct Node {
pub node_type: NodeType,
pub depth: usize,
pub num_samples: usize,
pub sum_gradients: f32,
pub sum_hessians: f32,
}
impl Node {
pub fn leaf(value: f32, depth: usize, num_samples: usize, sum_g: f32, sum_h: f32) -> Self {
Self {
node_type: NodeType::Leaf { value },
depth,
num_samples,
sum_gradients: sum_g,
sum_hessians: sum_h,
}
}
#[allow(clippy::too_many_arguments)]
pub fn internal(
feature_idx: usize,
bin_threshold: u8,
split_value: f64,
left_child: usize,
right_child: usize,
depth: usize,
num_samples: usize,
sum_g: f32,
sum_h: f32,
) -> Self {
Self {
node_type: NodeType::Internal {
feature_idx,
bin_threshold,
split_value,
left_child,
right_child,
},
depth,
num_samples,
sum_gradients: sum_g,
sum_hessians: sum_h,
}
}
#[inline]
pub fn is_leaf(&self) -> bool {
matches!(self.node_type, NodeType::Leaf { .. })
}
#[inline]
pub fn leaf_value(&self) -> Option<f32> {
match self.node_type {
NodeType::Leaf { value } => Some(value),
NodeType::Internal { .. } => None,
}
}
#[inline]
pub fn split_info(&self) -> Option<(usize, u8, f64, usize, usize)> {
match self.node_type {
NodeType::Internal {
feature_idx,
bin_threshold,
split_value,
left_child,
right_child,
} => Some((
feature_idx,
bin_threshold,
split_value,
left_child,
right_child,
)),
NodeType::Leaf { .. } => None,
}
}
#[inline]
pub fn compute_leaf_weight(sum_g: f32, sum_h: f32, lambda: f32) -> f32 {
-sum_g / (sum_h + lambda)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leaf_node() {
let node = Node::leaf(0.5, 2, 100, 10.0, 20.0);
assert!(node.is_leaf());
assert_eq!(node.leaf_value(), Some(0.5));
assert_eq!(node.split_info(), None);
assert_eq!(node.depth, 2);
assert_eq!(node.num_samples, 100);
}
#[test]
fn test_internal_node() {
let node = Node::internal(3, 128, 5.5, 1, 2, 1, 200, 15.0, 30.0);
assert!(!node.is_leaf());
assert_eq!(node.leaf_value(), None);
let split = node.split_info();
assert!(split.is_some());
let (f, t, v, l, r) = split.unwrap();
assert_eq!(f, 3);
assert_eq!(t, 128);
assert!((v - 5.5).abs() < 1e-10);
assert_eq!(l, 1);
assert_eq!(r, 2);
}
#[test]
fn test_leaf_weight() {
let weight = Node::compute_leaf_weight(-10.0, 20.0, 0.0);
assert!((weight - 0.5).abs() < 1e-6);
let weight_reg = Node::compute_leaf_weight(-10.0, 20.0, 10.0);
assert!((weight_reg - (10.0 / 30.0)).abs() < 1e-6);
}
}