1use forestfire_data::{
19 BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20 numeric_missing_bin,
21};
22#[cfg(feature = "polars")]
23use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
24use rayon::ThreadPoolBuilder;
25use rayon::prelude::*;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use std::collections::{BTreeMap, BTreeSet};
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31use std::sync::Arc;
32use wide::{u16x8, u32x8};
33
34mod boosting;
35mod bootstrap;
36mod compiled_artifact;
37mod forest;
38mod inference_input;
39mod introspection;
40pub mod ir;
41mod model_api;
42mod optimized_runtime;
43mod runtime_planning;
44mod sampling;
45mod training;
46pub mod tree;
47
48pub use boosting::BoostingError;
49pub use boosting::GradientBoostedTrees;
50pub use compiled_artifact::CompiledArtifactError;
51pub use forest::RandomForest;
52pub use introspection::IntrospectionError;
53pub use introspection::PredictionHistogramEntry;
54pub use introspection::PredictionValueStats;
55pub use introspection::TreeStructureSummary;
56pub use ir::IrError;
57pub use ir::ModelPackageIr;
58pub use model_api::OptimizedModel;
59pub use tree::classifier::DecisionTreeAlgorithm;
60pub use tree::classifier::DecisionTreeClassifier;
61pub use tree::classifier::DecisionTreeError;
62pub use tree::classifier::DecisionTreeOptions;
63pub use tree::classifier::train_c45;
64pub use tree::classifier::train_cart;
65pub use tree::classifier::train_id3;
66pub use tree::classifier::train_oblivious;
67pub use tree::classifier::train_randomized;
68pub use tree::regressor::DecisionTreeRegressor;
69pub use tree::regressor::RegressionTreeAlgorithm;
70pub use tree::regressor::RegressionTreeError;
71pub use tree::regressor::RegressionTreeOptions;
72pub use tree::regressor::train_cart_regressor;
73pub use tree::regressor::train_oblivious_regressor;
74pub use tree::regressor::train_randomized_regressor;
75#[cfg(feature = "polars")]
76const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
77pub(crate) use inference_input::ColumnMajorBinnedMatrix;
78pub(crate) use inference_input::CompactBinnedColumn;
79pub(crate) use inference_input::InferenceTable;
80pub(crate) use inference_input::ProjectedTableView;
81#[cfg(feature = "polars")]
82pub(crate) use inference_input::polars_named_columns;
83pub(crate) use introspection::prediction_value_stats;
84pub(crate) use introspection::tree_structure_summary;
85pub(crate) use optimized_runtime::InferenceExecutor;
86pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
87pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
88pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
89pub(crate) use optimized_runtime::OptimizedClassifierNode;
90pub(crate) use optimized_runtime::OptimizedRuntime;
91pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
93pub(crate) use optimized_runtime::resolve_inference_thread_count;
94pub(crate) use runtime_planning::build_feature_index_map;
95pub(crate) use runtime_planning::build_feature_projection;
96pub(crate) use runtime_planning::model_used_feature_indices;
97pub(crate) use runtime_planning::ordered_ensemble_indices;
98pub(crate) use runtime_planning::remap_feature_index;
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum TrainAlgorithm {
102 Dt,
104 Rf,
106 Gbm,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum Criterion {
112 Auto,
114 Gini,
116 Entropy,
118 Mean,
120 Median,
122 SecondOrder,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum Task {
128 Regression,
130 Classification,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TreeType {
136 Id3,
138 C45,
140 Cart,
142 Randomized,
144 Oblivious,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum MaxFeatures {
150 Auto,
152 All,
154 Sqrt,
156 Third,
158 Count(usize),
160}
161
162impl MaxFeatures {
163 pub fn resolve(self, task: Task, feature_count: usize) -> usize {
164 match self {
165 MaxFeatures::Auto => match task {
166 Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
167 Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
168 },
169 MaxFeatures::All => feature_count.max(1),
170 MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
171 MaxFeatures::Third => (feature_count / 3).max(1),
172 MaxFeatures::Count(count) => count.min(feature_count).max(1),
173 }
174 }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
178#[serde(rename_all = "snake_case")]
179pub enum InputFeatureKind {
180 Numeric,
182 Binary,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
187pub struct NumericBinBoundary {
188 pub bin: u16,
190 pub upper_bound: f64,
192}
193
194#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
195#[serde(tag = "kind", rename_all = "snake_case")]
196pub enum FeaturePreprocessing {
197 Numeric {
199 bin_boundaries: Vec<NumericBinBoundary>,
200 missing_bin: u16,
201 },
202 Binary,
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum MissingValueStrategy {
208 Heuristic,
209 Optimal,
210}
211
212#[derive(Debug, Clone, PartialEq, Eq)]
213pub enum MissingValueStrategyConfig {
214 Global(MissingValueStrategy),
215 PerFeature(BTreeMap<usize, MissingValueStrategy>),
216}
217
218impl MissingValueStrategyConfig {
219 pub fn heuristic() -> Self {
220 Self::Global(MissingValueStrategy::Heuristic)
221 }
222
223 pub fn optimal() -> Self {
224 Self::Global(MissingValueStrategy::Optimal)
225 }
226
227 pub fn resolve_for_feature_count(
228 &self,
229 feature_count: usize,
230 ) -> Result<Vec<MissingValueStrategy>, TrainError> {
231 match self {
232 MissingValueStrategyConfig::Global(strategy) => Ok(vec![*strategy; feature_count]),
233 MissingValueStrategyConfig::PerFeature(strategies) => {
234 let mut resolved = vec![MissingValueStrategy::Heuristic; feature_count];
235 for (&feature_index, &strategy) in strategies {
236 if feature_index >= feature_count {
237 return Err(TrainError::InvalidMissingValueStrategyFeature {
238 feature_index,
239 feature_count,
240 });
241 }
242 resolved[feature_index] = strategy;
243 }
244 Ok(resolved)
245 }
246 }
247 }
248}
249
250#[derive(Debug, Clone)]
256pub struct TrainConfig {
257 pub algorithm: TrainAlgorithm,
259 pub task: Task,
261 pub tree_type: TreeType,
263 pub criterion: Criterion,
265 pub max_depth: Option<usize>,
267 pub min_samples_split: Option<usize>,
269 pub min_samples_leaf: Option<usize>,
271 pub physical_cores: Option<usize>,
273 pub n_trees: Option<usize>,
275 pub max_features: MaxFeatures,
277 pub seed: Option<u64>,
279 pub compute_oob: bool,
281 pub learning_rate: Option<f64>,
283 pub bootstrap: bool,
285 pub top_gradient_fraction: Option<f64>,
287 pub other_gradient_fraction: Option<f64>,
289 pub missing_value_strategy: MissingValueStrategyConfig,
291 pub histogram_bins: Option<NumericBins>,
297}
298
299impl Default for TrainConfig {
300 fn default() -> Self {
301 Self {
302 algorithm: TrainAlgorithm::Dt,
303 task: Task::Regression,
304 tree_type: TreeType::Cart,
305 criterion: Criterion::Auto,
306 max_depth: None,
307 min_samples_split: None,
308 min_samples_leaf: None,
309 physical_cores: None,
310 n_trees: None,
311 max_features: MaxFeatures::Auto,
312 seed: None,
313 compute_oob: false,
314 learning_rate: None,
315 bootstrap: false,
316 top_gradient_fraction: None,
317 other_gradient_fraction: None,
318 missing_value_strategy: MissingValueStrategyConfig::heuristic(),
319 histogram_bins: None,
320 }
321 }
322}
323
324#[derive(Debug, Clone)]
330pub enum Model {
331 DecisionTreeClassifier(DecisionTreeClassifier),
332 DecisionTreeRegressor(DecisionTreeRegressor),
333 RandomForest(RandomForest),
334 GradientBoostedTrees(GradientBoostedTrees),
335}
336
337#[derive(Debug)]
338pub enum TrainError {
339 DecisionTree(DecisionTreeError),
340 RegressionTree(RegressionTreeError),
341 Boosting(BoostingError),
342 InvalidPhysicalCoreCount {
343 requested: usize,
344 available: usize,
345 },
346 ThreadPoolBuildFailed(String),
347 UnsupportedConfiguration {
348 task: Task,
349 tree_type: TreeType,
350 criterion: Criterion,
351 },
352 InvalidMaxDepth(usize),
353 InvalidMinSamplesSplit(usize),
354 InvalidMinSamplesLeaf(usize),
355 InvalidTreeCount(usize),
356 InvalidMaxFeatures(usize),
357 InvalidMissingValueStrategyFeature {
358 feature_index: usize,
359 feature_count: usize,
360 },
361}
362
363impl Display for TrainError {
364 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
365 match self {
366 TrainError::DecisionTree(err) => err.fmt(f),
367 TrainError::RegressionTree(err) => err.fmt(f),
368 TrainError::Boosting(err) => err.fmt(f),
369 TrainError::InvalidPhysicalCoreCount {
370 requested,
371 available,
372 } => write!(
373 f,
374 "Requested {} physical cores, but the available physical core count is {}.",
375 requested, available
376 ),
377 TrainError::ThreadPoolBuildFailed(message) => {
378 write!(f, "Failed to build training thread pool: {}.", message)
379 }
380 TrainError::UnsupportedConfiguration {
381 task,
382 tree_type,
383 criterion,
384 } => write!(
385 f,
386 "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
387 task, tree_type, criterion
388 ),
389 TrainError::InvalidMaxDepth(value) => {
390 write!(f, "max_depth must be at least 1. Received {}.", value)
391 }
392 TrainError::InvalidMinSamplesSplit(value) => {
393 write!(
394 f,
395 "min_samples_split must be at least 1. Received {}.",
396 value
397 )
398 }
399 TrainError::InvalidMinSamplesLeaf(value) => {
400 write!(
401 f,
402 "min_samples_leaf must be at least 1. Received {}.",
403 value
404 )
405 }
406 TrainError::InvalidTreeCount(n_trees) => {
407 write!(
408 f,
409 "Random forest requires at least one tree. Received {}.",
410 n_trees
411 )
412 }
413 TrainError::InvalidMaxFeatures(count) => {
414 write!(
415 f,
416 "max_features must be at least 1 when provided as an integer. Received {}.",
417 count
418 )
419 }
420 TrainError::InvalidMissingValueStrategyFeature {
421 feature_index,
422 feature_count,
423 } => write!(
424 f,
425 "missing_value_strategy references feature {}, but the training table only has {} features.",
426 feature_index, feature_count
427 ),
428 }
429 }
430}
431
432impl Error for TrainError {}
433
434#[derive(Debug, Clone, PartialEq)]
435pub enum PredictError {
436 ProbabilityPredictionRequiresClassification,
437 RaggedRows {
438 row: usize,
439 expected: usize,
440 actual: usize,
441 },
442 FeatureCountMismatch {
443 expected: usize,
444 actual: usize,
445 },
446 ColumnLengthMismatch {
447 feature: String,
448 expected: usize,
449 actual: usize,
450 },
451 MissingFeature(String),
452 UnexpectedFeature(String),
453 InvalidBinaryValue {
454 feature_index: usize,
455 row_index: usize,
456 value: f64,
457 },
458 NullValue {
459 feature: String,
460 row_index: usize,
461 },
462 UnsupportedFeatureType {
463 feature: String,
464 dtype: String,
465 },
466 Polars(String),
467}
468
469impl Display for PredictError {
470 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471 match self {
472 PredictError::ProbabilityPredictionRequiresClassification => write!(
473 f,
474 "predict_proba is only available for classification models."
475 ),
476 PredictError::RaggedRows {
477 row,
478 expected,
479 actual,
480 } => write!(
481 f,
482 "Ragged inference row at index {}: expected {} columns, found {}.",
483 row, expected, actual
484 ),
485 PredictError::FeatureCountMismatch { expected, actual } => write!(
486 f,
487 "Inference input has {} features, but the model expects {}.",
488 actual, expected
489 ),
490 PredictError::ColumnLengthMismatch {
491 feature,
492 expected,
493 actual,
494 } => write!(
495 f,
496 "Feature '{}' has {} values, expected {}.",
497 feature, actual, expected
498 ),
499 PredictError::MissingFeature(feature) => {
500 write!(f, "Missing required feature '{}'.", feature)
501 }
502 PredictError::UnexpectedFeature(feature) => {
503 write!(f, "Unexpected feature '{}'.", feature)
504 }
505 PredictError::InvalidBinaryValue {
506 feature_index,
507 row_index,
508 value,
509 } => write!(
510 f,
511 "Feature {} at row {} must be binary for inference, found {}.",
512 feature_index, row_index, value
513 ),
514 PredictError::NullValue { feature, row_index } => write!(
515 f,
516 "Feature '{}' contains a null value at row {}.",
517 feature, row_index
518 ),
519 PredictError::UnsupportedFeatureType { feature, dtype } => write!(
520 f,
521 "Feature '{}' has unsupported dtype '{}'.",
522 feature, dtype
523 ),
524 PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
525 }
526 }
527}
528
529impl Error for PredictError {}
530
531#[cfg(feature = "polars")]
532impl From<polars::error::PolarsError> for PredictError {
533 fn from(value: polars::error::PolarsError) -> Self {
534 PredictError::Polars(value.to_string())
535 }
536}
537
538#[derive(Debug)]
539pub enum OptimizeError {
540 InvalidPhysicalCoreCount { requested: usize, available: usize },
541 ThreadPoolBuildFailed(String),
542 UnsupportedModelType(&'static str),
543}
544
545impl Display for OptimizeError {
546 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
547 match self {
548 OptimizeError::InvalidPhysicalCoreCount {
549 requested,
550 available,
551 } => write!(
552 f,
553 "Requested {} physical cores, but the available physical core count is {}.",
554 requested, available
555 ),
556 OptimizeError::ThreadPoolBuildFailed(message) => {
557 write!(f, "Failed to build inference thread pool: {}.", message)
558 }
559 OptimizeError::UnsupportedModelType(model_type) => {
560 write!(
561 f,
562 "Optimized inference is not supported for model type '{}'.",
563 model_type
564 )
565 }
566 }
567 }
568}
569
570impl Error for OptimizeError {}
571
572#[derive(Debug, Clone, Copy)]
573pub(crate) struct Parallelism {
574 thread_count: usize,
575}
576
577impl Parallelism {
578 pub(crate) fn sequential() -> Self {
579 Self { thread_count: 1 }
580 }
581
582 #[cfg(test)]
583 pub(crate) fn with_threads(thread_count: usize) -> Self {
584 Self {
585 thread_count: thread_count.max(1),
586 }
587 }
588
589 pub(crate) fn enabled(self) -> bool {
590 self.thread_count > 1
591 }
592}
593
594pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
595 (0..table.n_features())
596 .map(|feature_index| {
597 if table.is_binary_feature(feature_index) {
598 FeaturePreprocessing::Binary
599 } else {
600 let values = (0..table.n_rows())
601 .map(|row_index| table.feature_value(feature_index, row_index))
602 .collect::<Vec<_>>();
603 FeaturePreprocessing::Numeric {
604 bin_boundaries: numeric_bin_boundaries(
605 &values,
606 NumericBins::Fixed(table.numeric_bin_cap()),
607 )
608 .into_iter()
609 .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
610 .collect(),
611 missing_bin: numeric_missing_bin(NumericBins::Fixed(table.numeric_bin_cap())),
612 }
613 }
614 })
615 .collect()
616}
617
618fn missing_feature_enabled(
619 feature_index: usize,
620 missing_features: Option<&BTreeSet<usize>>,
621) -> bool {
622 missing_features.is_none_or(|features| features.contains(&feature_index))
623}
624
625fn optimized_missing_bin(
626 preprocessing: &[FeaturePreprocessing],
627 feature_index: usize,
628 missing_features: Option<&BTreeSet<usize>>,
629) -> Option<u16> {
630 if !missing_feature_enabled(feature_index, missing_features) {
631 return None;
632 }
633
634 match preprocessing.get(feature_index) {
635 Some(FeaturePreprocessing::Binary) => Some(forestfire_data::BINARY_MISSING_BIN),
636 Some(FeaturePreprocessing::Numeric { missing_bin, .. }) => Some(*missing_bin),
637 None => None,
638 }
639}
640
641impl OptimizedRuntime {
642 fn supports_batch_matrix(&self) -> bool {
643 matches!(
644 self,
645 OptimizedRuntime::BinaryClassifier { .. }
646 | OptimizedRuntime::BinaryRegressor { .. }
647 | OptimizedRuntime::ObliviousClassifier { .. }
648 | OptimizedRuntime::ObliviousRegressor { .. }
649 | OptimizedRuntime::ForestClassifier { .. }
650 | OptimizedRuntime::ForestRegressor { .. }
651 | OptimizedRuntime::BoostedBinaryClassifier { .. }
652 | OptimizedRuntime::BoostedRegressor { .. }
653 )
654 }
655
656 fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
657 n_rows > 1 && self.supports_batch_matrix()
658 }
659
660 fn from_model(
661 model: &Model,
662 feature_index_map: &[usize],
663 missing_features: Option<&BTreeSet<usize>>,
664 ) -> Self {
665 match model {
666 Model::DecisionTreeClassifier(classifier) => {
667 Self::from_classifier(classifier, feature_index_map, missing_features)
668 }
669 Model::DecisionTreeRegressor(regressor) => {
670 Self::from_regressor(regressor, feature_index_map, missing_features)
671 }
672 Model::RandomForest(forest) => match forest.task() {
673 Task::Classification => {
674 let tree_order = ordered_ensemble_indices(forest.trees());
675 Self::ForestClassifier {
676 trees: tree_order
677 .into_iter()
678 .map(|tree_index| {
679 Self::from_model(
680 &forest.trees()[tree_index],
681 feature_index_map,
682 missing_features,
683 )
684 })
685 .collect(),
686 class_labels: forest
687 .class_labels()
688 .expect("classification forest stores class labels"),
689 }
690 }
691 Task::Regression => {
692 let tree_order = ordered_ensemble_indices(forest.trees());
693 Self::ForestRegressor {
694 trees: tree_order
695 .into_iter()
696 .map(|tree_index| {
697 Self::from_model(
698 &forest.trees()[tree_index],
699 feature_index_map,
700 missing_features,
701 )
702 })
703 .collect(),
704 }
705 }
706 },
707 Model::GradientBoostedTrees(model) => match model.task() {
708 Task::Classification => {
709 let tree_order = ordered_ensemble_indices(model.trees());
710 Self::BoostedBinaryClassifier {
711 trees: tree_order
712 .iter()
713 .map(|tree_index| {
714 Self::from_model(
715 &model.trees()[*tree_index],
716 feature_index_map,
717 missing_features,
718 )
719 })
720 .collect(),
721 tree_weights: tree_order
722 .iter()
723 .map(|tree_index| model.tree_weights()[*tree_index])
724 .collect(),
725 base_score: model.base_score(),
726 class_labels: model
727 .class_labels()
728 .expect("classification boosting stores class labels"),
729 }
730 }
731 Task::Regression => {
732 let tree_order = ordered_ensemble_indices(model.trees());
733 Self::BoostedRegressor {
734 trees: tree_order
735 .iter()
736 .map(|tree_index| {
737 Self::from_model(
738 &model.trees()[*tree_index],
739 feature_index_map,
740 missing_features,
741 )
742 })
743 .collect(),
744 tree_weights: tree_order
745 .iter()
746 .map(|tree_index| model.tree_weights()[*tree_index])
747 .collect(),
748 base_score: model.base_score(),
749 }
750 }
751 },
752 }
753 }
754
755 fn from_classifier(
756 classifier: &DecisionTreeClassifier,
757 feature_index_map: &[usize],
758 missing_features: Option<&BTreeSet<usize>>,
759 ) -> Self {
760 match classifier.structure() {
761 tree::classifier::TreeStructure::Standard { nodes, root } => {
762 if classifier_nodes_are_binary_only(nodes) {
763 return Self::BinaryClassifier {
764 nodes: build_binary_classifier_layout(
765 nodes,
766 *root,
767 classifier.class_labels(),
768 feature_index_map,
769 classifier.feature_preprocessing(),
770 missing_features,
771 ),
772 class_labels: classifier.class_labels().to_vec(),
773 };
774 }
775
776 let optimized_nodes = nodes
777 .iter()
778 .map(|node| match node {
779 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
780 OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
781 class_counts,
782 ))
783 }
784 tree::classifier::TreeNode::BinarySplit {
785 feature_index,
786 threshold_bin,
787 missing_direction,
788 left_child,
789 right_child,
790 class_counts,
791 ..
792 } => OptimizedClassifierNode::Binary {
793 feature_index: remap_feature_index(*feature_index, feature_index_map),
794 threshold_bin: *threshold_bin,
795 children: [*left_child, *right_child],
796 missing_bin: optimized_missing_bin(
797 classifier.feature_preprocessing(),
798 *feature_index,
799 missing_features,
800 ),
801 missing_child: if missing_feature_enabled(
802 *feature_index,
803 missing_features,
804 ) {
805 match missing_direction {
806 tree::shared::MissingBranchDirection::Left => Some(*left_child),
807 tree::shared::MissingBranchDirection::Right => {
808 Some(*right_child)
809 }
810 tree::shared::MissingBranchDirection::Node => None,
811 }
812 } else {
813 None
814 },
815 missing_probabilities: if missing_feature_enabled(
816 *feature_index,
817 missing_features,
818 ) && matches!(
819 missing_direction,
820 tree::shared::MissingBranchDirection::Node
821 ) {
822 Some(normalized_probabilities_from_counts(class_counts))
823 } else {
824 None
825 },
826 },
827 tree::classifier::TreeNode::MultiwaySplit {
828 feature_index,
829 class_counts,
830 branches,
831 missing_child,
832 ..
833 } => {
834 let max_bin_index = branches
835 .iter()
836 .map(|(bin, _)| usize::from(*bin))
837 .max()
838 .unwrap_or(0);
839 let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
840 for (bin, child_index) in branches {
841 child_lookup[usize::from(*bin)] = *child_index;
842 }
843 OptimizedClassifierNode::Multiway {
844 feature_index: remap_feature_index(
845 *feature_index,
846 feature_index_map,
847 ),
848 child_lookup,
849 max_bin_index,
850 missing_bin: optimized_missing_bin(
851 classifier.feature_preprocessing(),
852 *feature_index,
853 missing_features,
854 ),
855 missing_child: if missing_feature_enabled(
856 *feature_index,
857 missing_features,
858 ) {
859 *missing_child
860 } else {
861 None
862 },
863 fallback_probabilities: normalized_probabilities_from_counts(
864 class_counts,
865 ),
866 }
867 }
868 })
869 .collect();
870
871 Self::StandardClassifier {
872 nodes: optimized_nodes,
873 root: *root,
874 class_labels: classifier.class_labels().to_vec(),
875 }
876 }
877 tree::classifier::TreeStructure::Oblivious {
878 splits,
879 leaf_class_counts,
880 ..
881 } => Self::ObliviousClassifier {
882 feature_indices: splits
883 .iter()
884 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
885 .collect(),
886 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
887 leaf_values: leaf_class_counts
888 .iter()
889 .map(|class_counts| normalized_probabilities_from_counts(class_counts))
890 .collect(),
891 class_labels: classifier.class_labels().to_vec(),
892 },
893 }
894 }
895
896 fn from_regressor(
897 regressor: &DecisionTreeRegressor,
898 feature_index_map: &[usize],
899 missing_features: Option<&BTreeSet<usize>>,
900 ) -> Self {
901 match regressor.structure() {
902 tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
903 Self::BinaryRegressor {
904 nodes: build_binary_regressor_layout(
905 nodes,
906 *root,
907 feature_index_map,
908 regressor.feature_preprocessing(),
909 missing_features,
910 ),
911 }
912 }
913 tree::regressor::RegressionTreeStructure::Oblivious {
914 splits,
915 leaf_values,
916 ..
917 } => Self::ObliviousRegressor {
918 feature_indices: splits
919 .iter()
920 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
921 .collect(),
922 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
923 leaf_values: leaf_values.clone(),
924 },
925 }
926 }
927
928 #[inline(always)]
929 fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
930 match self {
931 OptimizedRuntime::BinaryClassifier { .. }
932 | OptimizedRuntime::StandardClassifier { .. }
933 | OptimizedRuntime::ObliviousClassifier { .. }
934 | OptimizedRuntime::ForestClassifier { .. }
935 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
936 let probabilities = self
937 .predict_proba_table_row(table, row_index)
938 .expect("classifier runtime supports probability prediction");
939 class_label_from_probabilities(&probabilities, self.class_labels())
940 }
941 OptimizedRuntime::BinaryRegressor { nodes } => {
942 predict_binary_regressor_row(nodes, |feature_index| {
943 table.binned_value(feature_index, row_index)
944 })
945 }
946 OptimizedRuntime::ObliviousRegressor {
947 feature_indices,
948 threshold_bins,
949 leaf_values,
950 } => predict_oblivious_row(
951 feature_indices,
952 threshold_bins,
953 leaf_values,
954 |feature_index| table.binned_value(feature_index, row_index),
955 ),
956 OptimizedRuntime::ForestRegressor { trees } => {
957 trees
958 .iter()
959 .map(|tree| tree.predict_table_row(table, row_index))
960 .sum::<f64>()
961 / trees.len() as f64
962 }
963 OptimizedRuntime::BoostedRegressor {
964 trees,
965 tree_weights,
966 base_score,
967 } => {
968 *base_score
969 + trees
970 .iter()
971 .zip(tree_weights.iter().copied())
972 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
973 .sum::<f64>()
974 }
975 }
976 }
977
978 #[inline(always)]
979 fn predict_proba_table_row(
980 &self,
981 table: &dyn TableAccess,
982 row_index: usize,
983 ) -> Result<Vec<f64>, PredictError> {
984 match self {
985 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
986 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
987 table.binned_value(feature_index, row_index)
988 })
989 .to_vec(),
990 ),
991 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
992 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
993 table.binned_value(feature_index, row_index)
994 })
995 .to_vec(),
996 ),
997 OptimizedRuntime::ObliviousClassifier {
998 feature_indices,
999 threshold_bins,
1000 leaf_values,
1001 ..
1002 } => Ok(predict_oblivious_probabilities_row(
1003 feature_indices,
1004 threshold_bins,
1005 leaf_values,
1006 |feature_index| table.binned_value(feature_index, row_index),
1007 )
1008 .to_vec()),
1009 OptimizedRuntime::ForestClassifier { trees, .. } => {
1010 let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1011 for tree in &trees[1..] {
1012 let row = tree.predict_proba_table_row(table, row_index)?;
1013 for (total, value) in totals.iter_mut().zip(row) {
1014 *total += value;
1015 }
1016 }
1017 let tree_count = trees.len() as f64;
1018 for value in &mut totals {
1019 *value /= tree_count;
1020 }
1021 Ok(totals)
1022 }
1023 OptimizedRuntime::BoostedBinaryClassifier {
1024 trees,
1025 tree_weights,
1026 base_score,
1027 ..
1028 } => {
1029 let raw_score = *base_score
1030 + trees
1031 .iter()
1032 .zip(tree_weights.iter().copied())
1033 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1034 .sum::<f64>();
1035 let positive = sigmoid(raw_score);
1036 Ok(vec![1.0 - positive, positive])
1037 }
1038 OptimizedRuntime::BinaryRegressor { .. }
1039 | OptimizedRuntime::ObliviousRegressor { .. }
1040 | OptimizedRuntime::ForestRegressor { .. }
1041 | OptimizedRuntime::BoostedRegressor { .. } => {
1042 Err(PredictError::ProbabilityPredictionRequiresClassification)
1043 }
1044 }
1045 }
1046
1047 fn predict_proba_table(
1048 &self,
1049 table: &dyn TableAccess,
1050 executor: &InferenceExecutor,
1051 ) -> Result<Vec<Vec<f64>>, PredictError> {
1052 match self {
1053 OptimizedRuntime::BinaryClassifier { .. }
1054 | OptimizedRuntime::StandardClassifier { .. }
1055 | OptimizedRuntime::ObliviousClassifier { .. }
1056 | OptimizedRuntime::ForestClassifier { .. }
1057 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1058 if self.should_use_batch_matrix(table.n_rows()) {
1059 let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1060 self.predict_proba_column_major_matrix(&matrix, executor)
1061 } else {
1062 (0..table.n_rows())
1063 .map(|row_index| self.predict_proba_table_row(table, row_index))
1064 .collect()
1065 }
1066 }
1067 OptimizedRuntime::BinaryRegressor { .. }
1068 | OptimizedRuntime::ObliviousRegressor { .. }
1069 | OptimizedRuntime::ForestRegressor { .. }
1070 | OptimizedRuntime::BoostedRegressor { .. } => {
1071 Err(PredictError::ProbabilityPredictionRequiresClassification)
1072 }
1073 }
1074 }
1075
1076 fn predict_column_major_matrix(
1077 &self,
1078 matrix: &ColumnMajorBinnedMatrix,
1079 executor: &InferenceExecutor,
1080 ) -> Vec<f64> {
1081 match self {
1082 OptimizedRuntime::BinaryClassifier { .. }
1083 | OptimizedRuntime::StandardClassifier { .. }
1084 | OptimizedRuntime::ObliviousClassifier { .. }
1085 | OptimizedRuntime::ForestClassifier { .. }
1086 | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1087 .predict_proba_column_major_matrix(matrix, executor)
1088 .expect("classifier runtime supports probability prediction")
1089 .into_iter()
1090 .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1091 .collect(),
1092 OptimizedRuntime::BinaryRegressor { nodes } => {
1093 predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1094 }
1095 OptimizedRuntime::ObliviousRegressor {
1096 feature_indices,
1097 threshold_bins,
1098 leaf_values,
1099 } => predict_oblivious_column_major_matrix(
1100 feature_indices,
1101 threshold_bins,
1102 leaf_values,
1103 matrix,
1104 executor,
1105 ),
1106 OptimizedRuntime::ForestRegressor { trees } => {
1107 let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1108 for tree in &trees[1..] {
1109 let values = tree.predict_column_major_matrix(matrix, executor);
1110 for (total, value) in totals.iter_mut().zip(values) {
1111 *total += value;
1112 }
1113 }
1114 let tree_count = trees.len() as f64;
1115 for total in &mut totals {
1116 *total /= tree_count;
1117 }
1118 totals
1119 }
1120 OptimizedRuntime::BoostedRegressor {
1121 trees,
1122 tree_weights,
1123 base_score,
1124 } => {
1125 let mut totals = vec![*base_score; matrix.n_rows];
1126 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1127 let values = tree.predict_column_major_matrix(matrix, executor);
1128 for (total, value) in totals.iter_mut().zip(values) {
1129 *total += weight * value;
1130 }
1131 }
1132 totals
1133 }
1134 }
1135 }
1136
1137 fn predict_proba_column_major_matrix(
1138 &self,
1139 matrix: &ColumnMajorBinnedMatrix,
1140 executor: &InferenceExecutor,
1141 ) -> Result<Vec<Vec<f64>>, PredictError> {
1142 match self {
1143 OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1144 Ok(predict_binary_classifier_probabilities_column_major_matrix(
1145 nodes, matrix, executor,
1146 ))
1147 }
1148 OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1149 .map(|row_index| {
1150 self.predict_proba_binned_row_from_columns(matrix, row_index)
1151 .expect("classifier runtime supports probability prediction")
1152 })
1153 .collect()),
1154 OptimizedRuntime::ObliviousClassifier {
1155 feature_indices,
1156 threshold_bins,
1157 leaf_values,
1158 ..
1159 } => Ok(predict_oblivious_probabilities_column_major_matrix(
1160 feature_indices,
1161 threshold_bins,
1162 leaf_values,
1163 matrix,
1164 executor,
1165 )),
1166 OptimizedRuntime::ForestClassifier { trees, .. } => {
1167 let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1168 for tree in &trees[1..] {
1169 let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1170 for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1171 for (total, value) in row_totals.iter_mut().zip(row_values) {
1172 *total += value;
1173 }
1174 }
1175 }
1176 let tree_count = trees.len() as f64;
1177 for row in &mut totals {
1178 for value in row {
1179 *value /= tree_count;
1180 }
1181 }
1182 Ok(totals)
1183 }
1184 OptimizedRuntime::BoostedBinaryClassifier {
1185 trees,
1186 tree_weights,
1187 base_score,
1188 ..
1189 } => {
1190 let mut raw_scores = vec![*base_score; matrix.n_rows];
1191 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1192 let values = tree.predict_column_major_matrix(matrix, executor);
1193 for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1194 *raw_score += weight * value;
1195 }
1196 }
1197 Ok(raw_scores
1198 .into_iter()
1199 .map(|raw_score| {
1200 let positive = sigmoid(raw_score);
1201 vec![1.0 - positive, positive]
1202 })
1203 .collect())
1204 }
1205 OptimizedRuntime::BinaryRegressor { .. }
1206 | OptimizedRuntime::ObliviousRegressor { .. }
1207 | OptimizedRuntime::ForestRegressor { .. }
1208 | OptimizedRuntime::BoostedRegressor { .. } => {
1209 Err(PredictError::ProbabilityPredictionRequiresClassification)
1210 }
1211 }
1212 }
1213
1214 fn class_labels(&self) -> &[f64] {
1215 match self {
1216 OptimizedRuntime::BinaryClassifier { class_labels, .. }
1217 | OptimizedRuntime::StandardClassifier { class_labels, .. }
1218 | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1219 | OptimizedRuntime::ForestClassifier { class_labels, .. }
1220 | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1221 _ => &[],
1222 }
1223 }
1224
1225 #[inline(always)]
1226 fn predict_binned_row_from_columns(
1227 &self,
1228 matrix: &ColumnMajorBinnedMatrix,
1229 row_index: usize,
1230 ) -> f64 {
1231 match self {
1232 OptimizedRuntime::BinaryRegressor { nodes } => {
1233 predict_binary_regressor_row(nodes, |feature_index| {
1234 matrix.column(feature_index).value_at(row_index)
1235 })
1236 }
1237 OptimizedRuntime::ObliviousRegressor {
1238 feature_indices,
1239 threshold_bins,
1240 leaf_values,
1241 } => predict_oblivious_row(
1242 feature_indices,
1243 threshold_bins,
1244 leaf_values,
1245 |feature_index| matrix.column(feature_index).value_at(row_index),
1246 ),
1247 OptimizedRuntime::BoostedRegressor {
1248 trees,
1249 tree_weights,
1250 base_score,
1251 } => {
1252 *base_score
1253 + trees
1254 .iter()
1255 .zip(tree_weights.iter().copied())
1256 .map(|(tree, weight)| {
1257 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1258 })
1259 .sum::<f64>()
1260 }
1261 _ => self.predict_column_major_matrix(
1262 matrix,
1263 &InferenceExecutor::new(1).expect("inference executor"),
1264 )[row_index],
1265 }
1266 }
1267
1268 #[inline(always)]
1269 fn predict_proba_binned_row_from_columns(
1270 &self,
1271 matrix: &ColumnMajorBinnedMatrix,
1272 row_index: usize,
1273 ) -> Result<Vec<f64>, PredictError> {
1274 match self {
1275 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1276 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1277 matrix.column(feature_index).value_at(row_index)
1278 })
1279 .to_vec(),
1280 ),
1281 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1282 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1283 matrix.column(feature_index).value_at(row_index)
1284 })
1285 .to_vec(),
1286 ),
1287 OptimizedRuntime::ObliviousClassifier {
1288 feature_indices,
1289 threshold_bins,
1290 leaf_values,
1291 ..
1292 } => Ok(predict_oblivious_probabilities_row(
1293 feature_indices,
1294 threshold_bins,
1295 leaf_values,
1296 |feature_index| matrix.column(feature_index).value_at(row_index),
1297 )
1298 .to_vec()),
1299 OptimizedRuntime::ForestClassifier { trees, .. } => {
1300 let mut totals =
1301 trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1302 for tree in &trees[1..] {
1303 let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1304 for (total, value) in totals.iter_mut().zip(row) {
1305 *total += value;
1306 }
1307 }
1308 let tree_count = trees.len() as f64;
1309 for value in &mut totals {
1310 *value /= tree_count;
1311 }
1312 Ok(totals)
1313 }
1314 OptimizedRuntime::BoostedBinaryClassifier {
1315 trees,
1316 tree_weights,
1317 base_score,
1318 ..
1319 } => {
1320 let raw_score = *base_score
1321 + trees
1322 .iter()
1323 .zip(tree_weights.iter().copied())
1324 .map(|(tree, weight)| {
1325 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1326 })
1327 .sum::<f64>();
1328 let positive = sigmoid(raw_score);
1329 Ok(vec![1.0 - positive, positive])
1330 }
1331 OptimizedRuntime::BinaryRegressor { .. }
1332 | OptimizedRuntime::ObliviousRegressor { .. }
1333 | OptimizedRuntime::ForestRegressor { .. }
1334 | OptimizedRuntime::BoostedRegressor { .. } => {
1335 Err(PredictError::ProbabilityPredictionRequiresClassification)
1336 }
1337 }
1338 }
1339}
1340
1341#[inline(always)]
1342fn predict_standard_classifier_probabilities_row<F>(
1343 nodes: &[OptimizedClassifierNode],
1344 root: usize,
1345 bin_at: F,
1346) -> &[f64]
1347where
1348 F: Fn(usize) -> u16,
1349{
1350 let mut node_index = root;
1351 loop {
1352 match &nodes[node_index] {
1353 OptimizedClassifierNode::Leaf(value) => return value,
1354 OptimizedClassifierNode::Binary {
1355 feature_index,
1356 threshold_bin,
1357 children,
1358 missing_bin,
1359 missing_child,
1360 missing_probabilities,
1361 } => {
1362 let bin = bin_at(*feature_index);
1363 if missing_bin.is_some_and(|expected| expected == bin) {
1364 if let Some(probabilities) = missing_probabilities {
1365 return probabilities;
1366 }
1367 if let Some(child_index) = missing_child {
1368 node_index = *child_index;
1369 continue;
1370 }
1371 }
1372 let go_right = usize::from(bin > *threshold_bin);
1373 node_index = children[go_right];
1374 }
1375 OptimizedClassifierNode::Multiway {
1376 feature_index,
1377 child_lookup,
1378 max_bin_index,
1379 missing_bin,
1380 missing_child,
1381 fallback_probabilities,
1382 } => {
1383 let bin_value = bin_at(*feature_index);
1384 if missing_bin.is_some_and(|expected| expected == bin_value) {
1385 if let Some(child_index) = missing_child {
1386 node_index = *child_index;
1387 continue;
1388 }
1389 return fallback_probabilities;
1390 }
1391 let bin = usize::from(bin_value);
1392 if bin > *max_bin_index {
1393 return fallback_probabilities;
1394 }
1395 let child_index = child_lookup[bin];
1396 if child_index == usize::MAX {
1397 return fallback_probabilities;
1398 }
1399 node_index = child_index;
1400 }
1401 }
1402 }
1403}
1404
1405#[inline(always)]
1406fn predict_binary_classifier_probabilities_row<F>(
1407 nodes: &[OptimizedBinaryClassifierNode],
1408 bin_at: F,
1409) -> &[f64]
1410where
1411 F: Fn(usize) -> u16,
1412{
1413 let mut node_index = 0usize;
1414 loop {
1415 match &nodes[node_index] {
1416 OptimizedBinaryClassifierNode::Leaf(value) => return value,
1417 OptimizedBinaryClassifierNode::Branch {
1418 feature_index,
1419 threshold_bin,
1420 jump_index,
1421 jump_if_greater,
1422 missing_bin,
1423 missing_jump_index,
1424 missing_probabilities,
1425 } => {
1426 let bin = bin_at(*feature_index);
1427 if missing_bin.is_some_and(|expected| expected == bin) {
1428 if let Some(probabilities) = missing_probabilities {
1429 return probabilities;
1430 }
1431 if let Some(jump_index) = missing_jump_index {
1432 node_index = *jump_index;
1433 continue;
1434 }
1435 }
1436 let go_right = bin > *threshold_bin;
1437 node_index = if go_right == *jump_if_greater {
1438 *jump_index
1439 } else {
1440 node_index + 1
1441 };
1442 }
1443 }
1444 }
1445}
1446
1447#[inline(always)]
1448fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1449where
1450 F: Fn(usize) -> u16,
1451{
1452 let mut node_index = 0usize;
1453 loop {
1454 match &nodes[node_index] {
1455 OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1456 OptimizedBinaryRegressorNode::Branch {
1457 feature_index,
1458 threshold_bin,
1459 jump_index,
1460 jump_if_greater,
1461 missing_bin,
1462 missing_jump_index,
1463 missing_value,
1464 } => {
1465 let bin = bin_at(*feature_index);
1466 if missing_bin.is_some_and(|expected| expected == bin) {
1467 if let Some(value) = missing_value {
1468 return *value;
1469 }
1470 if let Some(jump_index) = missing_jump_index {
1471 node_index = *jump_index;
1472 continue;
1473 }
1474 }
1475 let go_right = bin > *threshold_bin;
1476 node_index = if go_right == *jump_if_greater {
1477 *jump_index
1478 } else {
1479 node_index + 1
1480 };
1481 }
1482 }
1483 }
1484}
1485
1486fn predict_binary_classifier_probabilities_column_major_matrix(
1487 nodes: &[OptimizedBinaryClassifierNode],
1488 matrix: &ColumnMajorBinnedMatrix,
1489 _executor: &InferenceExecutor,
1490) -> Vec<Vec<f64>> {
1491 if binary_classifier_nodes_require_rowwise_missing(nodes) {
1492 return (0..matrix.n_rows)
1493 .map(|row_index| {
1494 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1495 matrix.column(feature_index).value_at(row_index)
1496 })
1497 .to_vec()
1498 })
1499 .collect();
1500 }
1501 (0..matrix.n_rows)
1502 .map(|row_index| {
1503 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1504 matrix.column(feature_index).value_at(row_index)
1505 })
1506 .to_vec()
1507 })
1508 .collect()
1509}
1510
1511fn predict_binary_regressor_column_major_matrix(
1512 nodes: &[OptimizedBinaryRegressorNode],
1513 matrix: &ColumnMajorBinnedMatrix,
1514 executor: &InferenceExecutor,
1515) -> Vec<f64> {
1516 if binary_regressor_nodes_require_rowwise_missing(nodes) {
1517 return (0..matrix.n_rows)
1518 .map(|row_index| {
1519 predict_binary_regressor_row(nodes, |feature_index| {
1520 matrix.column(feature_index).value_at(row_index)
1521 })
1522 })
1523 .collect();
1524 }
1525 let mut outputs = vec![0.0; matrix.n_rows];
1526 executor.fill_chunks(
1527 &mut outputs,
1528 STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1529 |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1530 );
1531 outputs
1532}
1533
1534fn predict_binary_regressor_chunk(
1535 nodes: &[OptimizedBinaryRegressorNode],
1536 matrix: &ColumnMajorBinnedMatrix,
1537 start_row: usize,
1538 output: &mut [f64],
1539) {
1540 let mut row_indices: Vec<usize> = (0..output.len()).collect();
1541 let mut stack = vec![(0usize, 0usize, output.len())];
1542
1543 while let Some((node_index, start, end)) = stack.pop() {
1544 match &nodes[node_index] {
1545 OptimizedBinaryRegressorNode::Leaf(value) => {
1546 for position in start..end {
1547 output[row_indices[position]] = *value;
1548 }
1549 }
1550 OptimizedBinaryRegressorNode::Branch {
1551 feature_index,
1552 threshold_bin,
1553 jump_index,
1554 jump_if_greater,
1555 ..
1556 } => {
1557 let fallthrough_index = node_index + 1;
1558 if *jump_index == fallthrough_index {
1559 stack.push((fallthrough_index, start, end));
1560 continue;
1561 }
1562
1563 let column = matrix.column(*feature_index);
1564 let mut partition = start;
1565 let mut jump_start = end;
1566 match column {
1567 CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1568 let threshold = *threshold_bin as u8;
1569 while partition < jump_start {
1570 let row_offset = row_indices[partition];
1571 let go_right = values[start_row + row_offset] > threshold;
1572 let goes_jump = go_right == *jump_if_greater;
1573 if goes_jump {
1574 jump_start -= 1;
1575 row_indices.swap(partition, jump_start);
1576 } else {
1577 partition += 1;
1578 }
1579 }
1580 }
1581 _ => {
1582 while partition < jump_start {
1583 let row_offset = row_indices[partition];
1584 let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1585 let goes_jump = go_right == *jump_if_greater;
1586 if goes_jump {
1587 jump_start -= 1;
1588 row_indices.swap(partition, jump_start);
1589 } else {
1590 partition += 1;
1591 }
1592 }
1593 }
1594 }
1595
1596 if jump_start < end {
1597 stack.push((*jump_index, jump_start, end));
1598 }
1599 if start < jump_start {
1600 stack.push((fallthrough_index, start, jump_start));
1601 }
1602 }
1603 }
1604 }
1605}
1606
1607fn binary_classifier_nodes_require_rowwise_missing(
1608 nodes: &[OptimizedBinaryClassifierNode],
1609) -> bool {
1610 nodes.iter().any(|node| match node {
1611 OptimizedBinaryClassifierNode::Leaf(_) => false,
1612 OptimizedBinaryClassifierNode::Branch {
1613 missing_bin,
1614 missing_jump_index,
1615 missing_probabilities,
1616 ..
1617 } => {
1618 missing_bin.is_some() || missing_jump_index.is_some() || missing_probabilities.is_some()
1619 }
1620 })
1621}
1622
1623fn binary_regressor_nodes_require_rowwise_missing(nodes: &[OptimizedBinaryRegressorNode]) -> bool {
1624 nodes.iter().any(|node| match node {
1625 OptimizedBinaryRegressorNode::Leaf(_) => false,
1626 OptimizedBinaryRegressorNode::Branch {
1627 missing_bin,
1628 missing_jump_index,
1629 missing_value,
1630 ..
1631 } => missing_bin.is_some() || missing_jump_index.is_some() || missing_value.is_some(),
1632 })
1633}
1634
1635#[inline(always)]
1636fn predict_oblivious_row<F>(
1637 feature_indices: &[usize],
1638 threshold_bins: &[u16],
1639 leaf_values: &[f64],
1640 bin_at: F,
1641) -> f64
1642where
1643 F: Fn(usize) -> u16,
1644{
1645 let mut leaf_index = 0usize;
1646 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1647 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1648 leaf_index = (leaf_index << 1) | go_right;
1649 }
1650 leaf_values[leaf_index]
1651}
1652
1653#[inline(always)]
1654fn predict_oblivious_probabilities_row<'a, F>(
1655 feature_indices: &[usize],
1656 threshold_bins: &[u16],
1657 leaf_values: &'a [Vec<f64>],
1658 bin_at: F,
1659) -> &'a [f64]
1660where
1661 F: Fn(usize) -> u16,
1662{
1663 let mut leaf_index = 0usize;
1664 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1665 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1666 leaf_index = (leaf_index << 1) | go_right;
1667 }
1668 leaf_values[leaf_index].as_slice()
1669}
1670
1671fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1672 let total = class_counts.iter().sum::<usize>();
1673 if total == 0 {
1674 return vec![0.0; class_counts.len()];
1675 }
1676
1677 class_counts
1678 .iter()
1679 .map(|count| *count as f64 / total as f64)
1680 .collect()
1681}
1682
1683fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1684 let best_index = probabilities
1685 .iter()
1686 .copied()
1687 .enumerate()
1688 .max_by(|(left_index, left), (right_index, right)| {
1689 left.total_cmp(right)
1690 .then_with(|| right_index.cmp(left_index))
1691 })
1692 .map(|(index, _)| index)
1693 .expect("classification probability row is non-empty");
1694 class_labels[best_index]
1695}
1696
1697#[inline(always)]
1698fn sigmoid(value: f64) -> f64 {
1699 if value >= 0.0 {
1700 let exp = (-value).exp();
1701 1.0 / (1.0 + exp)
1702 } else {
1703 let exp = value.exp();
1704 exp / (1.0 + exp)
1705 }
1706}
1707
1708fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1709 nodes.iter().all(|node| {
1710 matches!(
1711 node,
1712 tree::classifier::TreeNode::Leaf { .. }
1713 | tree::classifier::TreeNode::BinarySplit { .. }
1714 )
1715 })
1716}
1717
1718fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1719 match &nodes[node_index] {
1720 tree::classifier::TreeNode::Leaf { sample_count, .. }
1721 | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1722 | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1723 }
1724}
1725
1726fn build_binary_classifier_layout(
1727 nodes: &[tree::classifier::TreeNode],
1728 root: usize,
1729 _class_labels: &[f64],
1730 feature_index_map: &[usize],
1731 preprocessing: &[FeaturePreprocessing],
1732 missing_features: Option<&BTreeSet<usize>>,
1733) -> Vec<OptimizedBinaryClassifierNode> {
1734 let mut layout = Vec::with_capacity(nodes.len());
1735 append_binary_classifier_node(
1736 nodes,
1737 root,
1738 &mut layout,
1739 feature_index_map,
1740 preprocessing,
1741 missing_features,
1742 );
1743 layout
1744}
1745
1746fn append_binary_classifier_node(
1747 nodes: &[tree::classifier::TreeNode],
1748 node_index: usize,
1749 layout: &mut Vec<OptimizedBinaryClassifierNode>,
1750 feature_index_map: &[usize],
1751 preprocessing: &[FeaturePreprocessing],
1752 missing_features: Option<&BTreeSet<usize>>,
1753) -> usize {
1754 let current_index = layout.len();
1755 layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1756
1757 match &nodes[node_index] {
1758 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1759 layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1760 normalized_probabilities_from_counts(class_counts),
1761 );
1762 }
1763 tree::classifier::TreeNode::BinarySplit {
1764 feature_index,
1765 threshold_bin,
1766 missing_direction,
1767 left_child,
1768 right_child,
1769 class_counts,
1770 ..
1771 } => {
1772 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1773 (*left_child, *left_child, true)
1774 } else {
1775 let left_count = classifier_node_sample_count(nodes, *left_child);
1776 let right_count = classifier_node_sample_count(nodes, *right_child);
1777 if left_count >= right_count {
1778 (*left_child, *right_child, true)
1779 } else {
1780 (*right_child, *left_child, false)
1781 }
1782 };
1783
1784 let fallthrough_index = append_binary_classifier_node(
1785 nodes,
1786 fallthrough_child,
1787 layout,
1788 feature_index_map,
1789 preprocessing,
1790 missing_features,
1791 );
1792 debug_assert_eq!(fallthrough_index, current_index + 1);
1793 let jump_index = if jump_child == fallthrough_child {
1794 fallthrough_index
1795 } else {
1796 append_binary_classifier_node(
1797 nodes,
1798 jump_child,
1799 layout,
1800 feature_index_map,
1801 preprocessing,
1802 missing_features,
1803 )
1804 };
1805
1806 let missing_bin =
1807 optimized_missing_bin(preprocessing, *feature_index, missing_features);
1808 let (missing_jump_index, missing_probabilities) =
1809 if missing_feature_enabled(*feature_index, missing_features) {
1810 match missing_direction {
1811 tree::shared::MissingBranchDirection::Left => (
1812 Some(if *left_child == fallthrough_child {
1813 fallthrough_index
1814 } else {
1815 jump_index
1816 }),
1817 None,
1818 ),
1819 tree::shared::MissingBranchDirection::Right => (
1820 Some(if *right_child == fallthrough_child {
1821 fallthrough_index
1822 } else {
1823 jump_index
1824 }),
1825 None,
1826 ),
1827 tree::shared::MissingBranchDirection::Node => (
1828 None,
1829 Some(normalized_probabilities_from_counts(class_counts)),
1830 ),
1831 }
1832 } else {
1833 (None, None)
1834 };
1835
1836 layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1837 feature_index: remap_feature_index(*feature_index, feature_index_map),
1838 threshold_bin: *threshold_bin,
1839 jump_index,
1840 jump_if_greater,
1841 missing_bin,
1842 missing_jump_index,
1843 missing_probabilities,
1844 };
1845 }
1846 tree::classifier::TreeNode::MultiwaySplit { .. } => {
1847 unreachable!("multiway nodes are filtered out before binary layout construction");
1848 }
1849 }
1850
1851 current_index
1852}
1853
1854fn regressor_node_sample_count(
1855 nodes: &[tree::regressor::RegressionNode],
1856 node_index: usize,
1857) -> usize {
1858 match &nodes[node_index] {
1859 tree::regressor::RegressionNode::Leaf { sample_count, .. }
1860 | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1861 }
1862}
1863
1864fn build_binary_regressor_layout(
1865 nodes: &[tree::regressor::RegressionNode],
1866 root: usize,
1867 feature_index_map: &[usize],
1868 preprocessing: &[FeaturePreprocessing],
1869 missing_features: Option<&BTreeSet<usize>>,
1870) -> Vec<OptimizedBinaryRegressorNode> {
1871 let mut layout = Vec::with_capacity(nodes.len());
1872 append_binary_regressor_node(
1873 nodes,
1874 root,
1875 &mut layout,
1876 feature_index_map,
1877 preprocessing,
1878 missing_features,
1879 );
1880 layout
1881}
1882
1883fn append_binary_regressor_node(
1884 nodes: &[tree::regressor::RegressionNode],
1885 node_index: usize,
1886 layout: &mut Vec<OptimizedBinaryRegressorNode>,
1887 feature_index_map: &[usize],
1888 preprocessing: &[FeaturePreprocessing],
1889 missing_features: Option<&BTreeSet<usize>>,
1890) -> usize {
1891 let current_index = layout.len();
1892 layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1893
1894 match &nodes[node_index] {
1895 tree::regressor::RegressionNode::Leaf { value, .. } => {
1896 layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1897 }
1898 tree::regressor::RegressionNode::BinarySplit {
1899 feature_index,
1900 threshold_bin,
1901 missing_direction,
1902 missing_value,
1903 left_child,
1904 right_child,
1905 ..
1906 } => {
1907 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1908 (*left_child, *left_child, true)
1909 } else {
1910 let left_count = regressor_node_sample_count(nodes, *left_child);
1911 let right_count = regressor_node_sample_count(nodes, *right_child);
1912 if left_count >= right_count {
1913 (*left_child, *right_child, true)
1914 } else {
1915 (*right_child, *left_child, false)
1916 }
1917 };
1918
1919 let fallthrough_index = append_binary_regressor_node(
1920 nodes,
1921 fallthrough_child,
1922 layout,
1923 feature_index_map,
1924 preprocessing,
1925 missing_features,
1926 );
1927 debug_assert_eq!(fallthrough_index, current_index + 1);
1928 let jump_index = if jump_child == fallthrough_child {
1929 fallthrough_index
1930 } else {
1931 append_binary_regressor_node(
1932 nodes,
1933 jump_child,
1934 layout,
1935 feature_index_map,
1936 preprocessing,
1937 missing_features,
1938 )
1939 };
1940
1941 let missing_bin =
1942 optimized_missing_bin(preprocessing, *feature_index, missing_features);
1943 let (missing_jump_index, missing_value) =
1944 if missing_feature_enabled(*feature_index, missing_features) {
1945 match missing_direction {
1946 tree::shared::MissingBranchDirection::Left => (
1947 Some(if *left_child == fallthrough_child {
1948 fallthrough_index
1949 } else {
1950 jump_index
1951 }),
1952 None,
1953 ),
1954 tree::shared::MissingBranchDirection::Right => (
1955 Some(if *right_child == fallthrough_child {
1956 fallthrough_index
1957 } else {
1958 jump_index
1959 }),
1960 None,
1961 ),
1962 tree::shared::MissingBranchDirection::Node => (None, Some(*missing_value)),
1963 }
1964 } else {
1965 (None, None)
1966 };
1967
1968 layout[current_index] = OptimizedBinaryRegressorNode::Branch {
1969 feature_index: remap_feature_index(*feature_index, feature_index_map),
1970 threshold_bin: *threshold_bin,
1971 jump_index,
1972 jump_if_greater,
1973 missing_bin,
1974 missing_jump_index,
1975 missing_value,
1976 };
1977 }
1978 }
1979
1980 current_index
1981}
1982
1983fn predict_oblivious_column_major_matrix(
1984 feature_indices: &[usize],
1985 threshold_bins: &[u16],
1986 leaf_values: &[f64],
1987 matrix: &ColumnMajorBinnedMatrix,
1988 executor: &InferenceExecutor,
1989) -> Vec<f64> {
1990 let mut outputs = vec![0.0; matrix.n_rows];
1991 executor.fill_chunks(
1992 &mut outputs,
1993 PARALLEL_INFERENCE_CHUNK_ROWS,
1994 |start_row, chunk| {
1995 predict_oblivious_chunk(
1996 feature_indices,
1997 threshold_bins,
1998 leaf_values,
1999 matrix,
2000 start_row,
2001 chunk,
2002 )
2003 },
2004 );
2005 outputs
2006}
2007
2008fn predict_oblivious_probabilities_column_major_matrix(
2009 feature_indices: &[usize],
2010 threshold_bins: &[u16],
2011 leaf_values: &[Vec<f64>],
2012 matrix: &ColumnMajorBinnedMatrix,
2013 _executor: &InferenceExecutor,
2014) -> Vec<Vec<f64>> {
2015 (0..matrix.n_rows)
2016 .map(|row_index| {
2017 predict_oblivious_probabilities_row(
2018 feature_indices,
2019 threshold_bins,
2020 leaf_values,
2021 |feature_index| matrix.column(feature_index).value_at(row_index),
2022 )
2023 .to_vec()
2024 })
2025 .collect()
2026}
2027
2028fn predict_oblivious_chunk(
2029 feature_indices: &[usize],
2030 threshold_bins: &[u16],
2031 leaf_values: &[f64],
2032 matrix: &ColumnMajorBinnedMatrix,
2033 start_row: usize,
2034 output: &mut [f64],
2035) {
2036 let processed = simd_predict_oblivious_chunk(
2037 feature_indices,
2038 threshold_bins,
2039 leaf_values,
2040 matrix,
2041 start_row,
2042 output,
2043 );
2044
2045 for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2046 let row_index = start_row + offset;
2047 *out = predict_oblivious_row(
2048 feature_indices,
2049 threshold_bins,
2050 leaf_values,
2051 |feature_index| matrix.column(feature_index).value_at(row_index),
2052 );
2053 }
2054}
2055
2056fn simd_predict_oblivious_chunk(
2057 feature_indices: &[usize],
2058 threshold_bins: &[u16],
2059 leaf_values: &[f64],
2060 matrix: &ColumnMajorBinnedMatrix,
2061 start_row: usize,
2062 output: &mut [f64],
2063) -> usize {
2064 let mut processed = 0usize;
2065 let ones = u32x8::splat(1);
2066
2067 while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2068 let base_row = start_row + processed;
2069 let mut leaf_indices = u32x8::splat(0);
2070
2071 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2072 let column = matrix.column(feature_index);
2073 let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2074 let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2075 .try_into()
2076 .expect("lane width matches the fixed SIMD width");
2077 u32x8::new([
2078 u32::from(lanes[0]),
2079 u32::from(lanes[1]),
2080 u32::from(lanes[2]),
2081 u32::from(lanes[3]),
2082 u32::from(lanes[4]),
2083 u32::from(lanes[5]),
2084 u32::from(lanes[6]),
2085 u32::from(lanes[7]),
2086 ])
2087 } else {
2088 let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2089 .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2090 .expect("column is u16 when not u8")
2091 .try_into()
2092 .expect("lane width matches the fixed SIMD width");
2093 u32x8::from(u16x8::new(lanes))
2094 };
2095 let threshold = u32x8::splat(u32::from(threshold_bin));
2096 let bit = bins.cmp_gt(threshold) & ones;
2097 leaf_indices = (leaf_indices << 1) | bit;
2098 }
2099
2100 let lane_indices = leaf_indices.to_array();
2101 for lane in 0..OBLIVIOUS_SIMD_LANES {
2102 output[processed + lane] =
2103 leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2104 }
2105 processed += OBLIVIOUS_SIMD_LANES;
2106 }
2107
2108 processed
2109}
2110
2111pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2112 training::train(train_set, config)
2113}
2114
2115#[cfg(test)]
2116mod tests;