use std::collections::VecDeque;
use crate::ensemble::distributional::DistributionalSGBT;
use crate::ensemble::SGBT;
use crate::loss::Loss;
use crate::tree::node::NodeId;
use irithyll_core::packed::{EnsembleHeader, PackedNode, TreeEntry};
use irithyll_core::packed_i16::{PackedNodeI16, QuantizedEnsembleHeader};
pub fn export_packed<L: Loss>(model: &SGBT<L>, n_features: usize) -> Vec<u8> {
let learning_rate = model.config().learning_rate;
let n_trees = model.steps().len();
let mut all_tree_nodes: Vec<Vec<PackedNode>> = Vec::with_capacity(n_trees);
for step in model.steps() {
let arena = step.slot().active_tree().arena();
let root = step.slot().active_tree().root();
let packed_nodes = bfs_pack_tree(arena, root, learning_rate);
all_tree_nodes.push(packed_nodes);
}
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: n_trees as u16,
n_features: n_features as u16,
_reserved: 0,
base_prediction: model.base_prediction() as f32,
};
let mut tree_table: Vec<TreeEntry> = Vec::with_capacity(n_trees);
let mut byte_offset: u32 = 0;
let node_size = core::mem::size_of::<PackedNode>() as u32;
for tree_nodes in &all_tree_nodes {
tree_table.push(TreeEntry {
n_nodes: tree_nodes.len() as u32,
offset: byte_offset,
});
byte_offset += tree_nodes.len() as u32 * node_size;
}
let header_size = core::mem::size_of::<EnsembleHeader>();
let tree_table_size = n_trees * core::mem::size_of::<TreeEntry>();
let nodes_size = byte_offset as usize;
let total_size = header_size + tree_table_size + nodes_size;
let mut buf: Vec<u8> = Vec::with_capacity(total_size);
header.push_le_bytes(&mut buf);
for entry in &tree_table {
entry.push_le_bytes(&mut buf);
}
for tree_nodes in &all_tree_nodes {
for node in tree_nodes {
node.push_le_bytes(&mut buf);
}
}
debug_assert_eq!(buf.len(), total_size);
buf
}
fn bfs_pack_tree(
arena: &crate::tree::node::TreeArena,
root: NodeId,
learning_rate: f64,
) -> Vec<PackedNode> {
if root.is_none() || arena.n_nodes() == 0 {
return vec![PackedNode::leaf(0.0)];
}
let mut queue = VecDeque::new();
let mut bfs_order: Vec<NodeId> = Vec::new();
queue.push_back(root);
while let Some(node_id) = queue.pop_front() {
bfs_order.push(node_id);
let idx = node_id.idx();
if !arena.is_leaf[idx] {
queue.push_back(arena.left[idx]);
queue.push_back(arena.right[idx]);
}
}
let n_nodes = bfs_order.len();
assert!(
n_nodes <= u16::MAX as usize,
"tree has {} nodes, exceeds u16::MAX (65535)",
n_nodes
);
let max_id = bfs_order.iter().map(|id| id.0).max().unwrap_or(0) as usize;
let mut id_to_bfs = vec![u16::MAX; max_id + 1];
for (bfs_idx, &node_id) in bfs_order.iter().enumerate() {
id_to_bfs[node_id.idx()] = bfs_idx as u16;
}
let mut packed = Vec::with_capacity(n_nodes);
for &node_id in &bfs_order {
let idx = node_id.idx();
if arena.is_leaf[idx] {
packed.push(PackedNode::leaf(
(learning_rate * arena.leaf_value[idx]) as f32,
));
} else {
let feature = arena.feature_idx[idx] as u16;
let threshold = arena.threshold[idx] as f32;
let left_bfs = id_to_bfs[arena.left[idx].idx()];
let right_bfs = id_to_bfs[arena.right[idx].idx()];
packed.push(PackedNode::split(threshold, feature, left_bfs, right_bfs));
}
}
packed
}
pub fn validate_export<L: Loss>(model: &SGBT<L>, packed: &[u8], test_features: &[Vec<f64>]) -> f64 {
let view = irithyll_core::EnsembleView::from_bytes(packed)
.expect("validate_export: invalid packed binary");
let mut max_diff: f64 = 0.0;
for features_f64 in test_features {
let original = model.predict(features_f64);
let features_f32: Vec<f32> = features_f64.iter().map(|&v| v as f32).collect();
let packed_pred = view.predict(&features_f32) as f64;
let diff = (original - packed_pred).abs();
if diff > max_diff {
max_diff = diff;
}
}
max_diff
}
pub fn export_packed_i16<L: Loss>(model: &SGBT<L>, n_features: usize) -> Vec<u8> {
let learning_rate = model.config().learning_rate;
let n_trees = model.steps().len();
let mut thresholds_per_feature: Vec<Vec<f64>> = vec![Vec::new(); n_features];
let mut all_leaf_values: Vec<f64> = Vec::new();
for step in model.steps() {
let arena = step.slot().active_tree().arena();
let root = step.slot().active_tree().root();
if root.is_none() || arena.n_nodes() == 0 {
all_leaf_values.push(0.0);
continue;
}
let mut queue = VecDeque::new();
queue.push_back(root);
while let Some(node_id) = queue.pop_front() {
let idx = node_id.idx();
if arena.is_leaf[idx] {
all_leaf_values.push(learning_rate * arena.leaf_value[idx]);
} else {
let feat = arena.feature_idx[idx] as usize;
if feat < n_features {
thresholds_per_feature[feat].push(arena.threshold[idx]);
}
queue.push_back(arena.left[idx]);
queue.push_back(arena.right[idx]);
}
}
}
let feature_scales: Vec<f32> = thresholds_per_feature
.iter()
.map(|thresholds| {
let max_abs = thresholds.iter().map(|t| t.abs()).fold(0.0f64, f64::max);
if max_abs == 0.0 {
1.0f32
} else {
(32767.0 / max_abs) as f32
}
})
.collect();
let max_abs_leaf = all_leaf_values
.iter()
.map(|v| v.abs())
.fold(0.0f64, f64::max);
let leaf_scale: f32 = if max_abs_leaf == 0.0 {
1.0
} else {
(32767.0 / max_abs_leaf) as f32
};
let mut all_tree_nodes: Vec<Vec<PackedNodeI16>> = Vec::with_capacity(n_trees);
for step in model.steps() {
let arena = step.slot().active_tree().arena();
let root = step.slot().active_tree().root();
let packed_nodes =
bfs_pack_tree_i16(arena, root, learning_rate, &feature_scales, leaf_scale);
all_tree_nodes.push(packed_nodes);
}
let header = QuantizedEnsembleHeader {
magic: QuantizedEnsembleHeader::MAGIC,
version: QuantizedEnsembleHeader::VERSION,
n_trees: n_trees as u16,
n_features: n_features as u16,
_reserved: 0,
base_prediction: model.base_prediction() as f32,
};
let mut tree_table: Vec<TreeEntry> = Vec::with_capacity(n_trees);
let mut byte_offset: u32 = 0;
let node_size = core::mem::size_of::<PackedNodeI16>() as u32;
for tree_nodes in &all_tree_nodes {
tree_table.push(TreeEntry {
n_nodes: tree_nodes.len() as u32,
offset: byte_offset,
});
byte_offset += tree_nodes.len() as u32 * node_size;
}
let header_size = core::mem::size_of::<QuantizedEnsembleHeader>();
let leaf_scale_size = core::mem::size_of::<f32>();
let feature_scales_size = n_features * core::mem::size_of::<f32>();
let tree_table_size = n_trees * core::mem::size_of::<TreeEntry>();
let nodes_size = byte_offset as usize;
let total_size =
header_size + leaf_scale_size + feature_scales_size + tree_table_size + nodes_size;
let mut buf: Vec<u8> = Vec::with_capacity(total_size);
header.push_le_bytes(&mut buf);
leaf_scale.push_le_bytes(&mut buf);
for scale in &feature_scales {
scale.push_le_bytes(&mut buf);
}
for entry in &tree_table {
entry.push_le_bytes(&mut buf);
}
for tree_nodes in &all_tree_nodes {
for node in tree_nodes {
node.push_le_bytes(&mut buf);
}
}
debug_assert_eq!(buf.len(), total_size);
buf
}
fn bfs_pack_tree_i16(
arena: &crate::tree::node::TreeArena,
root: NodeId,
learning_rate: f64,
feature_scales: &[f32],
leaf_scale: f32,
) -> Vec<PackedNodeI16> {
if root.is_none() || arena.n_nodes() == 0 {
return vec![PackedNodeI16::leaf(0)];
}
let mut queue = VecDeque::new();
let mut bfs_order: Vec<NodeId> = Vec::new();
queue.push_back(root);
while let Some(node_id) = queue.pop_front() {
bfs_order.push(node_id);
let idx = node_id.idx();
if !arena.is_leaf[idx] {
queue.push_back(arena.left[idx]);
queue.push_back(arena.right[idx]);
}
}
let n_nodes = bfs_order.len();
assert!(
n_nodes <= u16::MAX as usize,
"tree has {} nodes, exceeds u16::MAX (65535)",
n_nodes
);
let max_id = bfs_order.iter().map(|id| id.0).max().unwrap_or(0) as usize;
let mut id_to_bfs = vec![u16::MAX; max_id + 1];
for (bfs_idx, &node_id) in bfs_order.iter().enumerate() {
id_to_bfs[node_id.idx()] = bfs_idx as u16;
}
let mut packed = Vec::with_capacity(n_nodes);
for &node_id in &bfs_order {
let idx = node_id.idx();
if arena.is_leaf[idx] {
let leaf_f64 = learning_rate * arena.leaf_value[idx];
let leaf_i16 = (leaf_f64 * leaf_scale as f64) as i16;
packed.push(PackedNodeI16::leaf(leaf_i16));
} else {
let feat = arena.feature_idx[idx] as usize;
let scale = if feat < feature_scales.len() {
feature_scales[feat]
} else {
1.0
};
let threshold_i16 = (arena.threshold[idx] * scale as f64) as i16;
let feature = feat as u16;
let left_bfs = id_to_bfs[arena.left[idx].idx()];
let right_bfs = id_to_bfs[arena.right[idx].idx()];
packed.push(PackedNodeI16::split(
threshold_i16,
feature,
left_bfs,
right_bfs,
));
}
}
packed
}
pub fn validate_export_i16<L: Loss>(
model: &SGBT<L>,
packed: &[u8],
test_features: &[Vec<f64>],
) -> f64 {
let view = irithyll_core::QuantizedEnsembleView::from_bytes(packed)
.expect("validate_export_i16: invalid packed binary");
let mut max_diff: f64 = 0.0;
for features_f64 in test_features {
let original = model.predict(features_f64);
let features_f32: Vec<f32> = features_f64.iter().map(|&v| v as f32).collect();
let quantized_pred = view.predict(&features_f32) as f64;
let diff = (original - quantized_pred).abs();
if diff > max_diff {
max_diff = diff;
}
}
max_diff
}
pub fn export_distributional_packed(
model: &DistributionalSGBT,
n_features: usize,
) -> (Vec<u8>, f64) {
let learning_rate = model.learning_rate();
let steps = model.location_steps();
let n_trees = steps.len();
let mut all_tree_nodes: Vec<Vec<PackedNode>> = Vec::with_capacity(n_trees);
for step in steps {
let arena = step.slot().active_tree().arena();
let root = step.slot().active_tree().root();
let packed_nodes = bfs_pack_tree(arena, root, learning_rate);
all_tree_nodes.push(packed_nodes);
}
let header = EnsembleHeader {
magic: EnsembleHeader::MAGIC,
version: EnsembleHeader::VERSION,
n_trees: n_trees as u16,
n_features: n_features as u16,
_reserved: 0,
base_prediction: 0.0,
};
let mut tree_table: Vec<TreeEntry> = Vec::with_capacity(n_trees);
let mut byte_offset: u32 = 0;
let node_size = core::mem::size_of::<PackedNode>() as u32;
for tree_nodes in &all_tree_nodes {
tree_table.push(TreeEntry {
n_nodes: tree_nodes.len() as u32,
offset: byte_offset,
});
byte_offset += tree_nodes.len() as u32 * node_size;
}
let header_size = core::mem::size_of::<EnsembleHeader>();
let tree_table_size = n_trees * core::mem::size_of::<TreeEntry>();
let nodes_size = byte_offset as usize;
let total_size = header_size + tree_table_size + nodes_size;
let mut buf: Vec<u8> = Vec::with_capacity(total_size);
header.push_le_bytes(&mut buf);
for entry in &tree_table {
entry.push_le_bytes(&mut buf);
}
for tree_nodes in &all_tree_nodes {
for node in tree_nodes {
node.push_le_bytes(&mut buf);
}
}
debug_assert_eq!(buf.len(), total_size);
(buf, model.location_base())
}
trait PushBytes {
fn push_le_bytes(&self, buf: &mut Vec<u8>);
}
impl PushBytes for EnsembleHeader {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.magic.to_le_bytes());
buf.extend_from_slice(&self.version.to_le_bytes());
buf.extend_from_slice(&self.n_trees.to_le_bytes());
buf.extend_from_slice(&self.n_features.to_le_bytes());
buf.extend_from_slice(&self._reserved.to_le_bytes());
buf.extend_from_slice(&self.base_prediction.to_le_bytes());
}
}
impl PushBytes for TreeEntry {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.n_nodes.to_le_bytes());
buf.extend_from_slice(&self.offset.to_le_bytes());
}
}
impl PushBytes for PackedNode {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.value.to_le_bytes());
buf.extend_from_slice(&self.children.to_le_bytes());
buf.extend_from_slice(&self.feature_flags.to_le_bytes());
buf.extend_from_slice(&self._reserved.to_le_bytes());
}
}
impl PushBytes for QuantizedEnsembleHeader {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.magic.to_le_bytes());
buf.extend_from_slice(&self.version.to_le_bytes());
buf.extend_from_slice(&self.n_trees.to_le_bytes());
buf.extend_from_slice(&self.n_features.to_le_bytes());
buf.extend_from_slice(&self._reserved.to_le_bytes());
buf.extend_from_slice(&self.base_prediction.to_le_bytes());
}
}
impl PushBytes for PackedNodeI16 {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.value.to_le_bytes());
buf.extend_from_slice(&self.feature_flags.to_le_bytes());
buf.extend_from_slice(&self.children.to_le_bytes());
}
}
impl PushBytes for f32 {
fn push_le_bytes(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(&self.to_le_bytes());
}
}
pub fn export_turbo_quantized_weights(weights: &[f64]) -> Vec<u8> {
irithyll_core::turbo_quant::quantize_weights(weights).to_bytes()
}
pub fn validate_turbo_quantized(weights: &[f64], packed: &[u8]) -> f64 {
let view = irithyll_core::turbo_quant::TurboQuantizedView::from_bytes(packed)
.expect("validate_turbo_quantized: invalid packed binary");
assert_eq!(
view.n_weights(),
weights.len(),
"weight count mismatch: original {} vs packed {}",
weights.len(),
view.n_weights()
);
let mut max_diff: f64 = 0.0;
for (i, &w) in weights.iter().enumerate() {
let mut unit = vec![0.0; weights.len()];
unit[i] = 1.0;
let dequant = view.predict(&unit);
let diff = (w - dequant).abs();
if diff > max_diff {
max_diff = diff;
}
}
max_diff
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::config::SGBTConfig;
use crate::sample::Sample;
fn trained_model() -> SGBT {
let config = SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(5)
.max_depth(3)
.n_bins(8)
.build()
.unwrap();
let mut model = SGBT::new(config);
for i in 0..100 {
let x = (i as f64) * 0.1;
model.train_one(&Sample::new(vec![x, x * 2.0, x * 0.5], x * 3.0));
}
model
}
#[test]
fn export_produces_valid_binary() {
let model = trained_model();
let packed = export_packed(&model, 3);
let view = irithyll_core::EnsembleView::from_bytes(&packed);
assert!(view.is_ok(), "exported binary should be valid");
let view = view.unwrap();
assert_eq!(view.n_trees(), 5);
assert_eq!(view.n_features(), 3);
}
#[test]
fn export_preserves_base_prediction() {
let model = trained_model();
let packed = export_packed(&model, 3);
let view = irithyll_core::EnsembleView::from_bytes(&packed).unwrap();
let expected = model.base_prediction() as f32;
assert!(
(view.base_prediction() - expected).abs() < 1e-6,
"base prediction mismatch: got {}, expected {}",
view.base_prediction(),
expected
);
}
#[test]
fn export_predictions_match_within_tolerance() {
let model = trained_model();
let packed = export_packed(&model, 3);
let test_data: Vec<Vec<f64>> = (0..50)
.map(|i| {
let x = (i as f64) * 0.2;
vec![x, x * 2.0, x * 0.5]
})
.collect();
let max_diff = validate_export(&model, &packed, &test_data);
assert!(
max_diff < 0.1,
"max prediction difference {} exceeds tolerance",
max_diff
);
}
#[test]
fn export_untrained_model() {
let config = SGBTConfig::builder().n_steps(3).build().unwrap();
let model = SGBT::new(config);
let packed = export_packed(&model, 5);
let view = irithyll_core::EnsembleView::from_bytes(&packed).unwrap();
assert_eq!(view.n_trees(), 3);
let pred = view.predict(&[0.0, 0.0, 0.0, 0.0, 0.0]);
assert!(pred.is_finite());
}
#[test]
fn binary_size_is_compact() {
let model = trained_model();
let packed = export_packed(&model, 3);
let header_size = 16;
let table_size = 5 * 8;
let min_size = header_size + table_size + 5 * 12; assert!(
packed.len() >= min_size,
"packed binary too small: {} bytes",
packed.len()
);
assert!(
packed.len() < 100_000,
"packed binary unexpectedly large: {} bytes",
packed.len()
);
}
#[test]
fn roundtrip_single_tree() {
let config = SGBTConfig::builder()
.n_steps(1)
.learning_rate(0.05)
.grace_period(5)
.max_depth(2)
.n_bins(8)
.build()
.unwrap();
let mut model = SGBT::new(config);
for i in 0..50 {
let x = (i as f64) * 0.1;
model.train_one(&Sample::new(vec![x, x * 2.0], x + 1.0));
}
let packed = export_packed(&model, 2);
let view = irithyll_core::EnsembleView::from_bytes(&packed).unwrap();
assert_eq!(view.n_trees(), 1);
let pred = view.predict(&[2.5, 5.0]);
assert!(pred.is_finite());
}
#[test]
fn export_i16_produces_valid_binary() {
let model = trained_model();
let packed = export_packed_i16(&model, 3);
let view = irithyll_core::QuantizedEnsembleView::from_bytes(&packed);
assert!(view.is_ok(), "exported i16 binary should be valid");
let view = view.unwrap();
assert_eq!(view.n_trees(), 5);
assert_eq!(view.n_features(), 3);
}
#[test]
fn export_i16_preserves_base_prediction() {
let model = trained_model();
let packed = export_packed_i16(&model, 3);
let view = irithyll_core::QuantizedEnsembleView::from_bytes(&packed).unwrap();
let expected = model.base_prediction() as f32;
assert!(
(view.base_prediction() - expected).abs() < 1e-6,
"i16 base prediction mismatch: got {}, expected {}",
view.base_prediction(),
expected
);
}
#[test]
fn export_i16_predictions_within_tolerance() {
let model = trained_model();
let packed = export_packed_i16(&model, 3);
let test_data: Vec<Vec<f64>> = (0..50)
.map(|i| {
let x = (i as f64) * 0.2;
vec![x, x * 2.0, x * 0.5]
})
.collect();
let max_diff = validate_export_i16(&model, &packed, &test_data);
assert!(
max_diff < 0.5,
"i16 max prediction difference {} exceeds tolerance 0.5",
max_diff
);
}
#[test]
fn export_i16_untrained_model() {
let config = SGBTConfig::builder().n_steps(3).build().unwrap();
let model = SGBT::new(config);
let packed = export_packed_i16(&model, 5);
let view = irithyll_core::QuantizedEnsembleView::from_bytes(&packed).unwrap();
assert_eq!(view.n_trees(), 3);
let pred = view.predict(&[0.0, 0.0, 0.0, 0.0, 0.0]);
assert!(pred.is_finite());
}
#[test]
fn export_i16_single_tree_roundtrip() {
let config = SGBTConfig::builder()
.n_steps(1)
.learning_rate(0.05)
.grace_period(5)
.max_depth(2)
.n_bins(8)
.build()
.unwrap();
let mut model = SGBT::new(config);
for i in 0..50 {
let x = (i as f64) * 0.1;
model.train_one(&Sample::new(vec![x, x * 2.0], x + 1.0));
}
let packed = export_packed_i16(&model, 2);
let view = irithyll_core::QuantizedEnsembleView::from_bytes(&packed).unwrap();
assert_eq!(view.n_trees(), 1);
let pred = view.predict(&[2.5, 5.0]);
assert!(pred.is_finite());
}
#[test]
fn export_distributional_packed_roundtrip() {
use crate::ensemble::distributional::DistributionalSGBT;
let config = SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(5)
.max_depth(3)
.n_bins(8)
.initial_target_count(10)
.build()
.unwrap();
let mut model = DistributionalSGBT::new(config);
let n_features = 3;
for i in 0..100 {
let x = (i as f64) * 0.1;
model.train_one(&(vec![x, x * 2.0, x * 0.5], x * 3.0));
}
let (packed, location_base) = export_distributional_packed(&model, n_features);
let view = irithyll_core::EnsembleView::from_bytes(&packed)
.expect("exported distributional binary should be valid");
assert_eq!(view.n_trees(), 5);
assert_eq!(view.n_features(), 3);
assert!(
view.base_prediction().abs() < 1e-6,
"header base_prediction should be 0.0, got {}",
view.base_prediction()
);
let test_features: Vec<Vec<f64>> = (0..20)
.map(|i| {
let x = (i as f64) * 0.5;
vec![x, x * 2.0, x * 0.5]
})
.collect();
let mut max_diff: f64 = 0.0;
for features in &test_features {
let full_mu = model.predict(features).mu;
let features_f32: Vec<f32> = features.iter().map(|&v| v as f32).collect();
let packed_mu = location_base + view.predict(&features_f32) as f64;
let diff = (full_mu - packed_mu).abs();
if diff > max_diff {
max_diff = diff;
}
}
assert!(
max_diff < 0.1,
"max mu difference {} between full tree and packed export exceeds f32 tolerance",
max_diff
);
}
}