1use crate::error::FormatError;
8use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
9use crate::traverse;
10
11#[derive(Clone, Copy)]
21pub struct EnsembleView<'a> {
22 header: &'a EnsembleHeader,
23 tree_table: &'a [TreeEntry],
24 nodes: &'a [PackedNode],
25}
26
27impl<'a> EnsembleView<'a> {
28 pub fn from_bytes(data: &'a [u8]) -> Result<Self, FormatError> {
41 use core::mem::{align_of, size_of};
42
43 let header_size = size_of::<EnsembleHeader>();
44 if data.len() < header_size {
45 return Err(FormatError::Truncated);
46 }
47
48 if (data.as_ptr() as usize) % align_of::<EnsembleHeader>() != 0 {
52 return Err(FormatError::Unaligned);
53 }
54
55 let header = unsafe { &*(data.as_ptr() as *const EnsembleHeader) };
57
58 if header.magic != EnsembleHeader::MAGIC {
59 return Err(FormatError::BadMagic);
60 }
61 if header.version != EnsembleHeader::VERSION {
62 return Err(FormatError::UnsupportedVersion);
63 }
64
65 let n_trees = header.n_trees as usize;
66 let tree_table_size = n_trees * size_of::<TreeEntry>();
67 let tree_table_offset = header_size;
68
69 if data.len() < tree_table_offset + tree_table_size {
70 return Err(FormatError::Truncated);
71 }
72
73 let tree_table_ptr = unsafe { data.as_ptr().add(tree_table_offset) } as *const TreeEntry;
76 let tree_table = unsafe { core::slice::from_raw_parts(tree_table_ptr, n_trees) };
77
78 let nodes_base_offset = tree_table_offset + tree_table_size;
80 let mut total_nodes: usize = 0;
81 for entry in tree_table {
82 total_nodes = total_nodes
83 .checked_add(entry.n_nodes as usize)
84 .ok_or(FormatError::Truncated)?;
85 }
86
87 let nodes_size = total_nodes
88 .checked_mul(size_of::<PackedNode>())
89 .ok_or(FormatError::Truncated)?;
90 let total_required = nodes_base_offset
91 .checked_add(nodes_size)
92 .ok_or(FormatError::Truncated)?;
93 if data.len() < total_required {
94 return Err(FormatError::Truncated);
95 }
96
97 for entry in tree_table {
99 let node_byte_offset = entry.offset as usize;
100 if node_byte_offset % size_of::<PackedNode>() != 0 {
102 return Err(FormatError::MisalignedTreeOffset);
103 }
104 let tree_bytes = (entry.n_nodes as usize)
105 .checked_mul(size_of::<PackedNode>())
106 .ok_or(FormatError::Truncated)?;
107 let tree_end = node_byte_offset
108 .checked_add(tree_bytes)
109 .ok_or(FormatError::Truncated)?;
110 if tree_end > nodes_size {
111 return Err(FormatError::Truncated);
112 }
113 }
114
115 let nodes_ptr = unsafe { data.as_ptr().add(nodes_base_offset) } as *const PackedNode;
116 let nodes = unsafe { core::slice::from_raw_parts(nodes_ptr, total_nodes) };
117
118 let n_features = header.n_features as usize;
120 for (tree_idx, entry) in tree_table.iter().enumerate() {
121 let tree_node_offset = entry.offset as usize / size_of::<PackedNode>();
122 let tree_n_nodes = entry.n_nodes as usize;
123
124 for local_idx in 0..tree_n_nodes {
125 let global_idx = tree_node_offset + local_idx;
126 let node = &nodes[global_idx];
127
128 if !node.is_leaf() {
129 let left = node.left_child() as usize;
130 let right = node.right_child() as usize;
131
132 if left >= tree_n_nodes || right >= tree_n_nodes {
134 return Err(FormatError::InvalidNodeIndex);
135 }
136
137 if n_features > 0 && node.feature_idx() as usize >= n_features {
138 return Err(FormatError::InvalidFeatureIndex);
139 }
140 }
141 }
142
143 let _ = tree_idx; }
145
146 Ok(Self {
147 header,
148 tree_table,
149 nodes,
150 })
151 }
152
153 pub fn predict(&self, features: &[f32]) -> f32 {
164 debug_assert!(
165 features.len() >= self.header.n_features as usize,
166 "predict: features.len() ({}) < n_features ({})",
167 features.len(),
168 self.header.n_features
169 );
170 let mut sum = self.header.base_prediction;
171 for entry in self.tree_table {
172 let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
173 let end = start + entry.n_nodes as usize;
174 let tree_nodes = &self.nodes[start..end];
175 sum += traverse::predict_tree(tree_nodes, features);
176 }
177 sum
178 }
179
180 pub fn predict_batch(&self, samples: &[&[f32]], out: &mut [f32]) {
190 assert!(out.len() >= samples.len());
191
192 let n = samples.len();
193 let mut i = 0;
194
195 while i + 4 <= n {
197 let batch = [samples[i], samples[i + 1], samples[i + 2], samples[i + 3]];
198 let mut sums = [self.header.base_prediction; 4];
200 for entry in self.tree_table {
201 let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
202 let end = start + entry.n_nodes as usize;
203 let tree_nodes = &self.nodes[start..end];
204 let preds = traverse::predict_tree_x4(tree_nodes, batch);
205 for j in 0..4 {
206 sums[j] += preds[j];
207 }
208 }
209 out[i] = sums[0];
210 out[i + 1] = sums[1];
211 out[i + 2] = sums[2];
212 out[i + 3] = sums[3];
213 i += 4;
214 }
215
216 while i < n {
218 out[i] = self.predict(samples[i]);
219 i += 1;
220 }
221 }
222
223 #[inline]
225 pub fn n_trees(&self) -> u16 {
226 self.header.n_trees
227 }
228
229 #[inline]
231 pub fn n_features(&self) -> u16 {
232 self.header.n_features
233 }
234
235 #[inline]
237 pub fn base_prediction(&self) -> f32 {
238 self.header.base_prediction
239 }
240
241 #[inline]
243 pub fn total_nodes(&self) -> usize {
244 self.nodes.len()
245 }
246}
247
248impl<'a> core::fmt::Debug for EnsembleView<'a> {
249 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
250 f.debug_struct("EnsembleView")
251 .field("n_trees", &self.n_trees())
252 .field("n_features", &self.n_features())
253 .field("base_prediction", &self.base_prediction())
254 .field("total_nodes", &self.total_nodes())
255 .finish()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
263 use alloc::{format, vec, vec::Vec};
264 use core::mem::size_of;
265
266 fn build_single_leaf_binary(leaf_value: f32, base: f32) -> Vec<u8> {
268 let header = EnsembleHeader {
269 magic: EnsembleHeader::MAGIC,
270 version: EnsembleHeader::VERSION,
271 n_trees: 1,
272 n_features: 1,
273 _reserved: 0,
274 base_prediction: base,
275 };
276 let entry = TreeEntry {
277 n_nodes: 1,
278 offset: 0,
279 };
280 let node = PackedNode::leaf(leaf_value);
281
282 let mut buf = Vec::new();
283 buf.extend_from_slice(as_bytes(&header));
285 buf.extend_from_slice(as_bytes(&entry));
286 buf.extend_from_slice(as_bytes(&node));
287 buf
288 }
289
290 fn build_one_split_binary() -> Vec<u8> {
292 let header = EnsembleHeader {
293 magic: EnsembleHeader::MAGIC,
294 version: EnsembleHeader::VERSION,
295 n_trees: 1,
296 n_features: 2,
297 _reserved: 0,
298 base_prediction: 0.0,
299 };
300 let entry = TreeEntry {
301 n_nodes: 3,
302 offset: 0,
303 };
304 let nodes = [
305 PackedNode::split(5.0, 0, 1, 2),
306 PackedNode::leaf(-1.0),
307 PackedNode::leaf(1.0),
308 ];
309
310 let mut buf = Vec::new();
311 buf.extend_from_slice(as_bytes(&header));
312 buf.extend_from_slice(as_bytes(&entry));
313 for n in &nodes {
314 buf.extend_from_slice(as_bytes(n));
315 }
316 buf
317 }
318
319 fn build_two_tree_binary() -> Vec<u8> {
321 let header = EnsembleHeader {
322 magic: EnsembleHeader::MAGIC,
323 version: EnsembleHeader::VERSION,
324 n_trees: 2,
325 n_features: 2,
326 _reserved: 0,
327 base_prediction: 1.0,
328 };
329 let entries = [
332 TreeEntry {
333 n_nodes: 3,
334 offset: 0,
335 },
336 TreeEntry {
337 n_nodes: 1,
338 offset: 3 * size_of::<PackedNode>() as u32,
339 },
340 ];
341 let nodes = [
342 PackedNode::split(5.0, 0, 1, 2),
344 PackedNode::leaf(-1.0),
345 PackedNode::leaf(1.0),
346 PackedNode::leaf(0.5),
348 ];
349
350 let mut buf = Vec::new();
351 buf.extend_from_slice(as_bytes(&header));
352 for e in &entries {
353 buf.extend_from_slice(as_bytes(e));
354 }
355 for n in &nodes {
356 buf.extend_from_slice(as_bytes(n));
357 }
358 buf
359 }
360
361 fn as_bytes<T: Sized>(val: &T) -> &[u8] {
363 unsafe { core::slice::from_raw_parts(val as *const T as *const u8, size_of::<T>()) }
364 }
365
366 #[test]
367 fn parse_single_leaf() {
368 let buf = build_single_leaf_binary(42.0, 0.0);
369 let view = EnsembleView::from_bytes(&buf).unwrap();
370 assert_eq!(view.n_trees(), 1);
371 assert_eq!(view.n_features(), 1);
372 assert_eq!(view.total_nodes(), 1);
373 }
374
375 #[test]
376 fn predict_single_leaf() {
377 let buf = build_single_leaf_binary(42.0, 10.0);
378 let view = EnsembleView::from_bytes(&buf).unwrap();
379 let pred = view.predict(&[0.0]);
381 assert!((pred - 52.0).abs() < 1e-6);
382 }
383
384 #[test]
385 fn predict_one_split_left() {
386 let buf = build_one_split_binary();
387 let view = EnsembleView::from_bytes(&buf).unwrap();
388 let pred = view.predict(&[3.0, 0.0]);
390 assert!((pred - (-1.0)).abs() < 1e-6);
391 }
392
393 #[test]
394 fn predict_one_split_right() {
395 let buf = build_one_split_binary();
396 let view = EnsembleView::from_bytes(&buf).unwrap();
397 let pred = view.predict(&[7.0, 0.0]);
399 assert!((pred - 1.0).abs() < 1e-6);
400 }
401
402 #[test]
403 fn predict_two_trees() {
404 let buf = build_two_tree_binary();
405 let view = EnsembleView::from_bytes(&buf).unwrap();
406 let pred = view.predict(&[3.0, 0.0]);
409 assert!((pred - 0.5).abs() < 1e-6);
410 }
411
412 #[test]
413 fn predict_batch_matches_single() {
414 let buf = build_two_tree_binary();
415 let view = EnsembleView::from_bytes(&buf).unwrap();
416
417 let samples: Vec<&[f32]> = vec![
418 &[3.0, 0.0],
419 &[7.0, 0.0],
420 &[5.0, 0.0],
421 &[0.0, 0.0],
422 &[10.0, 0.0],
423 ];
424 let mut out = vec![0.0f32; 5];
425 view.predict_batch(&samples, &mut out);
426
427 for (i, &s) in samples.iter().enumerate() {
428 let expected = view.predict(s);
429 assert!(
430 (out[i] - expected).abs() < 1e-6,
431 "batch[{}] = {}, expected {}",
432 i,
433 out[i],
434 expected
435 );
436 }
437 }
438
439 #[test]
440 fn bad_magic_is_rejected() {
441 let mut buf = build_single_leaf_binary(0.0, 0.0);
442 buf[0] = 0xFF; assert_eq!(
444 EnsembleView::from_bytes(&buf).unwrap_err(),
445 FormatError::BadMagic
446 );
447 }
448
449 #[test]
450 fn truncated_buffer_is_rejected() {
451 let buf = build_single_leaf_binary(0.0, 0.0);
452 assert_eq!(
453 EnsembleView::from_bytes(&buf[..4]).unwrap_err(),
454 FormatError::Truncated
455 );
456 }
457
458 #[test]
459 fn bad_version_is_rejected() {
460 let mut buf = build_single_leaf_binary(0.0, 0.0);
461 buf[4] = 99;
463 buf[5] = 0;
464 assert_eq!(
465 EnsembleView::from_bytes(&buf).unwrap_err(),
466 FormatError::UnsupportedVersion
467 );
468 }
469
470 #[test]
471 fn invalid_child_index_is_rejected() {
472 let header = EnsembleHeader {
473 magic: EnsembleHeader::MAGIC,
474 version: EnsembleHeader::VERSION,
475 n_trees: 1,
476 n_features: 2,
477 _reserved: 0,
478 base_prediction: 0.0,
479 };
480 let entry = TreeEntry {
481 n_nodes: 3,
482 offset: 0,
483 };
484 let nodes = [
486 PackedNode::split(5.0, 0, 1, 99), PackedNode::leaf(-1.0),
488 PackedNode::leaf(1.0),
489 ];
490
491 let mut buf = Vec::new();
492 buf.extend_from_slice(as_bytes(&header));
493 buf.extend_from_slice(as_bytes(&entry));
494 for n in &nodes {
495 buf.extend_from_slice(as_bytes(n));
496 }
497
498 assert_eq!(
499 EnsembleView::from_bytes(&buf).unwrap_err(),
500 FormatError::InvalidNodeIndex
501 );
502 }
503
504 #[test]
505 fn debug_format() {
506 let buf = build_single_leaf_binary(0.0, 0.0);
507 let view = EnsembleView::from_bytes(&buf).unwrap();
508 let debug = format!("{:?}", view);
509 assert!(debug.contains("EnsembleView"));
510 assert!(debug.contains("n_trees"));
511 }
512}