use crate::error::FormatError;
use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
use crate::traverse;
#[derive(Clone, Copy)]
pub struct EnsembleView<'a> {
header: &'a EnsembleHeader,
tree_table: &'a [TreeEntry],
nodes: &'a [PackedNode],
}
impl<'a> EnsembleView<'a> {
pub fn from_bytes(data: &'a [u8]) -> Result<Self, FormatError> {
use core::mem::{align_of, size_of};
let header_size = size_of::<EnsembleHeader>();
if data.len() < header_size {
return Err(FormatError::Truncated);
}
if (data.as_ptr() as usize) % align_of::<EnsembleHeader>() != 0 {
return Err(FormatError::Unaligned);
}
let header = unsafe { &*(data.as_ptr() as *const EnsembleHeader) };
if header.magic != EnsembleHeader::MAGIC {
return Err(FormatError::BadMagic);
}
if header.version != EnsembleHeader::VERSION {
return Err(FormatError::UnsupportedVersion);
}
let n_trees = header.n_trees as usize;
let tree_table_size = n_trees * size_of::<TreeEntry>();
let tree_table_offset = header_size;
if data.len() < tree_table_offset + tree_table_size {
return Err(FormatError::Truncated);
}
let tree_table_ptr = unsafe { data.as_ptr().add(tree_table_offset) } as *const TreeEntry;
let tree_table = unsafe { core::slice::from_raw_parts(tree_table_ptr, n_trees) };
let nodes_base_offset = tree_table_offset + tree_table_size;
let mut total_nodes: usize = 0;
for entry in tree_table {
total_nodes = total_nodes
.checked_add(entry.n_nodes as usize)
.ok_or(FormatError::Truncated)?;
}
let nodes_size = total_nodes
.checked_mul(size_of::<PackedNode>())
.ok_or(FormatError::Truncated)?;
let total_required = nodes_base_offset
.checked_add(nodes_size)
.ok_or(FormatError::Truncated)?;
if data.len() < total_required {
return Err(FormatError::Truncated);
}
for entry in tree_table {
let node_byte_offset = entry.offset as usize;
if node_byte_offset % size_of::<PackedNode>() != 0 {
return Err(FormatError::MisalignedTreeOffset);
}
let tree_bytes = (entry.n_nodes as usize)
.checked_mul(size_of::<PackedNode>())
.ok_or(FormatError::Truncated)?;
let tree_end = node_byte_offset
.checked_add(tree_bytes)
.ok_or(FormatError::Truncated)?;
if tree_end > nodes_size {
return Err(FormatError::Truncated);
}
}
let nodes_ptr = unsafe { data.as_ptr().add(nodes_base_offset) } as *const PackedNode;
let nodes = unsafe { core::slice::from_raw_parts(nodes_ptr, total_nodes) };
let n_features = header.n_features as usize;
for (tree_idx, entry) in tree_table.iter().enumerate() {
let tree_node_offset = entry.offset as usize / size_of::<PackedNode>();
let tree_n_nodes = entry.n_nodes as usize;
for local_idx in 0..tree_n_nodes {
let global_idx = tree_node_offset + local_idx;
let node = &nodes[global_idx];
if !node.is_leaf() {
let left = node.left_child() as usize;
let right = node.right_child() as usize;
if left >= tree_n_nodes || right >= tree_n_nodes {
return Err(FormatError::InvalidNodeIndex);
}
if n_features > 0 && node.feature_idx() as usize >= n_features {
return Err(FormatError::InvalidFeatureIndex);
}
}
}
let _ = tree_idx; }
Ok(Self {
header,
tree_table,
nodes,
})
}
pub fn predict(&self, features: &[f32]) -> f32 {
debug_assert!(
features.len() >= self.header.n_features as usize,
"predict: features.len() ({}) < n_features ({})",
features.len(),
self.header.n_features
);
let mut sum = self.header.base_prediction;
for entry in self.tree_table {
let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
let end = start + entry.n_nodes as usize;
let tree_nodes = &self.nodes[start..end];
sum += traverse::predict_tree(tree_nodes, features);
}
sum
}
pub fn predict_batch(&self, samples: &[&[f32]], out: &mut [f32]) {
assert!(out.len() >= samples.len());
let n = samples.len();
let mut i = 0;
while i + 4 <= n {
let batch = [samples[i], samples[i + 1], samples[i + 2], samples[i + 3]];
let mut sums = [self.header.base_prediction; 4];
for entry in self.tree_table {
let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
let end = start + entry.n_nodes as usize;
let tree_nodes = &self.nodes[start..end];
let preds = traverse::predict_tree_x4(tree_nodes, batch);
for j in 0..4 {
sums[j] += preds[j];
}
}
out[i] = sums[0];
out[i + 1] = sums[1];
out[i + 2] = sums[2];
out[i + 3] = sums[3];
i += 4;
}
while i < n {
out[i] = self.predict(samples[i]);
i += 1;
}
}
#[inline]
pub fn n_trees(&self) -> u16 {
self.header.n_trees
}
#[inline]
pub fn n_features(&self) -> u16 {
self.header.n_features
}
#[inline]
pub fn base_prediction(&self) -> f32 {
self.header.base_prediction
}
#[inline]
pub fn total_nodes(&self) -> usize {
self.nodes.len()
}
}
impl<'a> core::fmt::Debug for EnsembleView<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("EnsembleView")
.field("n_trees", &self.n_trees())
.field("n_features", &self.n_features())
.field("base_prediction", &self.base_prediction())
.field("total_nodes", &self.total_nodes())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
use alloc::{format, vec, vec::Vec};
use core::mem::size_of;
fn build_single_leaf_binary(leaf_value: f32, base: f32) -> Vec<u8> {
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: 1,
n_features: 1,
_reserved: 0,
base_prediction: base,
};
let entry = TreeEntry {
n_nodes: 1,
offset: 0,
};
let node = PackedNode::leaf(leaf_value);
let mut buf = Vec::new();
buf.extend_from_slice(as_bytes(&header));
buf.extend_from_slice(as_bytes(&entry));
buf.extend_from_slice(as_bytes(&node));
buf
}
fn build_one_split_binary() -> Vec<u8> {
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: 1,
n_features: 2,
_reserved: 0,
base_prediction: 0.0,
};
let entry = TreeEntry {
n_nodes: 3,
offset: 0,
};
let nodes = [
PackedNode::split(5.0, 0, 1, 2),
PackedNode::leaf(-1.0),
PackedNode::leaf(1.0),
];
let mut buf = Vec::new();
buf.extend_from_slice(as_bytes(&header));
buf.extend_from_slice(as_bytes(&entry));
for n in &nodes {
buf.extend_from_slice(as_bytes(n));
}
buf
}
fn build_two_tree_binary() -> Vec<u8> {
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: 2,
n_features: 2,
_reserved: 0,
base_prediction: 1.0,
};
let entries = [
TreeEntry {
n_nodes: 3,
offset: 0,
},
TreeEntry {
n_nodes: 1,
offset: 3 * size_of::<PackedNode>() as u32,
},
];
let nodes = [
PackedNode::split(5.0, 0, 1, 2),
PackedNode::leaf(-1.0),
PackedNode::leaf(1.0),
PackedNode::leaf(0.5),
];
let mut buf = Vec::new();
buf.extend_from_slice(as_bytes(&header));
for e in &entries {
buf.extend_from_slice(as_bytes(e));
}
for n in &nodes {
buf.extend_from_slice(as_bytes(n));
}
buf
}
fn as_bytes<T: Sized>(val: &T) -> &[u8] {
unsafe { core::slice::from_raw_parts(val as *const T as *const u8, size_of::<T>()) }
}
#[test]
fn parse_single_leaf() {
let buf = build_single_leaf_binary(42.0, 0.0);
let view = EnsembleView::from_bytes(&buf).unwrap();
assert_eq!(view.n_trees(), 1);
assert_eq!(view.n_features(), 1);
assert_eq!(view.total_nodes(), 1);
}
#[test]
fn predict_single_leaf() {
let buf = build_single_leaf_binary(42.0, 10.0);
let view = EnsembleView::from_bytes(&buf).unwrap();
let pred = view.predict(&[0.0]);
assert!((pred - 52.0).abs() < 1e-6);
}
#[test]
fn predict_one_split_left() {
let buf = build_one_split_binary();
let view = EnsembleView::from_bytes(&buf).unwrap();
let pred = view.predict(&[3.0, 0.0]);
assert!((pred - (-1.0)).abs() < 1e-6);
}
#[test]
fn predict_one_split_right() {
let buf = build_one_split_binary();
let view = EnsembleView::from_bytes(&buf).unwrap();
let pred = view.predict(&[7.0, 0.0]);
assert!((pred - 1.0).abs() < 1e-6);
}
#[test]
fn predict_two_trees() {
let buf = build_two_tree_binary();
let view = EnsembleView::from_bytes(&buf).unwrap();
let pred = view.predict(&[3.0, 0.0]);
assert!((pred - 0.5).abs() < 1e-6);
}
#[test]
fn predict_batch_matches_single() {
let buf = build_two_tree_binary();
let view = EnsembleView::from_bytes(&buf).unwrap();
let samples: Vec<&[f32]> = vec![
&[3.0, 0.0],
&[7.0, 0.0],
&[5.0, 0.0],
&[0.0, 0.0],
&[10.0, 0.0],
];
let mut out = vec![0.0f32; 5];
view.predict_batch(&samples, &mut out);
for (i, &s) in samples.iter().enumerate() {
let expected = view.predict(s);
assert!(
(out[i] - expected).abs() < 1e-6,
"batch[{}] = {}, expected {}",
i,
out[i],
expected
);
}
}
#[test]
fn bad_magic_is_rejected() {
let mut buf = build_single_leaf_binary(0.0, 0.0);
buf[0] = 0xFF; assert_eq!(
EnsembleView::from_bytes(&buf).unwrap_err(),
FormatError::BadMagic
);
}
#[test]
fn truncated_buffer_is_rejected() {
let buf = build_single_leaf_binary(0.0, 0.0);
assert_eq!(
EnsembleView::from_bytes(&buf[..4]).unwrap_err(),
FormatError::Truncated
);
}
#[test]
fn bad_version_is_rejected() {
let mut buf = build_single_leaf_binary(0.0, 0.0);
buf[4] = 99;
buf[5] = 0;
assert_eq!(
EnsembleView::from_bytes(&buf).unwrap_err(),
FormatError::UnsupportedVersion
);
}
#[test]
fn invalid_child_index_is_rejected() {
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: 1,
n_features: 2,
_reserved: 0,
base_prediction: 0.0,
};
let entry = TreeEntry {
n_nodes: 3,
offset: 0,
};
let nodes = [
PackedNode::split(5.0, 0, 1, 99), PackedNode::leaf(-1.0),
PackedNode::leaf(1.0),
];
let mut buf = Vec::new();
buf.extend_from_slice(as_bytes(&header));
buf.extend_from_slice(as_bytes(&entry));
for n in &nodes {
buf.extend_from_slice(as_bytes(n));
}
assert_eq!(
EnsembleView::from_bytes(&buf).unwrap_err(),
FormatError::InvalidNodeIndex
);
}
#[test]
fn debug_format() {
let buf = build_single_leaf_binary(0.0, 0.0);
let view = EnsembleView::from_bytes(&buf).unwrap();
let debug = format!("{:?}", view);
assert!(debug.contains("EnsembleView"));
assert!(debug.contains("n_trees"));
}
}