1use forestfire_data::{
2 BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
3};
4#[cfg(feature = "polars")]
5use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
6use rayon::ThreadPoolBuilder;
7use rayon::prelude::*;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::fmt::{Display, Formatter};
13use std::sync::Arc;
14use wide::{u16x8, u32x8};
15
16mod boosting;
17mod bootstrap;
18mod forest;
19pub mod ir;
20mod sampling;
21mod training;
22pub mod tree;
23
24pub use boosting::BoostingError;
25pub use boosting::GradientBoostedTrees;
26pub use forest::RandomForest;
27pub use ir::IrError;
28pub use ir::ModelPackageIr;
29pub use tree::classifier::DecisionTreeAlgorithm;
30pub use tree::classifier::DecisionTreeClassifier;
31pub use tree::classifier::DecisionTreeError;
32pub use tree::classifier::DecisionTreeOptions;
33pub use tree::classifier::train_c45;
34pub use tree::classifier::train_cart;
35pub use tree::classifier::train_id3;
36pub use tree::classifier::train_oblivious;
37pub use tree::classifier::train_randomized;
38pub use tree::regressor::DecisionTreeRegressor;
39pub use tree::regressor::RegressionTreeAlgorithm;
40pub use tree::regressor::RegressionTreeError;
41pub use tree::regressor::RegressionTreeOptions;
42pub use tree::regressor::train_cart_regressor;
43pub use tree::regressor::train_oblivious_regressor;
44pub use tree::regressor::train_randomized_regressor;
45
46const PARALLEL_INFERENCE_ROW_THRESHOLD: usize = 256;
47const PARALLEL_INFERENCE_CHUNK_ROWS: usize = 256;
48const STANDARD_BATCH_INFERENCE_CHUNK_ROWS: usize = 4096;
49const OBLIVIOUS_SIMD_LANES: usize = 8;
50#[cfg(feature = "polars")]
51const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
52const COMPILED_ARTIFACT_MAGIC: [u8; 4] = *b"FFCA";
53const COMPILED_ARTIFACT_VERSION: u16 = 1;
54const COMPILED_ARTIFACT_BACKEND_CPU: u16 = 1;
55const COMPILED_ARTIFACT_HEADER_LEN: usize = 8;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum TrainAlgorithm {
59 Dt,
60 Rf,
61 Gbm,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum Criterion {
66 Auto,
67 Gini,
68 Entropy,
69 Mean,
70 Median,
71 SecondOrder,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum Task {
76 Regression,
77 Classification,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum TreeType {
82 Id3,
83 C45,
84 Cart,
85 Randomized,
86 Oblivious,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum MaxFeatures {
91 Auto,
92 All,
93 Sqrt,
94 Third,
95 Count(usize),
96}
97
98impl MaxFeatures {
99 pub fn resolve(self, task: Task, feature_count: usize) -> usize {
100 match self {
101 MaxFeatures::Auto => match task {
102 Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
103 Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
104 },
105 MaxFeatures::All => feature_count.max(1),
106 MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
107 MaxFeatures::Third => (feature_count / 3).max(1),
108 MaxFeatures::Count(count) => count.min(feature_count).max(1),
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum InputFeatureKind {
116 Numeric,
117 Binary,
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
121pub struct NumericBinBoundary {
122 pub bin: u16,
123 pub upper_bound: f64,
124}
125
126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
127#[serde(tag = "kind", rename_all = "snake_case")]
128pub enum FeaturePreprocessing {
129 Numeric {
130 bin_boundaries: Vec<NumericBinBoundary>,
131 },
132 Binary,
133}
134
135#[derive(Debug, Clone, Copy)]
136pub struct TrainConfig {
137 pub algorithm: TrainAlgorithm,
138 pub task: Task,
139 pub tree_type: TreeType,
140 pub criterion: Criterion,
141 pub max_depth: Option<usize>,
142 pub min_samples_split: Option<usize>,
143 pub min_samples_leaf: Option<usize>,
144 pub physical_cores: Option<usize>,
145 pub n_trees: Option<usize>,
146 pub max_features: MaxFeatures,
147 pub seed: Option<u64>,
148 pub compute_oob: bool,
149 pub learning_rate: Option<f64>,
150 pub bootstrap: bool,
151 pub top_gradient_fraction: Option<f64>,
152 pub other_gradient_fraction: Option<f64>,
153}
154
155impl Default for TrainConfig {
156 fn default() -> Self {
157 Self {
158 algorithm: TrainAlgorithm::Dt,
159 task: Task::Regression,
160 tree_type: TreeType::Cart,
161 criterion: Criterion::Auto,
162 max_depth: None,
163 min_samples_split: None,
164 min_samples_leaf: None,
165 physical_cores: None,
166 n_trees: None,
167 max_features: MaxFeatures::Auto,
168 seed: None,
169 compute_oob: false,
170 learning_rate: None,
171 bootstrap: false,
172 top_gradient_fraction: None,
173 other_gradient_fraction: None,
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
179pub enum Model {
180 DecisionTreeClassifier(DecisionTreeClassifier),
181 DecisionTreeRegressor(DecisionTreeRegressor),
182 RandomForest(RandomForest),
183 GradientBoostedTrees(GradientBoostedTrees),
184}
185
186#[derive(Debug)]
187pub enum TrainError {
188 DecisionTree(DecisionTreeError),
189 RegressionTree(RegressionTreeError),
190 Boosting(BoostingError),
191 InvalidPhysicalCoreCount {
192 requested: usize,
193 available: usize,
194 },
195 ThreadPoolBuildFailed(String),
196 UnsupportedConfiguration {
197 task: Task,
198 tree_type: TreeType,
199 criterion: Criterion,
200 },
201 InvalidMaxDepth(usize),
202 InvalidMinSamplesSplit(usize),
203 InvalidMinSamplesLeaf(usize),
204 InvalidTreeCount(usize),
205 InvalidMaxFeatures(usize),
206}
207
208impl Display for TrainError {
209 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210 match self {
211 TrainError::DecisionTree(err) => err.fmt(f),
212 TrainError::RegressionTree(err) => err.fmt(f),
213 TrainError::Boosting(err) => err.fmt(f),
214 TrainError::InvalidPhysicalCoreCount {
215 requested,
216 available,
217 } => write!(
218 f,
219 "Requested {} physical cores, but the available physical core count is {}.",
220 requested, available
221 ),
222 TrainError::ThreadPoolBuildFailed(message) => {
223 write!(f, "Failed to build training thread pool: {}.", message)
224 }
225 TrainError::UnsupportedConfiguration {
226 task,
227 tree_type,
228 criterion,
229 } => write!(
230 f,
231 "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
232 task, tree_type, criterion
233 ),
234 TrainError::InvalidMaxDepth(value) => {
235 write!(f, "max_depth must be at least 1. Received {}.", value)
236 }
237 TrainError::InvalidMinSamplesSplit(value) => {
238 write!(
239 f,
240 "min_samples_split must be at least 1. Received {}.",
241 value
242 )
243 }
244 TrainError::InvalidMinSamplesLeaf(value) => {
245 write!(
246 f,
247 "min_samples_leaf must be at least 1. Received {}.",
248 value
249 )
250 }
251 TrainError::InvalidTreeCount(n_trees) => {
252 write!(
253 f,
254 "Random forest requires at least one tree. Received {}.",
255 n_trees
256 )
257 }
258 TrainError::InvalidMaxFeatures(count) => {
259 write!(
260 f,
261 "max_features must be at least 1 when provided as an integer. Received {}.",
262 count
263 )
264 }
265 }
266 }
267}
268
269impl Error for TrainError {}
270
271#[derive(Debug, Clone, PartialEq)]
272pub enum PredictError {
273 ProbabilityPredictionRequiresClassification,
274 RaggedRows {
275 row: usize,
276 expected: usize,
277 actual: usize,
278 },
279 FeatureCountMismatch {
280 expected: usize,
281 actual: usize,
282 },
283 ColumnLengthMismatch {
284 feature: String,
285 expected: usize,
286 actual: usize,
287 },
288 MissingFeature(String),
289 UnexpectedFeature(String),
290 InvalidBinaryValue {
291 feature_index: usize,
292 row_index: usize,
293 value: f64,
294 },
295 NullValue {
296 feature: String,
297 row_index: usize,
298 },
299 UnsupportedFeatureType {
300 feature: String,
301 dtype: String,
302 },
303 Polars(String),
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct TreeStructureSummary {
308 pub representation: String,
309 pub node_count: usize,
310 pub internal_node_count: usize,
311 pub leaf_count: usize,
312 pub actual_depth: usize,
313 pub shortest_path: usize,
314 pub longest_path: usize,
315 pub average_path: f64,
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct PredictionValueStats {
320 pub count: usize,
321 pub unique_count: usize,
322 pub min: f64,
323 pub max: f64,
324 pub mean: f64,
325 pub std_dev: f64,
326 pub histogram: Vec<PredictionHistogramEntry>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct PredictionHistogramEntry {
331 pub prediction: f64,
332 pub count: usize,
333}
334
335#[derive(Debug, Clone, PartialEq, Eq)]
336pub enum IntrospectionError {
337 TreeIndexOutOfBounds { requested: usize, available: usize },
338 NodeIndexOutOfBounds { requested: usize, available: usize },
339 LevelIndexOutOfBounds { requested: usize, available: usize },
340 LeafIndexOutOfBounds { requested: usize, available: usize },
341 NotANodeTree,
342 NotAnObliviousTree,
343}
344
345impl Display for IntrospectionError {
346 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
347 match self {
348 IntrospectionError::TreeIndexOutOfBounds {
349 requested,
350 available,
351 } => write!(
352 f,
353 "Tree index {} is out of bounds for model with {} trees.",
354 requested, available
355 ),
356 IntrospectionError::NodeIndexOutOfBounds {
357 requested,
358 available,
359 } => write!(
360 f,
361 "Node index {} is out of bounds for tree with {} nodes.",
362 requested, available
363 ),
364 IntrospectionError::LevelIndexOutOfBounds {
365 requested,
366 available,
367 } => write!(
368 f,
369 "Level index {} is out of bounds for tree with {} levels.",
370 requested, available
371 ),
372 IntrospectionError::LeafIndexOutOfBounds {
373 requested,
374 available,
375 } => write!(
376 f,
377 "Leaf index {} is out of bounds for tree with {} leaves.",
378 requested, available
379 ),
380 IntrospectionError::NotANodeTree => write!(
381 f,
382 "This tree uses oblivious-level representation; inspect levels or leaves instead."
383 ),
384 IntrospectionError::NotAnObliviousTree => write!(
385 f,
386 "This tree uses node-tree representation; inspect nodes instead."
387 ),
388 }
389 }
390}
391
392impl Error for IntrospectionError {}
393
394impl Display for PredictError {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 match self {
397 PredictError::ProbabilityPredictionRequiresClassification => write!(
398 f,
399 "predict_proba is only available for classification models."
400 ),
401 PredictError::RaggedRows {
402 row,
403 expected,
404 actual,
405 } => write!(
406 f,
407 "Ragged inference row at index {}: expected {} columns, found {}.",
408 row, expected, actual
409 ),
410 PredictError::FeatureCountMismatch { expected, actual } => write!(
411 f,
412 "Inference input has {} features, but the model expects {}.",
413 actual, expected
414 ),
415 PredictError::ColumnLengthMismatch {
416 feature,
417 expected,
418 actual,
419 } => write!(
420 f,
421 "Feature '{}' has {} values, expected {}.",
422 feature, actual, expected
423 ),
424 PredictError::MissingFeature(feature) => {
425 write!(f, "Missing required feature '{}'.", feature)
426 }
427 PredictError::UnexpectedFeature(feature) => {
428 write!(f, "Unexpected feature '{}'.", feature)
429 }
430 PredictError::InvalidBinaryValue {
431 feature_index,
432 row_index,
433 value,
434 } => write!(
435 f,
436 "Feature {} at row {} must be binary for inference, found {}.",
437 feature_index, row_index, value
438 ),
439 PredictError::NullValue { feature, row_index } => write!(
440 f,
441 "Feature '{}' contains a null value at row {}.",
442 feature, row_index
443 ),
444 PredictError::UnsupportedFeatureType { feature, dtype } => write!(
445 f,
446 "Feature '{}' has unsupported dtype '{}'.",
447 feature, dtype
448 ),
449 PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
450 }
451 }
452}
453
454impl Error for PredictError {}
455
456#[cfg(feature = "polars")]
457impl From<polars::error::PolarsError> for PredictError {
458 fn from(value: polars::error::PolarsError) -> Self {
459 PredictError::Polars(value.to_string())
460 }
461}
462
463#[derive(Debug)]
464pub enum OptimizeError {
465 InvalidPhysicalCoreCount { requested: usize, available: usize },
466 ThreadPoolBuildFailed(String),
467 UnsupportedModelType(&'static str),
468}
469
470impl Display for OptimizeError {
471 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
472 match self {
473 OptimizeError::InvalidPhysicalCoreCount {
474 requested,
475 available,
476 } => write!(
477 f,
478 "Requested {} physical cores, but the available physical core count is {}.",
479 requested, available
480 ),
481 OptimizeError::ThreadPoolBuildFailed(message) => {
482 write!(f, "Failed to build inference thread pool: {}.", message)
483 }
484 OptimizeError::UnsupportedModelType(model_type) => {
485 write!(
486 f,
487 "Optimized inference is not supported for model type '{}'.",
488 model_type
489 )
490 }
491 }
492 }
493}
494
495impl Error for OptimizeError {}
496
497#[derive(Debug)]
498pub enum CompiledArtifactError {
499 ArtifactTooShort { actual: usize, minimum: usize },
500 InvalidMagic([u8; 4]),
501 UnsupportedVersion(u16),
502 UnsupportedBackend(u16),
503 Encode(String),
504 Decode(String),
505 InvalidSemanticModel(IrError),
506 InvalidRuntime(OptimizeError),
507}
508
509impl Display for CompiledArtifactError {
510 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
511 match self {
512 CompiledArtifactError::ArtifactTooShort { actual, minimum } => write!(
513 f,
514 "Compiled artifact is too short: expected at least {} bytes, found {}.",
515 minimum, actual
516 ),
517 CompiledArtifactError::InvalidMagic(magic) => {
518 write!(f, "Compiled artifact has invalid magic bytes: {:?}.", magic)
519 }
520 CompiledArtifactError::UnsupportedVersion(version) => {
521 write!(f, "Unsupported compiled artifact version: {}.", version)
522 }
523 CompiledArtifactError::UnsupportedBackend(backend) => {
524 write!(f, "Unsupported compiled artifact backend: {}.", backend)
525 }
526 CompiledArtifactError::Encode(message) => {
527 write!(f, "Failed to encode compiled artifact: {}.", message)
528 }
529 CompiledArtifactError::Decode(message) => {
530 write!(f, "Failed to decode compiled artifact: {}.", message)
531 }
532 CompiledArtifactError::InvalidSemanticModel(err) => err.fmt(f),
533 CompiledArtifactError::InvalidRuntime(err) => err.fmt(f),
534 }
535 }
536}
537
538impl Error for CompiledArtifactError {}
539
540#[derive(Debug, Clone, Copy)]
541pub(crate) struct Parallelism {
542 thread_count: usize,
543}
544
545impl Parallelism {
546 pub(crate) fn sequential() -> Self {
547 Self { thread_count: 1 }
548 }
549
550 pub(crate) fn enabled(self) -> bool {
551 self.thread_count > 1
552 }
553}
554
555pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
556 (0..table.n_features())
557 .map(|feature_index| {
558 if table.is_binary_feature(feature_index) {
559 FeaturePreprocessing::Binary
560 } else {
561 let values = (0..table.n_rows())
562 .map(|row_index| table.feature_value(feature_index, row_index))
563 .collect::<Vec<_>>();
564 FeaturePreprocessing::Numeric {
565 bin_boundaries: numeric_bin_boundaries(
566 &values,
567 NumericBins::Fixed(table.numeric_bin_cap()),
568 )
569 .into_iter()
570 .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
571 .collect(),
572 }
573 }
574 })
575 .collect()
576}
577
578#[derive(Debug, Clone)]
579enum InferenceFeatureColumn {
580 Numeric(Vec<f64>),
581 Binary(Vec<bool>),
582}
583
584#[derive(Debug, Clone)]
585enum InferenceBinnedColumn {
586 Numeric(Vec<u16>),
587 Binary(Vec<bool>),
588}
589
590#[derive(Debug, Clone)]
591enum CompactBinnedColumn {
592 U8(Vec<u8>),
593 U16(Vec<u16>),
594}
595
596impl CompactBinnedColumn {
597 #[inline(always)]
598 fn value_at(&self, row_index: usize) -> u16 {
599 match self {
600 CompactBinnedColumn::U8(values) => u16::from(values[row_index]),
601 CompactBinnedColumn::U16(values) => values[row_index],
602 }
603 }
604
605 #[inline(always)]
606 fn slice_u8(&self, start: usize, len: usize) -> Option<&[u8]> {
607 match self {
608 CompactBinnedColumn::U8(values) => Some(&values[start..start + len]),
609 CompactBinnedColumn::U16(_) => None,
610 }
611 }
612
613 #[inline(always)]
614 fn slice_u16(&self, start: usize, len: usize) -> Option<&[u16]> {
615 match self {
616 CompactBinnedColumn::U8(_) => None,
617 CompactBinnedColumn::U16(values) => Some(&values[start..start + len]),
618 }
619 }
620}
621
622#[derive(Debug, Clone)]
623pub(crate) struct InferenceTable {
624 feature_columns: Vec<InferenceFeatureColumn>,
625 binned_feature_columns: Vec<InferenceBinnedColumn>,
626 n_rows: usize,
627}
628
629impl InferenceTable {
630 pub(crate) fn from_rows(
631 rows: Vec<Vec<f64>>,
632 preprocessing: &[FeaturePreprocessing],
633 ) -> Result<Self, PredictError> {
634 let expected = preprocessing.len();
635 if let Some((row_index, actual)) = rows
636 .iter()
637 .enumerate()
638 .find_map(|(row_index, row)| (row.len() != expected).then_some((row_index, row.len())))
639 {
640 return Err(PredictError::RaggedRows {
641 row: row_index,
642 expected,
643 actual,
644 });
645 }
646
647 let columns = (0..expected)
648 .map(|feature_index| {
649 rows.iter()
650 .map(|row| row[feature_index])
651 .collect::<Vec<_>>()
652 })
653 .collect::<Vec<_>>();
654
655 Self::from_columns(columns, preprocessing)
656 }
657
658 pub(crate) fn from_named_columns(
659 columns: BTreeMap<String, Vec<f64>>,
660 preprocessing: &[FeaturePreprocessing],
661 ) -> Result<Self, PredictError> {
662 let expected = preprocessing.len();
663 if columns.len() != expected {
664 for feature_index in 0..expected {
665 let name = format!("f{}", feature_index);
666 if !columns.contains_key(&name) {
667 return Err(PredictError::MissingFeature(name));
668 }
669 }
670 if let Some(unexpected) = columns.keys().find(|name| {
671 name.strip_prefix('f')
672 .and_then(|idx| idx.parse::<usize>().ok())
673 .is_none_or(|idx| idx >= expected)
674 }) {
675 return Err(PredictError::UnexpectedFeature(unexpected.clone()));
676 }
677 }
678
679 let n_rows = columns.values().next().map_or(0, Vec::len);
680 let ordered = (0..expected)
681 .map(|feature_index| {
682 let feature_name = format!("f{}", feature_index);
683 let values = columns
684 .get(&feature_name)
685 .ok_or_else(|| PredictError::MissingFeature(feature_name.clone()))?;
686 if values.len() != n_rows {
687 return Err(PredictError::ColumnLengthMismatch {
688 feature: feature_name,
689 expected: n_rows,
690 actual: values.len(),
691 });
692 }
693 Ok(values.clone())
694 })
695 .collect::<Result<Vec<_>, _>>()?;
696
697 Self::from_columns(ordered, preprocessing)
698 }
699
700 pub(crate) fn from_sparse_binary_columns(
701 n_rows: usize,
702 n_features: usize,
703 columns: Vec<Vec<usize>>,
704 preprocessing: &[FeaturePreprocessing],
705 ) -> Result<Self, PredictError> {
706 if n_features != preprocessing.len() {
707 return Err(PredictError::FeatureCountMismatch {
708 expected: preprocessing.len(),
709 actual: n_features,
710 });
711 }
712
713 let mut dense_columns = Vec::with_capacity(n_features);
714 for (feature_index, row_indices) in columns.into_iter().enumerate() {
715 match preprocessing.get(feature_index) {
716 Some(FeaturePreprocessing::Binary) => {
717 let mut values = vec![false; n_rows];
718 for row_index in row_indices {
719 if row_index >= n_rows {
720 return Err(PredictError::ColumnLengthMismatch {
721 feature: format!("f{}", feature_index),
722 expected: n_rows,
723 actual: row_index + 1,
724 });
725 }
726 values[row_index] = true;
727 }
728 dense_columns.push(values.into_iter().map(f64::from).collect());
729 }
730 Some(FeaturePreprocessing::Numeric { .. }) => {
731 return Err(PredictError::InvalidBinaryValue {
732 feature_index,
733 row_index: 0,
734 value: 1.0,
735 });
736 }
737 None => unreachable!("validated feature count"),
738 }
739 }
740
741 Self::from_columns(dense_columns, preprocessing)
742 }
743
744 fn from_columns(
745 columns: Vec<Vec<f64>>,
746 preprocessing: &[FeaturePreprocessing],
747 ) -> Result<Self, PredictError> {
748 if columns.len() != preprocessing.len() {
749 return Err(PredictError::FeatureCountMismatch {
750 expected: preprocessing.len(),
751 actual: columns.len(),
752 });
753 }
754
755 let n_rows = columns.first().map_or(0, Vec::len);
756 let mut feature_columns = Vec::with_capacity(columns.len());
757 let mut binned_feature_columns = Vec::with_capacity(columns.len());
758
759 for (feature_index, (column, feature_preprocessing)) in
760 columns.into_iter().zip(preprocessing.iter()).enumerate()
761 {
762 if column.len() != n_rows {
763 return Err(PredictError::ColumnLengthMismatch {
764 feature: format!("f{}", feature_index),
765 expected: n_rows,
766 actual: column.len(),
767 });
768 }
769 match feature_preprocessing {
770 FeaturePreprocessing::Binary => {
771 let values = column
772 .into_iter()
773 .enumerate()
774 .map(|(row_index, value)| match value {
775 v if v.total_cmp(&0.0).is_eq() => Ok(false),
776 v if v.total_cmp(&1.0).is_eq() => Ok(true),
777 v => Err(PredictError::InvalidBinaryValue {
778 feature_index,
779 row_index,
780 value: v,
781 }),
782 })
783 .collect::<Result<Vec<_>, _>>()?;
784 feature_columns.push(InferenceFeatureColumn::Binary(values.clone()));
785 binned_feature_columns.push(InferenceBinnedColumn::Binary(values));
786 }
787 FeaturePreprocessing::Numeric { bin_boundaries } => {
788 let bins = column
789 .iter()
790 .map(|value| infer_numeric_bin(*value, bin_boundaries))
791 .collect();
792 feature_columns.push(InferenceFeatureColumn::Numeric(column));
793 binned_feature_columns.push(InferenceBinnedColumn::Numeric(bins));
794 }
795 }
796 }
797
798 Ok(Self {
799 feature_columns,
800 binned_feature_columns,
801 n_rows,
802 })
803 }
804
805 pub(crate) fn to_column_major_binned_matrix(&self) -> ColumnMajorBinnedMatrix {
806 let n_features = self.feature_columns.len();
807 let columns = (0..n_features)
808 .map(
809 |feature_index| match &self.binned_feature_columns[feature_index] {
810 InferenceBinnedColumn::Numeric(values) => compact_binned_column(values),
811 InferenceBinnedColumn::Binary(values) => CompactBinnedColumn::U8(
812 values.iter().map(|value| u8::from(*value)).collect(),
813 ),
814 },
815 )
816 .collect();
817
818 ColumnMajorBinnedMatrix {
819 n_rows: self.n_rows,
820 columns,
821 }
822 }
823}
824
825#[derive(Debug, Clone)]
826struct ColumnMajorBinnedMatrix {
827 n_rows: usize,
828 columns: Vec<CompactBinnedColumn>,
829}
830
831impl ColumnMajorBinnedMatrix {
832 fn from_table_access(table: &dyn TableAccess) -> Self {
833 let columns = (0..table.n_features())
834 .map(|feature_index| {
835 if table.is_binary_binned_feature(feature_index) {
836 CompactBinnedColumn::U8(
837 (0..table.n_rows())
838 .map(|row_index| {
839 u8::from(
840 table
841 .binned_boolean_value(feature_index, row_index)
842 .unwrap_or(false),
843 )
844 })
845 .collect(),
846 )
847 } else {
848 compact_binned_column(
849 &(0..table.n_rows())
850 .map(|row_index| table.binned_value(feature_index, row_index))
851 .collect::<Vec<_>>(),
852 )
853 }
854 })
855 .collect();
856
857 Self {
858 n_rows: table.n_rows(),
859 columns,
860 }
861 }
862
863 #[inline(always)]
864 fn column(&self, feature_index: usize) -> &CompactBinnedColumn {
865 &self.columns[feature_index]
866 }
867}
868
869fn infer_numeric_bin(value: f64, boundaries: &[NumericBinBoundary]) -> u16 {
870 boundaries
871 .iter()
872 .find(|boundary| value <= boundary.upper_bound)
873 .map_or_else(
874 || boundaries.last().map_or(0, |boundary| boundary.bin),
875 |boundary| boundary.bin,
876 )
877}
878
879fn compact_binned_column(values: &[u16]) -> CompactBinnedColumn {
880 if values.iter().all(|value| *value <= u16::from(u8::MAX)) {
881 CompactBinnedColumn::U8(values.iter().map(|value| *value as u8).collect())
882 } else {
883 CompactBinnedColumn::U16(values.to_vec())
884 }
885}
886
887impl TableAccess for InferenceTable {
888 fn n_rows(&self) -> usize {
889 self.n_rows
890 }
891
892 fn n_features(&self) -> usize {
893 self.feature_columns.len()
894 }
895
896 fn canaries(&self) -> usize {
897 0
898 }
899
900 fn numeric_bin_cap(&self) -> usize {
901 MAX_NUMERIC_BINS
902 }
903
904 fn binned_feature_count(&self) -> usize {
905 self.binned_feature_columns.len()
906 }
907
908 fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
909 match &self.feature_columns[feature_index] {
910 InferenceFeatureColumn::Numeric(values) => values[row_index],
911 InferenceFeatureColumn::Binary(values) => f64::from(u8::from(values[row_index])),
912 }
913 }
914
915 fn is_binary_feature(&self, index: usize) -> bool {
916 matches!(
917 self.feature_columns[index],
918 InferenceFeatureColumn::Binary(_)
919 )
920 }
921
922 fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
923 match &self.binned_feature_columns[feature_index] {
924 InferenceBinnedColumn::Numeric(values) => values[row_index],
925 InferenceBinnedColumn::Binary(values) => u16::from(values[row_index]),
926 }
927 }
928
929 fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
930 match &self.binned_feature_columns[feature_index] {
931 InferenceBinnedColumn::Numeric(_) => None,
932 InferenceBinnedColumn::Binary(values) => Some(values[row_index]),
933 }
934 }
935
936 fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
937 BinnedColumnKind::Real {
938 source_index: index,
939 }
940 }
941
942 fn is_binary_binned_feature(&self, index: usize) -> bool {
943 matches!(
944 self.binned_feature_columns[index],
945 InferenceBinnedColumn::Binary(_)
946 )
947 }
948
949 fn target_value(&self, _row_index: usize) -> f64 {
950 0.0
951 }
952}
953
954#[derive(Debug, Clone, Serialize, Deserialize)]
955enum OptimizedRuntime {
956 BinaryClassifier {
957 nodes: Vec<OptimizedBinaryClassifierNode>,
958 class_labels: Vec<f64>,
959 },
960 StandardClassifier {
961 nodes: Vec<OptimizedClassifierNode>,
962 root: usize,
963 class_labels: Vec<f64>,
964 },
965 ObliviousClassifier {
966 feature_indices: Vec<usize>,
967 threshold_bins: Vec<u16>,
968 leaf_values: Vec<Vec<f64>>,
969 class_labels: Vec<f64>,
970 },
971 BinaryRegressor {
972 nodes: Vec<OptimizedBinaryRegressorNode>,
973 },
974 ObliviousRegressor {
975 feature_indices: Vec<usize>,
976 threshold_bins: Vec<u16>,
977 leaf_values: Vec<f64>,
978 },
979 ForestClassifier {
980 trees: Vec<OptimizedRuntime>,
981 class_labels: Vec<f64>,
982 },
983 ForestRegressor {
984 trees: Vec<OptimizedRuntime>,
985 },
986 BoostedBinaryClassifier {
987 trees: Vec<OptimizedRuntime>,
988 tree_weights: Vec<f64>,
989 base_score: f64,
990 class_labels: Vec<f64>,
991 },
992 BoostedRegressor {
993 trees: Vec<OptimizedRuntime>,
994 tree_weights: Vec<f64>,
995 base_score: f64,
996 },
997}
998
999#[derive(Debug, Clone, Serialize, Deserialize)]
1000enum OptimizedClassifierNode {
1001 Leaf(Vec<f64>),
1002 Binary {
1003 feature_index: usize,
1004 threshold_bin: u16,
1005 children: [usize; 2],
1006 },
1007 Multiway {
1008 feature_index: usize,
1009 child_lookup: Vec<usize>,
1010 max_bin_index: usize,
1011 fallback_probabilities: Vec<f64>,
1012 },
1013}
1014
1015#[derive(Debug, Clone, Serialize, Deserialize)]
1016enum OptimizedBinaryClassifierNode {
1017 Leaf(Vec<f64>),
1018 Branch {
1019 feature_index: usize,
1020 threshold_bin: u16,
1021 jump_index: usize,
1022 jump_if_greater: bool,
1023 },
1024}
1025
1026#[derive(Debug, Clone, Serialize, Deserialize)]
1027enum OptimizedBinaryRegressorNode {
1028 Leaf(f64),
1029 Branch {
1030 feature_index: usize,
1031 threshold_bin: u16,
1032 jump_index: usize,
1033 jump_if_greater: bool,
1034 },
1035}
1036
1037#[derive(Debug, Clone)]
1038struct InferenceExecutor {
1039 thread_count: usize,
1040 pool: Option<Arc<rayon::ThreadPool>>,
1041}
1042
1043#[derive(Debug, Clone, Serialize, Deserialize)]
1044struct CompiledArtifactPayload {
1045 semantic_ir: ModelPackageIr,
1046 runtime: OptimizedRuntime,
1047}
1048
1049impl InferenceExecutor {
1050 fn new(thread_count: usize) -> Result<Self, OptimizeError> {
1051 let pool = if thread_count > 1 {
1052 Some(Arc::new(
1053 ThreadPoolBuilder::new()
1054 .num_threads(thread_count)
1055 .build()
1056 .map_err(|err| OptimizeError::ThreadPoolBuildFailed(err.to_string()))?,
1057 ))
1058 } else {
1059 None
1060 };
1061
1062 Ok(Self { thread_count, pool })
1063 }
1064
1065 fn predict_rows<F>(&self, n_rows: usize, predict_row: F) -> Vec<f64>
1066 where
1067 F: Fn(usize) -> f64 + Sync + Send,
1068 {
1069 if self.thread_count == 1 || n_rows < PARALLEL_INFERENCE_ROW_THRESHOLD {
1070 return (0..n_rows).map(predict_row).collect();
1071 }
1072
1073 self.pool
1074 .as_ref()
1075 .expect("thread pool exists when parallel inference is enabled")
1076 .install(|| (0..n_rows).into_par_iter().map(predict_row).collect())
1077 }
1078
1079 fn fill_chunks<F>(&self, outputs: &mut [f64], chunk_rows: usize, fill_chunk: F)
1080 where
1081 F: Fn(usize, &mut [f64]) + Sync + Send,
1082 {
1083 if self.thread_count == 1 || outputs.len() < PARALLEL_INFERENCE_ROW_THRESHOLD {
1084 for (chunk_index, chunk) in outputs.chunks_mut(chunk_rows).enumerate() {
1085 fill_chunk(chunk_index * chunk_rows, chunk);
1086 }
1087 return;
1088 }
1089
1090 self.pool
1091 .as_ref()
1092 .expect("thread pool exists when parallel inference is enabled")
1093 .install(|| {
1094 outputs
1095 .par_chunks_mut(chunk_rows)
1096 .enumerate()
1097 .for_each(|(chunk_index, chunk)| fill_chunk(chunk_index * chunk_rows, chunk));
1098 });
1099 }
1100}
1101
1102#[derive(Debug, Clone)]
1103pub struct OptimizedModel {
1104 source_model: Model,
1105 runtime: OptimizedRuntime,
1106 executor: InferenceExecutor,
1107}
1108
1109impl OptimizedModel {
1110 fn new(source_model: Model, physical_cores: Option<usize>) -> Result<Self, OptimizeError> {
1111 let thread_count = resolve_inference_thread_count(physical_cores)?;
1112 let runtime = OptimizedRuntime::from_model(&source_model);
1113 let executor = InferenceExecutor::new(thread_count)?;
1114
1115 Ok(Self {
1116 source_model,
1117 runtime,
1118 executor,
1119 })
1120 }
1121
1122 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
1123 if self.runtime.should_use_batch_matrix(table.n_rows()) {
1124 let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1125 return self.predict_column_major_binned_matrix(&matrix);
1126 }
1127
1128 self.executor.predict_rows(table.n_rows(), |row_index| {
1129 self.runtime.predict_table_row(table, row_index)
1130 })
1131 }
1132
1133 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
1134 let table = InferenceTable::from_rows(rows, self.source_model.feature_preprocessing())?;
1135 if self.runtime.should_use_batch_matrix(table.n_rows()) {
1136 let matrix = table.to_column_major_binned_matrix();
1137 Ok(self.predict_column_major_binned_matrix(&matrix))
1138 } else {
1139 Ok(self.predict_table(&table))
1140 }
1141 }
1142
1143 pub fn predict_named_columns(
1144 &self,
1145 columns: BTreeMap<String, Vec<f64>>,
1146 ) -> Result<Vec<f64>, PredictError> {
1147 let table =
1148 InferenceTable::from_named_columns(columns, self.source_model.feature_preprocessing())?;
1149 if self.runtime.should_use_batch_matrix(table.n_rows()) {
1150 let matrix = table.to_column_major_binned_matrix();
1151 Ok(self.predict_column_major_binned_matrix(&matrix))
1152 } else {
1153 Ok(self.predict_table(&table))
1154 }
1155 }
1156
1157 pub fn predict_proba_table(
1158 &self,
1159 table: &dyn TableAccess,
1160 ) -> Result<Vec<Vec<f64>>, PredictError> {
1161 self.runtime.predict_proba_table(table, &self.executor)
1162 }
1163
1164 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
1165 let table = InferenceTable::from_rows(rows, self.source_model.feature_preprocessing())?;
1166 self.predict_proba_table(&table)
1167 }
1168
1169 pub fn predict_proba_named_columns(
1170 &self,
1171 columns: BTreeMap<String, Vec<f64>>,
1172 ) -> Result<Vec<Vec<f64>>, PredictError> {
1173 let table =
1174 InferenceTable::from_named_columns(columns, self.source_model.feature_preprocessing())?;
1175 self.predict_proba_table(&table)
1176 }
1177
1178 pub fn predict_proba_sparse_binary_columns(
1179 &self,
1180 n_rows: usize,
1181 n_features: usize,
1182 columns: Vec<Vec<usize>>,
1183 ) -> Result<Vec<Vec<f64>>, PredictError> {
1184 let table = InferenceTable::from_sparse_binary_columns(
1185 n_rows,
1186 n_features,
1187 columns,
1188 self.source_model.feature_preprocessing(),
1189 )?;
1190 self.predict_proba_table(&table)
1191 }
1192
1193 pub fn predict_sparse_binary_columns(
1194 &self,
1195 n_rows: usize,
1196 n_features: usize,
1197 columns: Vec<Vec<usize>>,
1198 ) -> Result<Vec<f64>, PredictError> {
1199 let table = InferenceTable::from_sparse_binary_columns(
1200 n_rows,
1201 n_features,
1202 columns,
1203 self.source_model.feature_preprocessing(),
1204 )?;
1205 if self.runtime.should_use_batch_matrix(table.n_rows()) {
1206 let matrix = table.to_column_major_binned_matrix();
1207 Ok(self.predict_column_major_binned_matrix(&matrix))
1208 } else {
1209 Ok(self.predict_table(&table))
1210 }
1211 }
1212
1213 #[cfg(feature = "polars")]
1214 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
1215 let columns = polars_named_columns(df)?;
1216 self.predict_named_columns(columns)
1217 }
1218
1219 #[cfg(feature = "polars")]
1220 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
1221 let mut predictions = Vec::new();
1222 let mut offset = 0i64;
1223 loop {
1224 let batch = lf
1225 .clone()
1226 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
1227 .collect()?;
1228 let height = batch.height();
1229 if height == 0 {
1230 break;
1231 }
1232 predictions.extend(self.predict_polars_dataframe(&batch)?);
1233 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
1234 break;
1235 }
1236 offset += height as i64;
1237 }
1238 Ok(predictions)
1239 }
1240
1241 pub fn algorithm(&self) -> TrainAlgorithm {
1242 self.source_model.algorithm()
1243 }
1244
1245 pub fn task(&self) -> Task {
1246 self.source_model.task()
1247 }
1248
1249 pub fn criterion(&self) -> Criterion {
1250 self.source_model.criterion()
1251 }
1252
1253 pub fn tree_type(&self) -> TreeType {
1254 self.source_model.tree_type()
1255 }
1256
1257 pub fn mean_value(&self) -> Option<f64> {
1258 self.source_model.mean_value()
1259 }
1260
1261 pub fn canaries(&self) -> usize {
1262 self.source_model.canaries()
1263 }
1264
1265 pub fn max_depth(&self) -> Option<usize> {
1266 self.source_model.max_depth()
1267 }
1268
1269 pub fn min_samples_split(&self) -> Option<usize> {
1270 self.source_model.min_samples_split()
1271 }
1272
1273 pub fn min_samples_leaf(&self) -> Option<usize> {
1274 self.source_model.min_samples_leaf()
1275 }
1276
1277 pub fn n_trees(&self) -> Option<usize> {
1278 self.source_model.n_trees()
1279 }
1280
1281 pub fn max_features(&self) -> Option<usize> {
1282 self.source_model.max_features()
1283 }
1284
1285 pub fn seed(&self) -> Option<u64> {
1286 self.source_model.seed()
1287 }
1288
1289 pub fn compute_oob(&self) -> bool {
1290 self.source_model.compute_oob()
1291 }
1292
1293 pub fn oob_score(&self) -> Option<f64> {
1294 self.source_model.oob_score()
1295 }
1296
1297 pub fn learning_rate(&self) -> Option<f64> {
1298 self.source_model.learning_rate()
1299 }
1300
1301 pub fn bootstrap(&self) -> bool {
1302 self.source_model.bootstrap()
1303 }
1304
1305 pub fn top_gradient_fraction(&self) -> Option<f64> {
1306 self.source_model.top_gradient_fraction()
1307 }
1308
1309 pub fn other_gradient_fraction(&self) -> Option<f64> {
1310 self.source_model.other_gradient_fraction()
1311 }
1312
1313 pub fn tree_count(&self) -> usize {
1314 self.source_model.tree_count()
1315 }
1316
1317 pub fn tree_structure(
1318 &self,
1319 tree_index: usize,
1320 ) -> Result<TreeStructureSummary, IntrospectionError> {
1321 self.source_model.tree_structure(tree_index)
1322 }
1323
1324 pub fn tree_prediction_stats(
1325 &self,
1326 tree_index: usize,
1327 ) -> Result<PredictionValueStats, IntrospectionError> {
1328 self.source_model.tree_prediction_stats(tree_index)
1329 }
1330
1331 pub fn tree_node(
1332 &self,
1333 tree_index: usize,
1334 node_index: usize,
1335 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
1336 self.source_model.tree_node(tree_index, node_index)
1337 }
1338
1339 pub fn tree_level(
1340 &self,
1341 tree_index: usize,
1342 level_index: usize,
1343 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
1344 self.source_model.tree_level(tree_index, level_index)
1345 }
1346
1347 pub fn tree_leaf(
1348 &self,
1349 tree_index: usize,
1350 leaf_index: usize,
1351 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
1352 self.source_model.tree_leaf(tree_index, leaf_index)
1353 }
1354
1355 pub fn to_ir(&self) -> ModelPackageIr {
1356 self.source_model.to_ir()
1357 }
1358
1359 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
1360 self.source_model.to_ir_json()
1361 }
1362
1363 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
1364 self.source_model.to_ir_json_pretty()
1365 }
1366
1367 pub fn serialize(&self) -> Result<String, serde_json::Error> {
1368 self.source_model.serialize()
1369 }
1370
1371 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
1372 self.source_model.serialize_pretty()
1373 }
1374
1375 pub fn serialize_compiled(&self) -> Result<Vec<u8>, CompiledArtifactError> {
1376 let payload = CompiledArtifactPayload {
1377 semantic_ir: self.source_model.to_ir(),
1378 runtime: self.runtime.clone(),
1379 };
1380 let mut payload_bytes = Vec::new();
1381 ciborium::into_writer(&payload, &mut payload_bytes)
1382 .map_err(|err| CompiledArtifactError::Encode(err.to_string()))?;
1383 let mut bytes = Vec::with_capacity(COMPILED_ARTIFACT_HEADER_LEN + payload_bytes.len());
1384 bytes.extend_from_slice(&COMPILED_ARTIFACT_MAGIC);
1385 bytes.extend_from_slice(&COMPILED_ARTIFACT_VERSION.to_le_bytes());
1386 bytes.extend_from_slice(&COMPILED_ARTIFACT_BACKEND_CPU.to_le_bytes());
1387 bytes.extend_from_slice(&payload_bytes);
1388 Ok(bytes)
1389 }
1390
1391 pub fn deserialize_compiled(
1392 serialized: &[u8],
1393 physical_cores: Option<usize>,
1394 ) -> Result<Self, CompiledArtifactError> {
1395 if serialized.len() < COMPILED_ARTIFACT_HEADER_LEN {
1396 return Err(CompiledArtifactError::ArtifactTooShort {
1397 actual: serialized.len(),
1398 minimum: COMPILED_ARTIFACT_HEADER_LEN,
1399 });
1400 }
1401
1402 let magic = [serialized[0], serialized[1], serialized[2], serialized[3]];
1403 if magic != COMPILED_ARTIFACT_MAGIC {
1404 return Err(CompiledArtifactError::InvalidMagic(magic));
1405 }
1406
1407 let version = u16::from_le_bytes([serialized[4], serialized[5]]);
1408 if version != COMPILED_ARTIFACT_VERSION {
1409 return Err(CompiledArtifactError::UnsupportedVersion(version));
1410 }
1411
1412 let backend = u16::from_le_bytes([serialized[6], serialized[7]]);
1413 if backend != COMPILED_ARTIFACT_BACKEND_CPU {
1414 return Err(CompiledArtifactError::UnsupportedBackend(backend));
1415 }
1416
1417 let payload: CompiledArtifactPayload = ciborium::from_reader(std::io::Cursor::new(
1418 &serialized[COMPILED_ARTIFACT_HEADER_LEN..],
1419 ))
1420 .map_err(|err| CompiledArtifactError::Decode(err.to_string()))?;
1421 let source_model = ir::model_from_ir(payload.semantic_ir)
1422 .map_err(CompiledArtifactError::InvalidSemanticModel)?;
1423 let thread_count = resolve_inference_thread_count(physical_cores)
1424 .map_err(CompiledArtifactError::InvalidRuntime)?;
1425 let executor =
1426 InferenceExecutor::new(thread_count).map_err(CompiledArtifactError::InvalidRuntime)?;
1427
1428 Ok(Self {
1429 source_model,
1430 runtime: payload.runtime,
1431 executor,
1432 })
1433 }
1434
1435 fn predict_column_major_binned_matrix(&self, matrix: &ColumnMajorBinnedMatrix) -> Vec<f64> {
1436 self.runtime
1437 .predict_column_major_matrix(matrix, &self.executor)
1438 }
1439}
1440
1441impl OptimizedRuntime {
1442 fn supports_batch_matrix(&self) -> bool {
1443 matches!(
1444 self,
1445 OptimizedRuntime::BinaryClassifier { .. }
1446 | OptimizedRuntime::BinaryRegressor { .. }
1447 | OptimizedRuntime::ObliviousClassifier { .. }
1448 | OptimizedRuntime::ObliviousRegressor { .. }
1449 | OptimizedRuntime::ForestClassifier { .. }
1450 | OptimizedRuntime::ForestRegressor { .. }
1451 | OptimizedRuntime::BoostedBinaryClassifier { .. }
1452 | OptimizedRuntime::BoostedRegressor { .. }
1453 )
1454 }
1455
1456 fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
1457 n_rows > 1 && self.supports_batch_matrix()
1458 }
1459
1460 fn from_model(model: &Model) -> Self {
1461 match model {
1462 Model::DecisionTreeClassifier(classifier) => Self::from_classifier(classifier),
1463 Model::DecisionTreeRegressor(regressor) => Self::from_regressor(regressor),
1464 Model::RandomForest(forest) => match forest.task() {
1465 Task::Classification => Self::ForestClassifier {
1466 trees: forest.trees().iter().map(Self::from_model).collect(),
1467 class_labels: forest
1468 .class_labels()
1469 .expect("classification forest stores class labels"),
1470 },
1471 Task::Regression => Self::ForestRegressor {
1472 trees: forest.trees().iter().map(Self::from_model).collect(),
1473 },
1474 },
1475 Model::GradientBoostedTrees(model) => match model.task() {
1476 Task::Classification => Self::BoostedBinaryClassifier {
1477 trees: model.trees().iter().map(Self::from_model).collect(),
1478 tree_weights: model.tree_weights().to_vec(),
1479 base_score: model.base_score(),
1480 class_labels: model
1481 .class_labels()
1482 .expect("classification boosting stores class labels"),
1483 },
1484 Task::Regression => Self::BoostedRegressor {
1485 trees: model.trees().iter().map(Self::from_model).collect(),
1486 tree_weights: model.tree_weights().to_vec(),
1487 base_score: model.base_score(),
1488 },
1489 },
1490 }
1491 }
1492
1493 fn from_classifier(classifier: &DecisionTreeClassifier) -> Self {
1494 match classifier.structure() {
1495 tree::classifier::TreeStructure::Standard { nodes, root } => {
1496 if classifier_nodes_are_binary_only(nodes) {
1497 return Self::BinaryClassifier {
1498 nodes: build_binary_classifier_layout(
1499 nodes,
1500 *root,
1501 classifier.class_labels(),
1502 ),
1503 class_labels: classifier.class_labels().to_vec(),
1504 };
1505 }
1506
1507 let optimized_nodes = nodes
1508 .iter()
1509 .map(|node| match node {
1510 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1511 OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
1512 class_counts,
1513 ))
1514 }
1515 tree::classifier::TreeNode::BinarySplit {
1516 feature_index,
1517 threshold_bin,
1518 left_child,
1519 right_child,
1520 ..
1521 } => OptimizedClassifierNode::Binary {
1522 feature_index: *feature_index,
1523 threshold_bin: *threshold_bin,
1524 children: [*left_child, *right_child],
1525 },
1526 tree::classifier::TreeNode::MultiwaySplit {
1527 feature_index,
1528 class_counts,
1529 branches,
1530 ..
1531 } => {
1532 let max_bin_index = branches
1533 .iter()
1534 .map(|(bin, _)| usize::from(*bin))
1535 .max()
1536 .unwrap_or(0);
1537 let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
1538 for (bin, child_index) in branches {
1539 child_lookup[usize::from(*bin)] = *child_index;
1540 }
1541 OptimizedClassifierNode::Multiway {
1542 feature_index: *feature_index,
1543 child_lookup,
1544 max_bin_index,
1545 fallback_probabilities: normalized_probabilities_from_counts(
1546 class_counts,
1547 ),
1548 }
1549 }
1550 })
1551 .collect();
1552
1553 Self::StandardClassifier {
1554 nodes: optimized_nodes,
1555 root: *root,
1556 class_labels: classifier.class_labels().to_vec(),
1557 }
1558 }
1559 tree::classifier::TreeStructure::Oblivious {
1560 splits,
1561 leaf_class_counts,
1562 ..
1563 } => Self::ObliviousClassifier {
1564 feature_indices: splits.iter().map(|split| split.feature_index).collect(),
1565 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
1566 leaf_values: leaf_class_counts
1567 .iter()
1568 .map(|class_counts| normalized_probabilities_from_counts(class_counts))
1569 .collect(),
1570 class_labels: classifier.class_labels().to_vec(),
1571 },
1572 }
1573 }
1574
1575 fn from_regressor(regressor: &DecisionTreeRegressor) -> Self {
1576 match regressor.structure() {
1577 tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
1578 Self::BinaryRegressor {
1579 nodes: build_binary_regressor_layout(nodes, *root),
1580 }
1581 }
1582 tree::regressor::RegressionTreeStructure::Oblivious {
1583 splits,
1584 leaf_values,
1585 ..
1586 } => Self::ObliviousRegressor {
1587 feature_indices: splits.iter().map(|split| split.feature_index).collect(),
1588 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
1589 leaf_values: leaf_values.clone(),
1590 },
1591 }
1592 }
1593
1594 #[inline(always)]
1595 fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
1596 match self {
1597 OptimizedRuntime::BinaryClassifier { .. }
1598 | OptimizedRuntime::StandardClassifier { .. }
1599 | OptimizedRuntime::ObliviousClassifier { .. }
1600 | OptimizedRuntime::ForestClassifier { .. }
1601 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1602 let probabilities = self
1603 .predict_proba_table_row(table, row_index)
1604 .expect("classifier runtime supports probability prediction");
1605 class_label_from_probabilities(&probabilities, self.class_labels())
1606 }
1607 OptimizedRuntime::BinaryRegressor { nodes } => {
1608 predict_binary_regressor_row(nodes, |feature_index| {
1609 table.binned_value(feature_index, row_index)
1610 })
1611 }
1612 OptimizedRuntime::ObliviousRegressor {
1613 feature_indices,
1614 threshold_bins,
1615 leaf_values,
1616 } => predict_oblivious_row(
1617 feature_indices,
1618 threshold_bins,
1619 leaf_values,
1620 |feature_index| table.binned_value(feature_index, row_index),
1621 ),
1622 OptimizedRuntime::ForestRegressor { trees } => {
1623 trees
1624 .iter()
1625 .map(|tree| tree.predict_table_row(table, row_index))
1626 .sum::<f64>()
1627 / trees.len() as f64
1628 }
1629 OptimizedRuntime::BoostedRegressor {
1630 trees,
1631 tree_weights,
1632 base_score,
1633 } => {
1634 *base_score
1635 + trees
1636 .iter()
1637 .zip(tree_weights.iter().copied())
1638 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1639 .sum::<f64>()
1640 }
1641 }
1642 }
1643
1644 #[inline(always)]
1645 fn predict_proba_table_row(
1646 &self,
1647 table: &dyn TableAccess,
1648 row_index: usize,
1649 ) -> Result<Vec<f64>, PredictError> {
1650 match self {
1651 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1652 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1653 table.binned_value(feature_index, row_index)
1654 })
1655 .to_vec(),
1656 ),
1657 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1658 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1659 table.binned_value(feature_index, row_index)
1660 })
1661 .to_vec(),
1662 ),
1663 OptimizedRuntime::ObliviousClassifier {
1664 feature_indices,
1665 threshold_bins,
1666 leaf_values,
1667 ..
1668 } => Ok(predict_oblivious_probabilities_row(
1669 feature_indices,
1670 threshold_bins,
1671 leaf_values,
1672 |feature_index| table.binned_value(feature_index, row_index),
1673 )
1674 .to_vec()),
1675 OptimizedRuntime::ForestClassifier { trees, .. } => {
1676 let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1677 for tree in &trees[1..] {
1678 let row = tree.predict_proba_table_row(table, row_index)?;
1679 for (total, value) in totals.iter_mut().zip(row) {
1680 *total += value;
1681 }
1682 }
1683 let tree_count = trees.len() as f64;
1684 for value in &mut totals {
1685 *value /= tree_count;
1686 }
1687 Ok(totals)
1688 }
1689 OptimizedRuntime::BoostedBinaryClassifier {
1690 trees,
1691 tree_weights,
1692 base_score,
1693 ..
1694 } => {
1695 let raw_score = *base_score
1696 + trees
1697 .iter()
1698 .zip(tree_weights.iter().copied())
1699 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1700 .sum::<f64>();
1701 let positive = sigmoid(raw_score);
1702 Ok(vec![1.0 - positive, positive])
1703 }
1704 OptimizedRuntime::BinaryRegressor { .. }
1705 | OptimizedRuntime::ObliviousRegressor { .. }
1706 | OptimizedRuntime::ForestRegressor { .. }
1707 | OptimizedRuntime::BoostedRegressor { .. } => {
1708 Err(PredictError::ProbabilityPredictionRequiresClassification)
1709 }
1710 }
1711 }
1712
1713 fn predict_proba_table(
1714 &self,
1715 table: &dyn TableAccess,
1716 executor: &InferenceExecutor,
1717 ) -> Result<Vec<Vec<f64>>, PredictError> {
1718 match self {
1719 OptimizedRuntime::BinaryClassifier { .. }
1720 | OptimizedRuntime::StandardClassifier { .. }
1721 | OptimizedRuntime::ObliviousClassifier { .. }
1722 | OptimizedRuntime::ForestClassifier { .. }
1723 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1724 if self.should_use_batch_matrix(table.n_rows()) {
1725 let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1726 self.predict_proba_column_major_matrix(&matrix, executor)
1727 } else {
1728 (0..table.n_rows())
1729 .map(|row_index| self.predict_proba_table_row(table, row_index))
1730 .collect()
1731 }
1732 }
1733 OptimizedRuntime::BinaryRegressor { .. }
1734 | OptimizedRuntime::ObliviousRegressor { .. }
1735 | OptimizedRuntime::ForestRegressor { .. }
1736 | OptimizedRuntime::BoostedRegressor { .. } => {
1737 Err(PredictError::ProbabilityPredictionRequiresClassification)
1738 }
1739 }
1740 }
1741
1742 fn predict_column_major_matrix(
1743 &self,
1744 matrix: &ColumnMajorBinnedMatrix,
1745 executor: &InferenceExecutor,
1746 ) -> Vec<f64> {
1747 match self {
1748 OptimizedRuntime::BinaryClassifier { .. }
1749 | OptimizedRuntime::StandardClassifier { .. }
1750 | OptimizedRuntime::ObliviousClassifier { .. }
1751 | OptimizedRuntime::ForestClassifier { .. }
1752 | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1753 .predict_proba_column_major_matrix(matrix, executor)
1754 .expect("classifier runtime supports probability prediction")
1755 .into_iter()
1756 .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1757 .collect(),
1758 OptimizedRuntime::BinaryRegressor { nodes } => {
1759 predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1760 }
1761 OptimizedRuntime::ObliviousRegressor {
1762 feature_indices,
1763 threshold_bins,
1764 leaf_values,
1765 } => predict_oblivious_column_major_matrix(
1766 feature_indices,
1767 threshold_bins,
1768 leaf_values,
1769 matrix,
1770 executor,
1771 ),
1772 OptimizedRuntime::ForestRegressor { trees } => {
1773 let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1774 for tree in &trees[1..] {
1775 let values = tree.predict_column_major_matrix(matrix, executor);
1776 for (total, value) in totals.iter_mut().zip(values) {
1777 *total += value;
1778 }
1779 }
1780 let tree_count = trees.len() as f64;
1781 for total in &mut totals {
1782 *total /= tree_count;
1783 }
1784 totals
1785 }
1786 OptimizedRuntime::BoostedRegressor {
1787 trees,
1788 tree_weights,
1789 base_score,
1790 } => {
1791 let mut totals = vec![*base_score; matrix.n_rows];
1792 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1793 let values = tree.predict_column_major_matrix(matrix, executor);
1794 for (total, value) in totals.iter_mut().zip(values) {
1795 *total += weight * value;
1796 }
1797 }
1798 totals
1799 }
1800 }
1801 }
1802
1803 fn predict_proba_column_major_matrix(
1804 &self,
1805 matrix: &ColumnMajorBinnedMatrix,
1806 executor: &InferenceExecutor,
1807 ) -> Result<Vec<Vec<f64>>, PredictError> {
1808 match self {
1809 OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1810 Ok(predict_binary_classifier_probabilities_column_major_matrix(
1811 nodes, matrix, executor,
1812 ))
1813 }
1814 OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1815 .map(|row_index| {
1816 self.predict_proba_binned_row_from_columns(matrix, row_index)
1817 .expect("classifier runtime supports probability prediction")
1818 })
1819 .collect()),
1820 OptimizedRuntime::ObliviousClassifier {
1821 feature_indices,
1822 threshold_bins,
1823 leaf_values,
1824 ..
1825 } => Ok(predict_oblivious_probabilities_column_major_matrix(
1826 feature_indices,
1827 threshold_bins,
1828 leaf_values,
1829 matrix,
1830 executor,
1831 )),
1832 OptimizedRuntime::ForestClassifier { trees, .. } => {
1833 let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1834 for tree in &trees[1..] {
1835 let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1836 for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1837 for (total, value) in row_totals.iter_mut().zip(row_values) {
1838 *total += value;
1839 }
1840 }
1841 }
1842 let tree_count = trees.len() as f64;
1843 for row in &mut totals {
1844 for value in row {
1845 *value /= tree_count;
1846 }
1847 }
1848 Ok(totals)
1849 }
1850 OptimizedRuntime::BoostedBinaryClassifier {
1851 trees,
1852 tree_weights,
1853 base_score,
1854 ..
1855 } => {
1856 let mut raw_scores = vec![*base_score; matrix.n_rows];
1857 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1858 let values = tree.predict_column_major_matrix(matrix, executor);
1859 for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1860 *raw_score += weight * value;
1861 }
1862 }
1863 Ok(raw_scores
1864 .into_iter()
1865 .map(|raw_score| {
1866 let positive = sigmoid(raw_score);
1867 vec![1.0 - positive, positive]
1868 })
1869 .collect())
1870 }
1871 OptimizedRuntime::BinaryRegressor { .. }
1872 | OptimizedRuntime::ObliviousRegressor { .. }
1873 | OptimizedRuntime::ForestRegressor { .. }
1874 | OptimizedRuntime::BoostedRegressor { .. } => {
1875 Err(PredictError::ProbabilityPredictionRequiresClassification)
1876 }
1877 }
1878 }
1879
1880 fn class_labels(&self) -> &[f64] {
1881 match self {
1882 OptimizedRuntime::BinaryClassifier { class_labels, .. }
1883 | OptimizedRuntime::StandardClassifier { class_labels, .. }
1884 | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1885 | OptimizedRuntime::ForestClassifier { class_labels, .. }
1886 | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1887 _ => &[],
1888 }
1889 }
1890
1891 #[inline(always)]
1892 fn predict_binned_row_from_columns(
1893 &self,
1894 matrix: &ColumnMajorBinnedMatrix,
1895 row_index: usize,
1896 ) -> f64 {
1897 match self {
1898 OptimizedRuntime::BinaryRegressor { nodes } => {
1899 predict_binary_regressor_row(nodes, |feature_index| {
1900 matrix.column(feature_index).value_at(row_index)
1901 })
1902 }
1903 OptimizedRuntime::ObliviousRegressor {
1904 feature_indices,
1905 threshold_bins,
1906 leaf_values,
1907 } => predict_oblivious_row(
1908 feature_indices,
1909 threshold_bins,
1910 leaf_values,
1911 |feature_index| matrix.column(feature_index).value_at(row_index),
1912 ),
1913 OptimizedRuntime::BoostedRegressor {
1914 trees,
1915 tree_weights,
1916 base_score,
1917 } => {
1918 *base_score
1919 + trees
1920 .iter()
1921 .zip(tree_weights.iter().copied())
1922 .map(|(tree, weight)| {
1923 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1924 })
1925 .sum::<f64>()
1926 }
1927 _ => self.predict_column_major_matrix(
1928 matrix,
1929 &InferenceExecutor::new(1).expect("inference executor"),
1930 )[row_index],
1931 }
1932 }
1933
1934 #[inline(always)]
1935 fn predict_proba_binned_row_from_columns(
1936 &self,
1937 matrix: &ColumnMajorBinnedMatrix,
1938 row_index: usize,
1939 ) -> Result<Vec<f64>, PredictError> {
1940 match self {
1941 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1942 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1943 matrix.column(feature_index).value_at(row_index)
1944 })
1945 .to_vec(),
1946 ),
1947 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1948 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1949 matrix.column(feature_index).value_at(row_index)
1950 })
1951 .to_vec(),
1952 ),
1953 OptimizedRuntime::ObliviousClassifier {
1954 feature_indices,
1955 threshold_bins,
1956 leaf_values,
1957 ..
1958 } => Ok(predict_oblivious_probabilities_row(
1959 feature_indices,
1960 threshold_bins,
1961 leaf_values,
1962 |feature_index| matrix.column(feature_index).value_at(row_index),
1963 )
1964 .to_vec()),
1965 OptimizedRuntime::ForestClassifier { trees, .. } => {
1966 let mut totals =
1967 trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1968 for tree in &trees[1..] {
1969 let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1970 for (total, value) in totals.iter_mut().zip(row) {
1971 *total += value;
1972 }
1973 }
1974 let tree_count = trees.len() as f64;
1975 for value in &mut totals {
1976 *value /= tree_count;
1977 }
1978 Ok(totals)
1979 }
1980 OptimizedRuntime::BoostedBinaryClassifier {
1981 trees,
1982 tree_weights,
1983 base_score,
1984 ..
1985 } => {
1986 let raw_score = *base_score
1987 + trees
1988 .iter()
1989 .zip(tree_weights.iter().copied())
1990 .map(|(tree, weight)| {
1991 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1992 })
1993 .sum::<f64>();
1994 let positive = sigmoid(raw_score);
1995 Ok(vec![1.0 - positive, positive])
1996 }
1997 OptimizedRuntime::BinaryRegressor { .. }
1998 | OptimizedRuntime::ObliviousRegressor { .. }
1999 | OptimizedRuntime::ForestRegressor { .. }
2000 | OptimizedRuntime::BoostedRegressor { .. } => {
2001 Err(PredictError::ProbabilityPredictionRequiresClassification)
2002 }
2003 }
2004 }
2005}
2006
2007#[inline(always)]
2008fn predict_standard_classifier_probabilities_row<F>(
2009 nodes: &[OptimizedClassifierNode],
2010 root: usize,
2011 bin_at: F,
2012) -> &[f64]
2013where
2014 F: Fn(usize) -> u16,
2015{
2016 let mut node_index = root;
2017 loop {
2018 match &nodes[node_index] {
2019 OptimizedClassifierNode::Leaf(value) => return value,
2020 OptimizedClassifierNode::Binary {
2021 feature_index,
2022 threshold_bin,
2023 children,
2024 } => {
2025 let go_right = usize::from(bin_at(*feature_index) > *threshold_bin);
2026 node_index = children[go_right];
2027 }
2028 OptimizedClassifierNode::Multiway {
2029 feature_index,
2030 child_lookup,
2031 max_bin_index,
2032 fallback_probabilities,
2033 } => {
2034 let bin = usize::from(bin_at(*feature_index));
2035 if bin > *max_bin_index {
2036 return fallback_probabilities;
2037 }
2038 let child_index = child_lookup[bin];
2039 if child_index == usize::MAX {
2040 return fallback_probabilities;
2041 }
2042 node_index = child_index;
2043 }
2044 }
2045 }
2046}
2047
2048#[inline(always)]
2049fn predict_binary_classifier_probabilities_row<F>(
2050 nodes: &[OptimizedBinaryClassifierNode],
2051 bin_at: F,
2052) -> &[f64]
2053where
2054 F: Fn(usize) -> u16,
2055{
2056 let mut node_index = 0usize;
2057 loop {
2058 match &nodes[node_index] {
2059 OptimizedBinaryClassifierNode::Leaf(value) => return value,
2060 OptimizedBinaryClassifierNode::Branch {
2061 feature_index,
2062 threshold_bin,
2063 jump_index,
2064 jump_if_greater,
2065 } => {
2066 let go_right = bin_at(*feature_index) > *threshold_bin;
2067 node_index = if go_right == *jump_if_greater {
2068 *jump_index
2069 } else {
2070 node_index + 1
2071 };
2072 }
2073 }
2074 }
2075}
2076
2077#[inline(always)]
2078fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
2079where
2080 F: Fn(usize) -> u16,
2081{
2082 let mut node_index = 0usize;
2083 loop {
2084 match &nodes[node_index] {
2085 OptimizedBinaryRegressorNode::Leaf(value) => return *value,
2086 OptimizedBinaryRegressorNode::Branch {
2087 feature_index,
2088 threshold_bin,
2089 jump_index,
2090 jump_if_greater,
2091 } => {
2092 let go_right = bin_at(*feature_index) > *threshold_bin;
2093 node_index = if go_right == *jump_if_greater {
2094 *jump_index
2095 } else {
2096 node_index + 1
2097 };
2098 }
2099 }
2100 }
2101}
2102
2103fn predict_binary_classifier_probabilities_column_major_matrix(
2104 nodes: &[OptimizedBinaryClassifierNode],
2105 matrix: &ColumnMajorBinnedMatrix,
2106 _executor: &InferenceExecutor,
2107) -> Vec<Vec<f64>> {
2108 (0..matrix.n_rows)
2109 .map(|row_index| {
2110 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
2111 matrix.column(feature_index).value_at(row_index)
2112 })
2113 .to_vec()
2114 })
2115 .collect()
2116}
2117
2118fn predict_binary_regressor_column_major_matrix(
2119 nodes: &[OptimizedBinaryRegressorNode],
2120 matrix: &ColumnMajorBinnedMatrix,
2121 executor: &InferenceExecutor,
2122) -> Vec<f64> {
2123 let mut outputs = vec![0.0; matrix.n_rows];
2124 executor.fill_chunks(
2125 &mut outputs,
2126 STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
2127 |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
2128 );
2129 outputs
2130}
2131
2132fn predict_binary_regressor_chunk(
2133 nodes: &[OptimizedBinaryRegressorNode],
2134 matrix: &ColumnMajorBinnedMatrix,
2135 start_row: usize,
2136 output: &mut [f64],
2137) {
2138 let mut row_indices: Vec<usize> = (0..output.len()).collect();
2139 let mut stack = vec![(0usize, 0usize, output.len())];
2140
2141 while let Some((node_index, start, end)) = stack.pop() {
2142 match &nodes[node_index] {
2143 OptimizedBinaryRegressorNode::Leaf(value) => {
2144 for position in start..end {
2145 output[row_indices[position]] = *value;
2146 }
2147 }
2148 OptimizedBinaryRegressorNode::Branch {
2149 feature_index,
2150 threshold_bin,
2151 jump_index,
2152 jump_if_greater,
2153 } => {
2154 let fallthrough_index = node_index + 1;
2155 if *jump_index == fallthrough_index {
2156 stack.push((fallthrough_index, start, end));
2157 continue;
2158 }
2159
2160 let column = matrix.column(*feature_index);
2161 let mut partition = start;
2162 let mut jump_start = end;
2163 match column {
2164 CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
2165 let threshold = *threshold_bin as u8;
2166 while partition < jump_start {
2167 let row_offset = row_indices[partition];
2168 let go_right = values[start_row + row_offset] > threshold;
2169 let goes_jump = go_right == *jump_if_greater;
2170 if goes_jump {
2171 jump_start -= 1;
2172 row_indices.swap(partition, jump_start);
2173 } else {
2174 partition += 1;
2175 }
2176 }
2177 }
2178 _ => {
2179 while partition < jump_start {
2180 let row_offset = row_indices[partition];
2181 let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
2182 let goes_jump = go_right == *jump_if_greater;
2183 if goes_jump {
2184 jump_start -= 1;
2185 row_indices.swap(partition, jump_start);
2186 } else {
2187 partition += 1;
2188 }
2189 }
2190 }
2191 }
2192
2193 if jump_start < end {
2194 stack.push((*jump_index, jump_start, end));
2195 }
2196 if start < jump_start {
2197 stack.push((fallthrough_index, start, jump_start));
2198 }
2199 }
2200 }
2201 }
2202}
2203
2204#[inline(always)]
2205fn predict_oblivious_row<F>(
2206 feature_indices: &[usize],
2207 threshold_bins: &[u16],
2208 leaf_values: &[f64],
2209 bin_at: F,
2210) -> f64
2211where
2212 F: Fn(usize) -> u16,
2213{
2214 let mut leaf_index = 0usize;
2215 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2216 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
2217 leaf_index = (leaf_index << 1) | go_right;
2218 }
2219 leaf_values[leaf_index]
2220}
2221
2222#[inline(always)]
2223fn predict_oblivious_probabilities_row<'a, F>(
2224 feature_indices: &[usize],
2225 threshold_bins: &[u16],
2226 leaf_values: &'a [Vec<f64>],
2227 bin_at: F,
2228) -> &'a [f64]
2229where
2230 F: Fn(usize) -> u16,
2231{
2232 let mut leaf_index = 0usize;
2233 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2234 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
2235 leaf_index = (leaf_index << 1) | go_right;
2236 }
2237 leaf_values[leaf_index].as_slice()
2238}
2239
2240fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
2241 let total = class_counts.iter().sum::<usize>();
2242 if total == 0 {
2243 return vec![0.0; class_counts.len()];
2244 }
2245
2246 class_counts
2247 .iter()
2248 .map(|count| *count as f64 / total as f64)
2249 .collect()
2250}
2251
2252fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
2253 let best_index = probabilities
2254 .iter()
2255 .copied()
2256 .enumerate()
2257 .max_by(|(left_index, left), (right_index, right)| {
2258 left.total_cmp(right)
2259 .then_with(|| right_index.cmp(left_index))
2260 })
2261 .map(|(index, _)| index)
2262 .expect("classification probability row is non-empty");
2263 class_labels[best_index]
2264}
2265
2266#[inline(always)]
2267fn sigmoid(value: f64) -> f64 {
2268 if value >= 0.0 {
2269 let exp = (-value).exp();
2270 1.0 / (1.0 + exp)
2271 } else {
2272 let exp = value.exp();
2273 exp / (1.0 + exp)
2274 }
2275}
2276
2277fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
2278 nodes.iter().all(|node| {
2279 matches!(
2280 node,
2281 tree::classifier::TreeNode::Leaf { .. }
2282 | tree::classifier::TreeNode::BinarySplit { .. }
2283 )
2284 })
2285}
2286
2287fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
2288 match &nodes[node_index] {
2289 tree::classifier::TreeNode::Leaf { sample_count, .. }
2290 | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
2291 | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
2292 }
2293}
2294
2295fn build_binary_classifier_layout(
2296 nodes: &[tree::classifier::TreeNode],
2297 root: usize,
2298 _class_labels: &[f64],
2299) -> Vec<OptimizedBinaryClassifierNode> {
2300 let mut layout = Vec::with_capacity(nodes.len());
2301 append_binary_classifier_node(nodes, root, &mut layout);
2302 layout
2303}
2304
2305fn append_binary_classifier_node(
2306 nodes: &[tree::classifier::TreeNode],
2307 node_index: usize,
2308 layout: &mut Vec<OptimizedBinaryClassifierNode>,
2309) -> usize {
2310 let current_index = layout.len();
2311 layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
2312
2313 match &nodes[node_index] {
2314 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
2315 layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
2316 normalized_probabilities_from_counts(class_counts),
2317 );
2318 }
2319 tree::classifier::TreeNode::BinarySplit {
2320 feature_index,
2321 threshold_bin,
2322 left_child,
2323 right_child,
2324 ..
2325 } => {
2326 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
2327 (*left_child, *left_child, true)
2328 } else {
2329 let left_count = classifier_node_sample_count(nodes, *left_child);
2330 let right_count = classifier_node_sample_count(nodes, *right_child);
2331 if left_count >= right_count {
2332 (*left_child, *right_child, true)
2333 } else {
2334 (*right_child, *left_child, false)
2335 }
2336 };
2337
2338 let fallthrough_index = append_binary_classifier_node(nodes, fallthrough_child, layout);
2339 debug_assert_eq!(fallthrough_index, current_index + 1);
2340 let jump_index = if jump_child == fallthrough_child {
2341 fallthrough_index
2342 } else {
2343 append_binary_classifier_node(nodes, jump_child, layout)
2344 };
2345
2346 layout[current_index] = OptimizedBinaryClassifierNode::Branch {
2347 feature_index: *feature_index,
2348 threshold_bin: *threshold_bin,
2349 jump_index,
2350 jump_if_greater,
2351 };
2352 }
2353 tree::classifier::TreeNode::MultiwaySplit { .. } => {
2354 unreachable!("multiway nodes are filtered out before binary layout construction");
2355 }
2356 }
2357
2358 current_index
2359}
2360
2361fn regressor_node_sample_count(
2362 nodes: &[tree::regressor::RegressionNode],
2363 node_index: usize,
2364) -> usize {
2365 match &nodes[node_index] {
2366 tree::regressor::RegressionNode::Leaf { sample_count, .. }
2367 | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
2368 }
2369}
2370
2371fn build_binary_regressor_layout(
2372 nodes: &[tree::regressor::RegressionNode],
2373 root: usize,
2374) -> Vec<OptimizedBinaryRegressorNode> {
2375 let mut layout = Vec::with_capacity(nodes.len());
2376 append_binary_regressor_node(nodes, root, &mut layout);
2377 layout
2378}
2379
2380fn append_binary_regressor_node(
2381 nodes: &[tree::regressor::RegressionNode],
2382 node_index: usize,
2383 layout: &mut Vec<OptimizedBinaryRegressorNode>,
2384) -> usize {
2385 let current_index = layout.len();
2386 layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
2387
2388 match &nodes[node_index] {
2389 tree::regressor::RegressionNode::Leaf { value, .. } => {
2390 layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
2391 }
2392 tree::regressor::RegressionNode::BinarySplit {
2393 feature_index,
2394 threshold_bin,
2395 left_child,
2396 right_child,
2397 ..
2398 } => {
2399 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
2400 (*left_child, *left_child, true)
2401 } else {
2402 let left_count = regressor_node_sample_count(nodes, *left_child);
2403 let right_count = regressor_node_sample_count(nodes, *right_child);
2404 if left_count >= right_count {
2405 (*left_child, *right_child, true)
2406 } else {
2407 (*right_child, *left_child, false)
2408 }
2409 };
2410
2411 let fallthrough_index = append_binary_regressor_node(nodes, fallthrough_child, layout);
2412 debug_assert_eq!(fallthrough_index, current_index + 1);
2413 let jump_index = if jump_child == fallthrough_child {
2414 fallthrough_index
2415 } else {
2416 append_binary_regressor_node(nodes, jump_child, layout)
2417 };
2418
2419 layout[current_index] = OptimizedBinaryRegressorNode::Branch {
2420 feature_index: *feature_index,
2421 threshold_bin: *threshold_bin,
2422 jump_index,
2423 jump_if_greater,
2424 };
2425 }
2426 }
2427
2428 current_index
2429}
2430
2431fn predict_oblivious_column_major_matrix(
2432 feature_indices: &[usize],
2433 threshold_bins: &[u16],
2434 leaf_values: &[f64],
2435 matrix: &ColumnMajorBinnedMatrix,
2436 executor: &InferenceExecutor,
2437) -> Vec<f64> {
2438 let mut outputs = vec![0.0; matrix.n_rows];
2439 executor.fill_chunks(
2440 &mut outputs,
2441 PARALLEL_INFERENCE_CHUNK_ROWS,
2442 |start_row, chunk| {
2443 predict_oblivious_chunk(
2444 feature_indices,
2445 threshold_bins,
2446 leaf_values,
2447 matrix,
2448 start_row,
2449 chunk,
2450 )
2451 },
2452 );
2453 outputs
2454}
2455
2456fn predict_oblivious_probabilities_column_major_matrix(
2457 feature_indices: &[usize],
2458 threshold_bins: &[u16],
2459 leaf_values: &[Vec<f64>],
2460 matrix: &ColumnMajorBinnedMatrix,
2461 _executor: &InferenceExecutor,
2462) -> Vec<Vec<f64>> {
2463 (0..matrix.n_rows)
2464 .map(|row_index| {
2465 predict_oblivious_probabilities_row(
2466 feature_indices,
2467 threshold_bins,
2468 leaf_values,
2469 |feature_index| matrix.column(feature_index).value_at(row_index),
2470 )
2471 .to_vec()
2472 })
2473 .collect()
2474}
2475
2476fn predict_oblivious_chunk(
2477 feature_indices: &[usize],
2478 threshold_bins: &[u16],
2479 leaf_values: &[f64],
2480 matrix: &ColumnMajorBinnedMatrix,
2481 start_row: usize,
2482 output: &mut [f64],
2483) {
2484 let processed = simd_predict_oblivious_chunk(
2485 feature_indices,
2486 threshold_bins,
2487 leaf_values,
2488 matrix,
2489 start_row,
2490 output,
2491 );
2492
2493 for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2494 let row_index = start_row + offset;
2495 *out = predict_oblivious_row(
2496 feature_indices,
2497 threshold_bins,
2498 leaf_values,
2499 |feature_index| matrix.column(feature_index).value_at(row_index),
2500 );
2501 }
2502}
2503
2504fn simd_predict_oblivious_chunk(
2505 feature_indices: &[usize],
2506 threshold_bins: &[u16],
2507 leaf_values: &[f64],
2508 matrix: &ColumnMajorBinnedMatrix,
2509 start_row: usize,
2510 output: &mut [f64],
2511) -> usize {
2512 let mut processed = 0usize;
2513 let ones = u32x8::splat(1);
2514
2515 while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2516 let base_row = start_row + processed;
2517 let mut leaf_indices = u32x8::splat(0);
2518
2519 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2520 let column = matrix.column(feature_index);
2521 let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2522 let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2523 .try_into()
2524 .expect("lane width matches the fixed SIMD width");
2525 u32x8::new([
2526 u32::from(lanes[0]),
2527 u32::from(lanes[1]),
2528 u32::from(lanes[2]),
2529 u32::from(lanes[3]),
2530 u32::from(lanes[4]),
2531 u32::from(lanes[5]),
2532 u32::from(lanes[6]),
2533 u32::from(lanes[7]),
2534 ])
2535 } else {
2536 let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2537 .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2538 .expect("column is u16 when not u8")
2539 .try_into()
2540 .expect("lane width matches the fixed SIMD width");
2541 u32x8::from(u16x8::new(lanes))
2542 };
2543 let threshold = u32x8::splat(u32::from(threshold_bin));
2544 let bit = bins.cmp_gt(threshold) & ones;
2545 leaf_indices = (leaf_indices << 1) | bit;
2546 }
2547
2548 let lane_indices = leaf_indices.to_array();
2549 for lane in 0..OBLIVIOUS_SIMD_LANES {
2550 output[processed + lane] =
2551 leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2552 }
2553 processed += OBLIVIOUS_SIMD_LANES;
2554 }
2555
2556 processed
2557}
2558
2559pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2560 training::train(train_set, config)
2561}
2562
2563fn resolve_inference_thread_count(physical_cores: Option<usize>) -> Result<usize, OptimizeError> {
2564 let available = num_cpus::get_physical().max(1);
2565 let requested = physical_cores.unwrap_or(available);
2566
2567 if requested == 0 {
2568 return Err(OptimizeError::InvalidPhysicalCoreCount {
2569 requested,
2570 available,
2571 });
2572 }
2573
2574 Ok(requested.min(available))
2575}
2576
2577impl Model {
2578 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
2579 match self {
2580 Model::DecisionTreeClassifier(model) => model.predict_table(table),
2581 Model::DecisionTreeRegressor(model) => model.predict_table(table),
2582 Model::RandomForest(model) => model.predict_table(table),
2583 Model::GradientBoostedTrees(model) => model.predict_table(table),
2584 }
2585 }
2586
2587 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
2588 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
2589 Ok(self.predict_table(&table))
2590 }
2591
2592 pub fn predict_proba_table(
2593 &self,
2594 table: &dyn TableAccess,
2595 ) -> Result<Vec<Vec<f64>>, PredictError> {
2596 match self {
2597 Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
2598 Model::RandomForest(model) => model.predict_proba_table(table),
2599 Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
2600 Model::DecisionTreeRegressor(_) => {
2601 Err(PredictError::ProbabilityPredictionRequiresClassification)
2602 }
2603 }
2604 }
2605
2606 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
2607 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
2608 self.predict_proba_table(&table)
2609 }
2610
2611 pub fn predict_named_columns(
2612 &self,
2613 columns: BTreeMap<String, Vec<f64>>,
2614 ) -> Result<Vec<f64>, PredictError> {
2615 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
2616 Ok(self.predict_table(&table))
2617 }
2618
2619 pub fn predict_proba_named_columns(
2620 &self,
2621 columns: BTreeMap<String, Vec<f64>>,
2622 ) -> Result<Vec<Vec<f64>>, PredictError> {
2623 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
2624 self.predict_proba_table(&table)
2625 }
2626
2627 pub fn predict_sparse_binary_columns(
2628 &self,
2629 n_rows: usize,
2630 n_features: usize,
2631 columns: Vec<Vec<usize>>,
2632 ) -> Result<Vec<f64>, PredictError> {
2633 let table = InferenceTable::from_sparse_binary_columns(
2634 n_rows,
2635 n_features,
2636 columns,
2637 self.feature_preprocessing(),
2638 )?;
2639 Ok(self.predict_table(&table))
2640 }
2641
2642 pub fn predict_proba_sparse_binary_columns(
2643 &self,
2644 n_rows: usize,
2645 n_features: usize,
2646 columns: Vec<Vec<usize>>,
2647 ) -> Result<Vec<Vec<f64>>, PredictError> {
2648 let table = InferenceTable::from_sparse_binary_columns(
2649 n_rows,
2650 n_features,
2651 columns,
2652 self.feature_preprocessing(),
2653 )?;
2654 self.predict_proba_table(&table)
2655 }
2656
2657 #[cfg(feature = "polars")]
2658 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
2659 let columns = polars_named_columns(df)?;
2660 self.predict_named_columns(columns)
2661 }
2662
2663 #[cfg(feature = "polars")]
2664 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
2665 let mut predictions = Vec::new();
2666 let mut offset = 0i64;
2667 loop {
2668 let batch = lf
2669 .clone()
2670 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
2671 .collect()?;
2672 let height = batch.height();
2673 if height == 0 {
2674 break;
2675 }
2676 predictions.extend(self.predict_polars_dataframe(&batch)?);
2677 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
2678 break;
2679 }
2680 offset += height as i64;
2681 }
2682 Ok(predictions)
2683 }
2684
2685 pub fn algorithm(&self) -> TrainAlgorithm {
2686 match self {
2687 Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
2688 TrainAlgorithm::Dt
2689 }
2690 Model::RandomForest(_) => TrainAlgorithm::Rf,
2691 Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
2692 }
2693 }
2694
2695 pub fn task(&self) -> Task {
2696 match self {
2697 Model::DecisionTreeRegressor(_) => Task::Regression,
2698 Model::DecisionTreeClassifier(_) => Task::Classification,
2699 Model::RandomForest(model) => model.task(),
2700 Model::GradientBoostedTrees(model) => model.task(),
2701 }
2702 }
2703
2704 pub fn criterion(&self) -> Criterion {
2705 match self {
2706 Model::DecisionTreeClassifier(model) => model.criterion(),
2707 Model::DecisionTreeRegressor(model) => model.criterion(),
2708 Model::RandomForest(model) => model.criterion(),
2709 Model::GradientBoostedTrees(model) => model.criterion(),
2710 }
2711 }
2712
2713 pub fn tree_type(&self) -> TreeType {
2714 match self {
2715 Model::DecisionTreeClassifier(model) => match model.algorithm() {
2716 DecisionTreeAlgorithm::Id3 => TreeType::Id3,
2717 DecisionTreeAlgorithm::C45 => TreeType::C45,
2718 DecisionTreeAlgorithm::Cart => TreeType::Cart,
2719 DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
2720 DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
2721 },
2722 Model::DecisionTreeRegressor(model) => match model.algorithm() {
2723 RegressionTreeAlgorithm::Cart => TreeType::Cart,
2724 RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
2725 RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
2726 },
2727 Model::RandomForest(model) => model.tree_type(),
2728 Model::GradientBoostedTrees(model) => model.tree_type(),
2729 }
2730 }
2731
2732 pub fn mean_value(&self) -> Option<f64> {
2733 match self {
2734 Model::DecisionTreeClassifier(_)
2735 | Model::DecisionTreeRegressor(_)
2736 | Model::RandomForest(_)
2737 | Model::GradientBoostedTrees(_) => None,
2738 }
2739 }
2740
2741 pub fn canaries(&self) -> usize {
2742 self.training_metadata().canaries
2743 }
2744
2745 pub fn max_depth(&self) -> Option<usize> {
2746 self.training_metadata().max_depth
2747 }
2748
2749 pub fn min_samples_split(&self) -> Option<usize> {
2750 self.training_metadata().min_samples_split
2751 }
2752
2753 pub fn min_samples_leaf(&self) -> Option<usize> {
2754 self.training_metadata().min_samples_leaf
2755 }
2756
2757 pub fn n_trees(&self) -> Option<usize> {
2758 self.training_metadata().n_trees
2759 }
2760
2761 pub fn max_features(&self) -> Option<usize> {
2762 self.training_metadata().max_features
2763 }
2764
2765 pub fn seed(&self) -> Option<u64> {
2766 self.training_metadata().seed
2767 }
2768
2769 pub fn compute_oob(&self) -> bool {
2770 self.training_metadata().compute_oob
2771 }
2772
2773 pub fn oob_score(&self) -> Option<f64> {
2774 self.training_metadata().oob_score
2775 }
2776
2777 pub fn learning_rate(&self) -> Option<f64> {
2778 self.training_metadata().learning_rate
2779 }
2780
2781 pub fn bootstrap(&self) -> bool {
2782 self.training_metadata().bootstrap.unwrap_or(false)
2783 }
2784
2785 pub fn top_gradient_fraction(&self) -> Option<f64> {
2786 self.training_metadata().top_gradient_fraction
2787 }
2788
2789 pub fn other_gradient_fraction(&self) -> Option<f64> {
2790 self.training_metadata().other_gradient_fraction
2791 }
2792
2793 pub fn tree_count(&self) -> usize {
2794 self.to_ir().model.trees.len()
2795 }
2796
2797 pub fn tree_structure(
2798 &self,
2799 tree_index: usize,
2800 ) -> Result<TreeStructureSummary, IntrospectionError> {
2801 tree_structure_summary(self.tree_definition(tree_index)?)
2802 }
2803
2804 pub fn tree_prediction_stats(
2805 &self,
2806 tree_index: usize,
2807 ) -> Result<PredictionValueStats, IntrospectionError> {
2808 prediction_value_stats(self.tree_definition(tree_index)?)
2809 }
2810
2811 pub fn tree_node(
2812 &self,
2813 tree_index: usize,
2814 node_index: usize,
2815 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
2816 match self.tree_definition(tree_index)? {
2817 ir::TreeDefinition::NodeTree { nodes, .. } => {
2818 let available = nodes.len();
2819 nodes
2820 .into_iter()
2821 .nth(node_index)
2822 .ok_or(IntrospectionError::NodeIndexOutOfBounds {
2823 requested: node_index,
2824 available,
2825 })
2826 }
2827 ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
2828 }
2829 }
2830
2831 pub fn tree_level(
2832 &self,
2833 tree_index: usize,
2834 level_index: usize,
2835 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
2836 match self.tree_definition(tree_index)? {
2837 ir::TreeDefinition::ObliviousLevels { levels, .. } => {
2838 let available = levels.len();
2839 levels.into_iter().nth(level_index).ok_or(
2840 IntrospectionError::LevelIndexOutOfBounds {
2841 requested: level_index,
2842 available,
2843 },
2844 )
2845 }
2846 ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
2847 }
2848 }
2849
2850 pub fn tree_leaf(
2851 &self,
2852 tree_index: usize,
2853 leaf_index: usize,
2854 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
2855 match self.tree_definition(tree_index)? {
2856 ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
2857 let available = leaves.len();
2858 leaves
2859 .into_iter()
2860 .nth(leaf_index)
2861 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
2862 requested: leaf_index,
2863 available,
2864 })
2865 }
2866 ir::TreeDefinition::NodeTree { nodes, .. } => {
2867 let leaves = nodes
2868 .into_iter()
2869 .filter_map(|node| match node {
2870 ir::NodeTreeNode::Leaf {
2871 node_id,
2872 leaf,
2873 stats,
2874 ..
2875 } => Some(ir::IndexedLeaf {
2876 leaf_index: node_id,
2877 leaf,
2878 stats: ir::NodeStats {
2879 sample_count: stats.sample_count,
2880 impurity: stats.impurity,
2881 gain: stats.gain,
2882 class_counts: stats.class_counts,
2883 variance: stats.variance,
2884 },
2885 }),
2886 _ => None,
2887 })
2888 .collect::<Vec<_>>();
2889 let available = leaves.len();
2890 leaves
2891 .into_iter()
2892 .nth(leaf_index)
2893 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
2894 requested: leaf_index,
2895 available,
2896 })
2897 }
2898 }
2899 }
2900
2901 pub fn to_ir(&self) -> ModelPackageIr {
2902 ir::model_to_ir(self)
2903 }
2904
2905 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
2906 serde_json::to_string(&self.to_ir())
2907 }
2908
2909 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
2910 serde_json::to_string_pretty(&self.to_ir())
2911 }
2912
2913 pub fn serialize(&self) -> Result<String, serde_json::Error> {
2914 self.to_ir_json()
2915 }
2916
2917 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
2918 self.to_ir_json_pretty()
2919 }
2920
2921 pub fn optimize_inference(
2922 &self,
2923 physical_cores: Option<usize>,
2924 ) -> Result<OptimizedModel, OptimizeError> {
2925 OptimizedModel::new(self.clone(), physical_cores)
2926 }
2927
2928 pub fn json_schema() -> schemars::schema::RootSchema {
2929 ModelPackageIr::json_schema()
2930 }
2931
2932 pub fn json_schema_json() -> Result<String, IrError> {
2933 ModelPackageIr::json_schema_json()
2934 }
2935
2936 pub fn json_schema_json_pretty() -> Result<String, IrError> {
2937 ModelPackageIr::json_schema_json_pretty()
2938 }
2939
2940 pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
2941 let ir: ModelPackageIr =
2942 serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
2943 ir::model_from_ir(ir)
2944 }
2945
2946 pub(crate) fn num_features(&self) -> usize {
2947 match self {
2948 Model::DecisionTreeClassifier(model) => model.num_features(),
2949 Model::DecisionTreeRegressor(model) => model.num_features(),
2950 Model::RandomForest(model) => model.num_features(),
2951 Model::GradientBoostedTrees(model) => model.num_features(),
2952 }
2953 }
2954
2955 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
2956 match self {
2957 Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
2958 Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
2959 Model::RandomForest(model) => model.feature_preprocessing(),
2960 Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
2961 }
2962 }
2963
2964 pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
2965 match self {
2966 Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
2967 Model::RandomForest(model) => model.class_labels(),
2968 Model::GradientBoostedTrees(model) => model.class_labels(),
2969 Model::DecisionTreeRegressor(_) => None,
2970 }
2971 }
2972
2973 pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
2974 match self {
2975 Model::DecisionTreeClassifier(model) => model.training_metadata(),
2976 Model::DecisionTreeRegressor(model) => model.training_metadata(),
2977 Model::RandomForest(model) => model.training_metadata(),
2978 Model::GradientBoostedTrees(model) => model.training_metadata(),
2979 }
2980 }
2981
2982 fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
2983 let trees = self.to_ir().model.trees;
2984 let available = trees.len();
2985 trees
2986 .into_iter()
2987 .nth(tree_index)
2988 .ok_or(IntrospectionError::TreeIndexOutOfBounds {
2989 requested: tree_index,
2990 available,
2991 })
2992 }
2993}
2994
2995fn tree_structure_summary(
2996 tree: ir::TreeDefinition,
2997) -> Result<TreeStructureSummary, IntrospectionError> {
2998 match tree {
2999 ir::TreeDefinition::NodeTree {
3000 root_node_id,
3001 nodes,
3002 ..
3003 } => {
3004 let node_map = nodes
3005 .iter()
3006 .cloned()
3007 .map(|node| match &node {
3008 ir::NodeTreeNode::Leaf { node_id, .. }
3009 | ir::NodeTreeNode::BinaryBranch { node_id, .. }
3010 | ir::NodeTreeNode::MultiwayBranch { node_id, .. } => (*node_id, node),
3011 })
3012 .collect::<BTreeMap<_, _>>();
3013 let mut leaf_depths = Vec::new();
3014 collect_leaf_depths(&node_map, root_node_id, &mut leaf_depths)?;
3015 let internal_node_count = nodes
3016 .iter()
3017 .filter(|node| !matches!(node, ir::NodeTreeNode::Leaf { .. }))
3018 .count();
3019 let leaf_count = leaf_depths.len();
3020 let shortest_path = *leaf_depths.iter().min().unwrap_or(&0);
3021 let longest_path = *leaf_depths.iter().max().unwrap_or(&0);
3022 let average_path = if leaf_depths.is_empty() {
3023 0.0
3024 } else {
3025 leaf_depths.iter().sum::<usize>() as f64 / leaf_depths.len() as f64
3026 };
3027 Ok(TreeStructureSummary {
3028 representation: "node_tree".to_string(),
3029 node_count: internal_node_count + leaf_count,
3030 internal_node_count,
3031 leaf_count,
3032 actual_depth: longest_path,
3033 shortest_path,
3034 longest_path,
3035 average_path,
3036 })
3037 }
3038 ir::TreeDefinition::ObliviousLevels { depth, leaves, .. } => Ok(TreeStructureSummary {
3039 representation: "oblivious_levels".to_string(),
3040 node_count: ((1usize << depth) - 1) + leaves.len(),
3041 internal_node_count: (1usize << depth) - 1,
3042 leaf_count: leaves.len(),
3043 actual_depth: depth,
3044 shortest_path: depth,
3045 longest_path: depth,
3046 average_path: depth as f64,
3047 }),
3048 }
3049}
3050
3051fn collect_leaf_depths(
3052 nodes: &BTreeMap<usize, ir::NodeTreeNode>,
3053 node_id: usize,
3054 output: &mut Vec<usize>,
3055) -> Result<(), IntrospectionError> {
3056 match nodes
3057 .get(&node_id)
3058 .ok_or(IntrospectionError::NodeIndexOutOfBounds {
3059 requested: node_id,
3060 available: nodes.len(),
3061 })? {
3062 ir::NodeTreeNode::Leaf { depth, .. } => output.push(*depth),
3063 ir::NodeTreeNode::BinaryBranch {
3064 depth: _, children, ..
3065 } => {
3066 collect_leaf_depths(nodes, children.left, output)?;
3067 collect_leaf_depths(nodes, children.right, output)?;
3068 }
3069 ir::NodeTreeNode::MultiwayBranch {
3070 depth,
3071 branches,
3072 unmatched_leaf: _,
3073 ..
3074 } => {
3075 output.push(depth + 1);
3076 for branch in branches {
3077 collect_leaf_depths(nodes, branch.child, output)?;
3078 }
3079 }
3080 }
3081 Ok(())
3082}
3083
3084fn prediction_value_stats(
3085 tree: ir::TreeDefinition,
3086) -> Result<PredictionValueStats, IntrospectionError> {
3087 let predictions = match tree {
3088 ir::TreeDefinition::NodeTree { nodes, .. } => nodes
3089 .into_iter()
3090 .flat_map(|node| match node {
3091 ir::NodeTreeNode::Leaf { leaf, .. } => vec![leaf_payload_value(&leaf)],
3092 ir::NodeTreeNode::MultiwayBranch { unmatched_leaf, .. } => {
3093 vec![leaf_payload_value(&unmatched_leaf)]
3094 }
3095 ir::NodeTreeNode::BinaryBranch { .. } => Vec::new(),
3096 })
3097 .collect::<Vec<_>>(),
3098 ir::TreeDefinition::ObliviousLevels { leaves, .. } => leaves
3099 .into_iter()
3100 .map(|leaf| leaf_payload_value(&leaf.leaf))
3101 .collect::<Vec<_>>(),
3102 };
3103
3104 let count = predictions.len();
3105 let min = predictions
3106 .iter()
3107 .copied()
3108 .min_by(f64::total_cmp)
3109 .unwrap_or(0.0);
3110 let max = predictions
3111 .iter()
3112 .copied()
3113 .max_by(f64::total_cmp)
3114 .unwrap_or(0.0);
3115 let mean = if count == 0 {
3116 0.0
3117 } else {
3118 predictions.iter().sum::<f64>() / count as f64
3119 };
3120 let std_dev = if count == 0 {
3121 0.0
3122 } else {
3123 let variance = predictions
3124 .iter()
3125 .map(|value| (*value - mean).powi(2))
3126 .sum::<f64>()
3127 / count as f64;
3128 variance.sqrt()
3129 };
3130 let mut histogram = BTreeMap::<String, usize>::new();
3131 for prediction in &predictions {
3132 *histogram.entry(prediction.to_string()).or_insert(0) += 1;
3133 }
3134 let histogram = histogram
3135 .into_iter()
3136 .map(|(prediction, count)| PredictionHistogramEntry {
3137 prediction: prediction
3138 .parse::<f64>()
3139 .expect("histogram keys are numeric"),
3140 count,
3141 })
3142 .collect::<Vec<_>>();
3143
3144 Ok(PredictionValueStats {
3145 count,
3146 unique_count: histogram.len(),
3147 min,
3148 max,
3149 mean,
3150 std_dev,
3151 histogram,
3152 })
3153}
3154
3155fn leaf_payload_value(leaf: &ir::LeafPayload) -> f64 {
3156 match leaf {
3157 ir::LeafPayload::RegressionValue { value } => *value,
3158 ir::LeafPayload::ClassIndex { class_value, .. } => *class_value,
3159 }
3160}
3161
3162#[cfg(feature = "polars")]
3163fn polars_named_columns(df: &DataFrame) -> Result<BTreeMap<String, Vec<f64>>, PredictError> {
3164 df.get_columns()
3165 .iter()
3166 .map(|column| {
3167 let name = column.name().to_string();
3168 Ok((name, polars_column_values(column)?))
3169 })
3170 .collect()
3171}
3172
3173#[cfg(feature = "polars")]
3174fn polars_column_values(column: &Column) -> Result<Vec<f64>, PredictError> {
3175 let name = column.name().to_string();
3176 let series = column.as_materialized_series();
3177 match series.dtype() {
3178 DataType::Boolean => series
3179 .bool()?
3180 .into_iter()
3181 .enumerate()
3182 .map(|(row_index, value)| {
3183 value
3184 .map(|value| f64::from(u8::from(value)))
3185 .ok_or_else(|| PredictError::NullValue {
3186 feature: name.clone(),
3187 row_index,
3188 })
3189 })
3190 .collect(),
3191 DataType::Float64 => series
3192 .f64()?
3193 .into_iter()
3194 .enumerate()
3195 .map(|(row_index, value)| {
3196 value.ok_or_else(|| PredictError::NullValue {
3197 feature: name.clone(),
3198 row_index,
3199 })
3200 })
3201 .collect(),
3202 DataType::Float32
3203 | DataType::Int8
3204 | DataType::Int16
3205 | DataType::Int32
3206 | DataType::Int64
3207 | DataType::UInt8
3208 | DataType::UInt16
3209 | DataType::UInt32
3210 | DataType::UInt64 => {
3211 let casted = series.cast(&DataType::Float64)?;
3212 casted
3213 .f64()?
3214 .into_iter()
3215 .enumerate()
3216 .map(|(row_index, value)| {
3217 value.ok_or_else(|| PredictError::NullValue {
3218 feature: name.clone(),
3219 row_index,
3220 })
3221 })
3222 .collect()
3223 }
3224 dtype => Err(PredictError::UnsupportedFeatureType {
3225 feature: name,
3226 dtype: dtype.to_string(),
3227 }),
3228 }
3229}
3230
3231#[cfg(test)]
3232mod tests;