1use crate::error::FormatError;
13use crate::packed::TreeEntry;
14use crate::packed_i16::{PackedNodeI16, QuantizedEnsembleHeader};
15use crate::traverse_i16;
16
17#[derive(Clone, Copy)]
28pub struct QuantizedEnsembleView<'a> {
29 header: &'a QuantizedEnsembleHeader,
30 leaf_scale: f32,
31 feature_scales: &'a [f32],
32 tree_table: &'a [TreeEntry],
33 nodes: &'a [PackedNodeI16],
34}
35
36impl<'a> QuantizedEnsembleView<'a> {
37 pub fn from_bytes(data: &'a [u8]) -> Result<Self, FormatError> {
62 use core::mem::{align_of, size_of};
63
64 let header_size = size_of::<QuantizedEnsembleHeader>();
65 if data.len() < header_size {
66 return Err(FormatError::Truncated);
67 }
68
69 if (data.as_ptr() as usize) % align_of::<QuantizedEnsembleHeader>() != 0 {
71 return Err(FormatError::Unaligned);
72 }
73
74 let header = unsafe { &*(data.as_ptr() as *const QuantizedEnsembleHeader) };
76
77 if header.magic != QuantizedEnsembleHeader::MAGIC {
78 return Err(FormatError::BadMagic);
79 }
80 if header.version != QuantizedEnsembleHeader::VERSION {
81 return Err(FormatError::UnsupportedVersion);
82 }
83
84 let n_trees = header.n_trees as usize;
85 let n_features = header.n_features as usize;
86
87 let leaf_scale_offset = header_size;
89 let leaf_scale_size = size_of::<f32>();
90 let feature_scales_offset = leaf_scale_offset + leaf_scale_size;
91 let feature_scales_size = n_features
92 .checked_mul(size_of::<f32>())
93 .ok_or(FormatError::Truncated)?;
94
95 let tree_table_offset = feature_scales_offset
96 .checked_add(feature_scales_size)
97 .ok_or(FormatError::Truncated)?;
98 let tree_table_size = n_trees
99 .checked_mul(size_of::<TreeEntry>())
100 .ok_or(FormatError::Truncated)?;
101
102 let nodes_base_offset = tree_table_offset
103 .checked_add(tree_table_size)
104 .ok_or(FormatError::Truncated)?;
105
106 if data.len() < nodes_base_offset {
108 return Err(FormatError::Truncated);
109 }
110
111 let leaf_scale_ptr = unsafe { data.as_ptr().add(leaf_scale_offset) } as *const f32;
115 let leaf_scale = unsafe { *leaf_scale_ptr };
116
117 let feature_scales_ptr = unsafe { data.as_ptr().add(feature_scales_offset) } as *const f32;
120 let feature_scales = unsafe { core::slice::from_raw_parts(feature_scales_ptr, n_features) };
121
122 let tree_table_ptr = unsafe { data.as_ptr().add(tree_table_offset) } as *const TreeEntry;
127 let tree_table = unsafe { core::slice::from_raw_parts(tree_table_ptr, n_trees) };
128
129 let mut total_nodes: usize = 0;
131 for entry in tree_table {
132 total_nodes = total_nodes
133 .checked_add(entry.n_nodes as usize)
134 .ok_or(FormatError::Truncated)?;
135 }
136
137 let nodes_size = total_nodes
138 .checked_mul(size_of::<PackedNodeI16>())
139 .ok_or(FormatError::Truncated)?;
140 let total_required = nodes_base_offset
141 .checked_add(nodes_size)
142 .ok_or(FormatError::Truncated)?;
143 if data.len() < total_required {
144 return Err(FormatError::Truncated);
145 }
146
147 for entry in tree_table {
149 let node_byte_offset = entry.offset as usize;
150 if node_byte_offset % size_of::<PackedNodeI16>() != 0 {
151 return Err(FormatError::MisalignedTreeOffset);
152 }
153 let tree_bytes = (entry.n_nodes as usize)
154 .checked_mul(size_of::<PackedNodeI16>())
155 .ok_or(FormatError::Truncated)?;
156 let tree_end = node_byte_offset
157 .checked_add(tree_bytes)
158 .ok_or(FormatError::Truncated)?;
159 if tree_end > nodes_size {
160 return Err(FormatError::Truncated);
161 }
162 }
163
164 let nodes_ptr = unsafe { data.as_ptr().add(nodes_base_offset) } as *const PackedNodeI16;
167 let nodes = unsafe { core::slice::from_raw_parts(nodes_ptr, total_nodes) };
168
169 for entry in tree_table {
171 let tree_node_offset = entry.offset as usize / size_of::<PackedNodeI16>();
172 let tree_n_nodes = entry.n_nodes as usize;
173
174 for local_idx in 0..tree_n_nodes {
175 let global_idx = tree_node_offset + local_idx;
176 let node = &nodes[global_idx];
177
178 if !node.is_leaf() {
179 let left = node.left_child() as usize;
180 let right = node.right_child() as usize;
181
182 if left >= tree_n_nodes || right >= tree_n_nodes {
184 return Err(FormatError::InvalidNodeIndex);
185 }
186
187 if n_features > 0 && node.feature_idx() as usize >= n_features {
188 return Err(FormatError::InvalidFeatureIndex);
189 }
190 }
191 }
192 }
193
194 Ok(Self {
195 header,
196 leaf_scale,
197 feature_scales,
198 tree_table,
199 nodes,
200 })
201 }
202
203 pub fn predict(&self, features: &[f32]) -> f32 {
218 debug_assert!(
219 features.len() >= self.header.n_features as usize,
220 "predict: features.len() ({}) < n_features ({})",
221 features.len(),
222 self.header.n_features
223 );
224
225 let mut leaf_sum: i32 = 0;
226 for entry in self.tree_table {
227 let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
228 let end = start + entry.n_nodes as usize;
229 let tree_nodes = &self.nodes[start..end];
230 leaf_sum +=
231 traverse_i16::predict_tree_i16_inline(tree_nodes, features, self.feature_scales)
232 as i32;
233 }
234
235 self.header.base_prediction + (leaf_sum as f32) / self.leaf_scale
236 }
237
238 pub fn predict_prequantized(&self, features_i16: &[i16]) -> f32 {
252 debug_assert!(
253 features_i16.len() >= self.header.n_features as usize,
254 "predict_prequantized: features_i16.len() ({}) < n_features ({})",
255 features_i16.len(),
256 self.header.n_features
257 );
258
259 let mut leaf_sum: i32 = 0;
260 for entry in self.tree_table {
261 let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
262 let end = start + entry.n_nodes as usize;
263 let tree_nodes = &self.nodes[start..end];
264 leaf_sum += traverse_i16::predict_tree_i16(tree_nodes, features_i16) as i32;
265 }
266
267 self.header.base_prediction + (leaf_sum as f32) / self.leaf_scale
268 }
269
270 pub fn predict_batch(&self, samples: &[&[f32]], out: &mut [f32]) {
280 assert!(out.len() >= samples.len());
281
282 for (i, &s) in samples.iter().enumerate() {
286 out[i] = self.predict(s);
287 }
288 }
289
290 pub fn predict_batch_prequantized(&self, samples: &[&[i16]], out: &mut [f32]) {
300 assert!(out.len() >= samples.len());
301
302 let n = samples.len();
303 let mut i = 0;
304
305 while i + 4 <= n {
307 let batch = [samples[i], samples[i + 1], samples[i + 2], samples[i + 3]];
308 let mut sums = [0i32; 4];
309 for entry in self.tree_table {
310 let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
311 let end = start + entry.n_nodes as usize;
312 let tree_nodes = &self.nodes[start..end];
313 let preds = traverse_i16::predict_tree_i16_x4(tree_nodes, batch);
314 for j in 0..4 {
315 sums[j] += preds[j] as i32;
316 }
317 }
318 for j in 0..4 {
319 out[i + j] = self.header.base_prediction + (sums[j] as f32) / self.leaf_scale;
320 }
321 i += 4;
322 }
323
324 while i < n {
326 out[i] = self.predict_prequantized(samples[i]);
327 i += 1;
328 }
329 }
330
331 #[inline]
333 pub fn n_trees(&self) -> u16 {
334 self.header.n_trees
335 }
336
337 #[inline]
339 pub fn n_features(&self) -> u16 {
340 self.header.n_features
341 }
342
343 #[inline]
345 pub fn base_prediction(&self) -> f32 {
346 self.header.base_prediction
347 }
348
349 #[inline]
351 pub fn leaf_scale(&self) -> f32 {
352 self.leaf_scale
353 }
354
355 #[inline]
357 pub fn feature_scales(&self) -> &[f32] {
358 self.feature_scales
359 }
360
361 #[inline]
363 pub fn total_nodes(&self) -> usize {
364 self.nodes.len()
365 }
366}
367
368impl<'a> core::fmt::Debug for QuantizedEnsembleView<'a> {
369 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
370 f.debug_struct("QuantizedEnsembleView")
371 .field("n_trees", &self.n_trees())
372 .field("n_features", &self.n_features())
373 .field("base_prediction", &self.base_prediction())
374 .field("leaf_scale", &self.leaf_scale())
375 .field("total_nodes", &self.total_nodes())
376 .finish()
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::packed::TreeEntry;
384 use crate::packed_i16::{PackedNodeI16, QuantizedEnsembleHeader};
385 use alloc::{format, vec, vec::Vec};
386 use core::mem::size_of;
387
388 fn as_bytes<T: Sized>(val: &T) -> &[u8] {
390 unsafe { core::slice::from_raw_parts(val as *const T as *const u8, size_of::<T>()) }
391 }
392
393 fn build_single_leaf_binary(leaf_value: i16, base: f32, leaf_scale: f32) -> Vec<u8> {
397 let header = QuantizedEnsembleHeader {
398 magic: QuantizedEnsembleHeader::MAGIC,
399 version: QuantizedEnsembleHeader::VERSION,
400 n_trees: 1,
401 n_features: 1,
402 _reserved: 0,
403 base_prediction: base,
404 };
405 let feature_scale: f32 = 1.0;
406 let entry = TreeEntry {
407 n_nodes: 1,
408 offset: 0,
409 };
410 let node = PackedNodeI16::leaf(leaf_value);
411
412 let mut buf = Vec::new();
413 buf.extend_from_slice(as_bytes(&header));
414 buf.extend_from_slice(as_bytes(&leaf_scale));
415 buf.extend_from_slice(as_bytes(&feature_scale));
416 buf.extend_from_slice(as_bytes(&entry));
417 buf.extend_from_slice(as_bytes(&node));
418 buf
419 }
420
421 fn build_one_split_binary() -> Vec<u8> {
424 let header = QuantizedEnsembleHeader {
425 magic: QuantizedEnsembleHeader::MAGIC,
426 version: QuantizedEnsembleHeader::VERSION,
427 n_trees: 1,
428 n_features: 2,
429 _reserved: 0,
430 base_prediction: 0.0,
431 };
432 let leaf_scale: f32 = 100.0;
433 let feature_scales: [f32; 2] = [100.0, 100.0];
434 let entry = TreeEntry {
435 n_nodes: 3,
436 offset: 0,
437 };
438 let nodes = [
439 PackedNodeI16::split(500, 0, 1, 2), PackedNodeI16::leaf(-100), PackedNodeI16::leaf(100), ];
443
444 let mut buf = Vec::new();
445 buf.extend_from_slice(as_bytes(&header));
446 buf.extend_from_slice(as_bytes(&leaf_scale));
447 for s in &feature_scales {
448 buf.extend_from_slice(as_bytes(s));
449 }
450 buf.extend_from_slice(as_bytes(&entry));
451 for n in &nodes {
452 buf.extend_from_slice(as_bytes(n));
453 }
454 buf
455 }
456
457 fn build_two_tree_binary() -> Vec<u8> {
459 let header = QuantizedEnsembleHeader {
460 magic: QuantizedEnsembleHeader::MAGIC,
461 version: QuantizedEnsembleHeader::VERSION,
462 n_trees: 2,
463 n_features: 2,
464 _reserved: 0,
465 base_prediction: 1.0,
466 };
467 let leaf_scale: f32 = 100.0;
468 let feature_scales: [f32; 2] = [100.0, 100.0];
469 let entries = [
472 TreeEntry {
473 n_nodes: 3,
474 offset: 0,
475 },
476 TreeEntry {
477 n_nodes: 1,
478 offset: 3 * size_of::<PackedNodeI16>() as u32,
479 },
480 ];
481 let nodes = [
482 PackedNodeI16::split(500, 0, 1, 2),
484 PackedNodeI16::leaf(-100), PackedNodeI16::leaf(100), PackedNodeI16::leaf(50), ];
489
490 let mut buf = Vec::new();
491 buf.extend_from_slice(as_bytes(&header));
492 buf.extend_from_slice(as_bytes(&leaf_scale));
493 for s in &feature_scales {
494 buf.extend_from_slice(as_bytes(s));
495 }
496 for e in &entries {
497 buf.extend_from_slice(as_bytes(e));
498 }
499 for n in &nodes {
500 buf.extend_from_slice(as_bytes(n));
501 }
502 buf
503 }
504
505 #[test]
506 fn parse_single_leaf_i16() {
507 let buf = build_single_leaf_binary(42, 0.0, 100.0);
508 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
509 assert_eq!(view.n_trees(), 1);
510 assert_eq!(view.n_features(), 1);
511 assert_eq!(view.total_nodes(), 1);
512 assert_eq!(view.leaf_scale(), 100.0);
513 }
514
515 #[test]
516 fn predict_single_leaf_i16() {
517 let buf = build_single_leaf_binary(42, 10.0, 100.0);
520 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
521 let pred = view.predict(&[0.0]);
522 assert!((pred - 10.42).abs() < 1e-5, "expected 10.42, got {}", pred);
523 }
524
525 #[test]
526 fn predict_one_split_left_i16() {
527 let buf = build_one_split_binary();
528 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
529 let pred = view.predict(&[3.0, 0.0]);
532 assert!((pred - (-1.0)).abs() < 1e-5, "expected -1.0, got {}", pred);
533 }
534
535 #[test]
536 fn predict_one_split_right_i16() {
537 let buf = build_one_split_binary();
538 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
539 let pred = view.predict(&[7.0, 0.0]);
542 assert!((pred - 1.0).abs() < 1e-5, "expected 1.0, got {}", pred);
543 }
544
545 #[test]
546 fn predict_two_trees_i16() {
547 let buf = build_two_tree_binary();
548 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
549 let pred = view.predict(&[3.0, 0.0]);
553 assert!((pred - 0.5).abs() < 1e-5, "expected 0.5, got {}", pred);
554 }
555
556 #[test]
557 fn predict_prequantized_matches_predict() {
558 let buf = build_one_split_binary();
559 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
560
561 let pred_inline = view.predict(&[3.0, 0.0]);
563 let pred_preq = view.predict_prequantized(&[300, 0]);
564 assert!(
565 (pred_inline - pred_preq).abs() < 1e-5,
566 "left: inline={}, prequantized={}",
567 pred_inline,
568 pred_preq
569 );
570
571 let pred_inline = view.predict(&[7.0, 0.0]);
573 let pred_preq = view.predict_prequantized(&[700, 0]);
574 assert!(
575 (pred_inline - pred_preq).abs() < 1e-5,
576 "right: inline={}, prequantized={}",
577 pred_inline,
578 pred_preq
579 );
580 }
581
582 #[test]
583 fn bad_magic_rejected_i16() {
584 let mut buf = build_single_leaf_binary(0, 0.0, 100.0);
585 buf[0] = 0xFF; assert_eq!(
587 QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
588 FormatError::BadMagic
589 );
590 }
591
592 #[test]
593 fn truncated_rejected_i16() {
594 let buf = build_single_leaf_binary(0, 0.0, 100.0);
595 assert_eq!(
597 QuantizedEnsembleView::from_bytes(&buf[..4]).unwrap_err(),
598 FormatError::Truncated
599 );
600 }
601
602 #[test]
603 fn debug_format_i16() {
604 let buf = build_single_leaf_binary(0, 0.0, 100.0);
605 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
606 let debug = format!("{:?}", view);
607 assert!(
608 debug.contains("QuantizedEnsembleView"),
609 "missing struct name in debug: {}",
610 debug
611 );
612 assert!(
613 debug.contains("n_trees"),
614 "missing n_trees in debug: {}",
615 debug
616 );
617 assert!(
618 debug.contains("leaf_scale"),
619 "missing leaf_scale in debug: {}",
620 debug
621 );
622 }
623
624 #[test]
625 fn predict_batch_matches_single_i16() {
626 let buf = build_two_tree_binary();
627 let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
628
629 let samples: Vec<&[f32]> = vec![
630 &[3.0, 0.0],
631 &[7.0, 0.0],
632 &[5.0, 0.0],
633 &[0.0, 0.0],
634 &[10.0, 0.0],
635 ];
636 let mut out = vec![0.0f32; 5];
637 view.predict_batch(&samples, &mut out);
638
639 for (i, &s) in samples.iter().enumerate() {
640 let expected = view.predict(s);
641 assert!(
642 (out[i] - expected).abs() < 1e-6,
643 "batch[{}] = {}, expected {}",
644 i,
645 out[i],
646 expected
647 );
648 }
649 }
650
651 #[test]
652 fn bad_version_rejected_i16() {
653 let mut buf = build_single_leaf_binary(0, 0.0, 100.0);
654 buf[4] = 99;
656 buf[5] = 0;
657 assert_eq!(
658 QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
659 FormatError::UnsupportedVersion
660 );
661 }
662
663 #[test]
664 fn invalid_child_index_rejected_i16() {
665 let header = QuantizedEnsembleHeader {
666 magic: QuantizedEnsembleHeader::MAGIC,
667 version: QuantizedEnsembleHeader::VERSION,
668 n_trees: 1,
669 n_features: 2,
670 _reserved: 0,
671 base_prediction: 0.0,
672 };
673 let leaf_scale: f32 = 100.0;
674 let feature_scales: [f32; 2] = [100.0, 100.0];
675 let entry = TreeEntry {
676 n_nodes: 3,
677 offset: 0,
678 };
679 let nodes = [
681 PackedNodeI16::split(500, 0, 1, 99),
682 PackedNodeI16::leaf(-100),
683 PackedNodeI16::leaf(100),
684 ];
685
686 let mut buf = Vec::new();
687 buf.extend_from_slice(as_bytes(&header));
688 buf.extend_from_slice(as_bytes(&leaf_scale));
689 for s in &feature_scales {
690 buf.extend_from_slice(as_bytes(s));
691 }
692 buf.extend_from_slice(as_bytes(&entry));
693 for n in &nodes {
694 buf.extend_from_slice(as_bytes(n));
695 }
696
697 assert_eq!(
698 QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
699 FormatError::InvalidNodeIndex
700 );
701 }
702}