1use forestfire_data::{
19 BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20 numeric_missing_bin,
21};
22#[cfg(feature = "polars")]
23use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
24use rayon::ThreadPoolBuilder;
25use rayon::prelude::*;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use std::collections::{BTreeMap, BTreeSet};
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31use std::sync::Arc;
32use wide::{u16x8, u32x8};
33
34mod boosting;
35mod bootstrap;
36mod compiled_artifact;
37mod forest;
38mod inference_input;
39mod introspection;
40pub mod ir;
41mod model_api;
42mod optimized_runtime;
43mod runtime_planning;
44mod sampling;
45mod training;
46pub mod tree;
47
48pub use boosting::BoostingError;
49pub use boosting::GradientBoostedTrees;
50pub use compiled_artifact::CompiledArtifactError;
51pub use forest::RandomForest;
52pub use introspection::IntrospectionError;
53pub use introspection::PredictionHistogramEntry;
54pub use introspection::PredictionValueStats;
55pub use introspection::TreeStructureSummary;
56pub use ir::IrError;
57pub use ir::ModelPackageIr;
58pub use model_api::OptimizedModel;
59pub use tree::classifier::DecisionTreeAlgorithm;
60pub use tree::classifier::DecisionTreeClassifier;
61pub use tree::classifier::DecisionTreeError;
62pub use tree::classifier::DecisionTreeOptions;
63pub use tree::classifier::train_c45;
64pub use tree::classifier::train_cart;
65pub use tree::classifier::train_id3;
66pub use tree::classifier::train_oblivious;
67pub use tree::classifier::train_randomized;
68pub use tree::regressor::DecisionTreeRegressor;
69pub use tree::regressor::RegressionTreeAlgorithm;
70pub use tree::regressor::RegressionTreeError;
71pub use tree::regressor::RegressionTreeOptions;
72pub use tree::regressor::train_cart_regressor;
73pub use tree::regressor::train_oblivious_regressor;
74pub use tree::regressor::train_randomized_regressor;
75#[cfg(feature = "polars")]
76const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
77pub(crate) use inference_input::ColumnMajorBinnedMatrix;
78pub(crate) use inference_input::CompactBinnedColumn;
79pub(crate) use inference_input::InferenceTable;
80pub(crate) use inference_input::ProjectedTableView;
81#[cfg(feature = "polars")]
82pub(crate) use inference_input::polars_named_columns;
83pub(crate) use introspection::prediction_value_stats;
84pub(crate) use introspection::tree_structure_summary;
85pub(crate) use optimized_runtime::InferenceExecutor;
86pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
87pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
88pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
89pub(crate) use optimized_runtime::OptimizedClassifierNode;
90pub(crate) use optimized_runtime::OptimizedRuntime;
91pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
93pub(crate) use optimized_runtime::resolve_inference_thread_count;
94pub(crate) use runtime_planning::build_feature_index_map;
95pub(crate) use runtime_planning::build_feature_projection;
96pub(crate) use runtime_planning::model_used_feature_indices;
97pub(crate) use runtime_planning::ordered_ensemble_indices;
98pub(crate) use runtime_planning::remap_feature_index;
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum TrainAlgorithm {
102 Dt,
104 Rf,
106 Gbm,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum Criterion {
112 Auto,
114 Gini,
116 Entropy,
118 Mean,
120 Median,
122 SecondOrder,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum Task {
128 Regression,
130 Classification,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TreeType {
136 Id3,
138 C45,
140 Cart,
142 Randomized,
144 Oblivious,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum MaxFeatures {
150 Auto,
152 All,
154 Sqrt,
156 Third,
158 Count(usize),
160}
161
162impl MaxFeatures {
163 pub fn resolve(self, task: Task, feature_count: usize) -> usize {
164 match self {
165 MaxFeatures::Auto => match task {
166 Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
167 Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
168 },
169 MaxFeatures::All => feature_count.max(1),
170 MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
171 MaxFeatures::Third => (feature_count / 3).max(1),
172 MaxFeatures::Count(count) => count.min(feature_count).max(1),
173 }
174 }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq)]
178pub enum CanaryFilter {
179 TopN(usize),
181 TopFraction(f64),
183}
184
185impl Default for CanaryFilter {
186 fn default() -> Self {
187 Self::TopN(1)
188 }
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum InputFeatureKind {
194 Numeric,
196 Binary,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
201pub struct NumericBinBoundary {
202 pub bin: u16,
204 pub upper_bound: f64,
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209#[serde(tag = "kind", rename_all = "snake_case")]
210pub enum FeaturePreprocessing {
211 Numeric {
213 bin_boundaries: Vec<NumericBinBoundary>,
214 missing_bin: u16,
215 },
216 Binary,
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum MissingValueStrategy {
222 Heuristic,
223 Optimal,
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
227pub enum MissingValueStrategyConfig {
228 Global(MissingValueStrategy),
229 PerFeature(BTreeMap<usize, MissingValueStrategy>),
230}
231
232impl MissingValueStrategyConfig {
233 pub fn heuristic() -> Self {
234 Self::Global(MissingValueStrategy::Heuristic)
235 }
236
237 pub fn optimal() -> Self {
238 Self::Global(MissingValueStrategy::Optimal)
239 }
240
241 pub fn resolve_for_feature_count(
242 &self,
243 feature_count: usize,
244 ) -> Result<Vec<MissingValueStrategy>, TrainError> {
245 match self {
246 MissingValueStrategyConfig::Global(strategy) => Ok(vec![*strategy; feature_count]),
247 MissingValueStrategyConfig::PerFeature(strategies) => {
248 let mut resolved = vec![MissingValueStrategy::Heuristic; feature_count];
249 for (&feature_index, &strategy) in strategies {
250 if feature_index >= feature_count {
251 return Err(TrainError::InvalidMissingValueStrategyFeature {
252 feature_index,
253 feature_count,
254 });
255 }
256 resolved[feature_index] = strategy;
257 }
258 Ok(resolved)
259 }
260 }
261 }
262}
263
264#[derive(Debug, Clone)]
270pub struct TrainConfig {
271 pub algorithm: TrainAlgorithm,
273 pub task: Task,
275 pub tree_type: TreeType,
277 pub criterion: Criterion,
279 pub max_depth: Option<usize>,
281 pub min_samples_split: Option<usize>,
283 pub min_samples_leaf: Option<usize>,
285 pub physical_cores: Option<usize>,
287 pub n_trees: Option<usize>,
289 pub max_features: MaxFeatures,
291 pub seed: Option<u64>,
293 pub canary_filter: CanaryFilter,
295 pub compute_oob: bool,
297 pub learning_rate: Option<f64>,
299 pub bootstrap: bool,
301 pub top_gradient_fraction: Option<f64>,
303 pub other_gradient_fraction: Option<f64>,
305 pub missing_value_strategy: MissingValueStrategyConfig,
307 pub histogram_bins: Option<NumericBins>,
313}
314
315impl Default for TrainConfig {
316 fn default() -> Self {
317 Self {
318 algorithm: TrainAlgorithm::Dt,
319 task: Task::Regression,
320 tree_type: TreeType::Cart,
321 criterion: Criterion::Auto,
322 max_depth: None,
323 min_samples_split: None,
324 min_samples_leaf: None,
325 physical_cores: None,
326 n_trees: None,
327 max_features: MaxFeatures::Auto,
328 seed: None,
329 canary_filter: CanaryFilter::default(),
330 compute_oob: false,
331 learning_rate: None,
332 bootstrap: false,
333 top_gradient_fraction: None,
334 other_gradient_fraction: None,
335 missing_value_strategy: MissingValueStrategyConfig::heuristic(),
336 histogram_bins: None,
337 }
338 }
339}
340
341#[derive(Debug, Clone)]
347pub enum Model {
348 DecisionTreeClassifier(DecisionTreeClassifier),
349 DecisionTreeRegressor(DecisionTreeRegressor),
350 RandomForest(RandomForest),
351 GradientBoostedTrees(GradientBoostedTrees),
352}
353
354#[derive(Debug)]
355pub enum TrainError {
356 DecisionTree(DecisionTreeError),
357 RegressionTree(RegressionTreeError),
358 Boosting(BoostingError),
359 InvalidPhysicalCoreCount {
360 requested: usize,
361 available: usize,
362 },
363 ThreadPoolBuildFailed(String),
364 UnsupportedConfiguration {
365 task: Task,
366 tree_type: TreeType,
367 criterion: Criterion,
368 },
369 InvalidMaxDepth(usize),
370 InvalidMinSamplesSplit(usize),
371 InvalidMinSamplesLeaf(usize),
372 InvalidTreeCount(usize),
373 InvalidMaxFeatures(usize),
374 InvalidCanaryFilterTopN(usize),
375 InvalidCanaryFilterTopFraction(f64),
376 InvalidMissingValueStrategyFeature {
377 feature_index: usize,
378 feature_count: usize,
379 },
380}
381
382impl Display for TrainError {
383 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384 match self {
385 TrainError::DecisionTree(err) => err.fmt(f),
386 TrainError::RegressionTree(err) => err.fmt(f),
387 TrainError::Boosting(err) => err.fmt(f),
388 TrainError::InvalidPhysicalCoreCount {
389 requested,
390 available,
391 } => write!(
392 f,
393 "Requested {} physical cores, but the available physical core count is {}.",
394 requested, available
395 ),
396 TrainError::ThreadPoolBuildFailed(message) => {
397 write!(f, "Failed to build training thread pool: {}.", message)
398 }
399 TrainError::UnsupportedConfiguration {
400 task,
401 tree_type,
402 criterion,
403 } => write!(
404 f,
405 "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
406 task, tree_type, criterion
407 ),
408 TrainError::InvalidMaxDepth(value) => {
409 write!(f, "max_depth must be at least 1. Received {}.", value)
410 }
411 TrainError::InvalidMinSamplesSplit(value) => {
412 write!(
413 f,
414 "min_samples_split must be at least 1. Received {}.",
415 value
416 )
417 }
418 TrainError::InvalidMinSamplesLeaf(value) => {
419 write!(
420 f,
421 "min_samples_leaf must be at least 1. Received {}.",
422 value
423 )
424 }
425 TrainError::InvalidTreeCount(n_trees) => {
426 write!(
427 f,
428 "Random forest requires at least one tree. Received {}.",
429 n_trees
430 )
431 }
432 TrainError::InvalidMaxFeatures(count) => {
433 write!(
434 f,
435 "max_features must be at least 1 when provided as an integer. Received {}.",
436 count
437 )
438 }
439 TrainError::InvalidCanaryFilterTopN(count) => {
440 write!(
441 f,
442 "canary_filter top-n must be at least 1. Received {}.",
443 count
444 )
445 }
446 TrainError::InvalidCanaryFilterTopFraction(fraction) => {
447 write!(
448 f,
449 "canary_filter top fraction must be finite and in (0, 1]. Received {}.",
450 fraction
451 )
452 }
453 TrainError::InvalidMissingValueStrategyFeature {
454 feature_index,
455 feature_count,
456 } => write!(
457 f,
458 "missing_value_strategy references feature {}, but the training table only has {} features.",
459 feature_index, feature_count
460 ),
461 }
462 }
463}
464
465impl Error for TrainError {}
466
467impl CanaryFilter {
468 pub(crate) fn validate(self) -> Result<(), TrainError> {
469 match self {
470 Self::TopN(0) => Err(TrainError::InvalidCanaryFilterTopN(0)),
471 Self::TopN(_) => Ok(()),
472 Self::TopFraction(fraction)
473 if !fraction.is_finite() || fraction <= 0.0 || fraction > 1.0 =>
474 {
475 Err(TrainError::InvalidCanaryFilterTopFraction(fraction))
476 }
477 Self::TopFraction(_) => Ok(()),
478 }
479 }
480
481 pub(crate) fn selection_size(self, candidate_count: usize) -> usize {
482 if candidate_count == 0 {
483 return 0;
484 }
485
486 match self {
487 Self::TopN(count) => count.clamp(1, candidate_count),
488 Self::TopFraction(fraction) => {
489 ((fraction * candidate_count as f64).ceil() as usize).clamp(1, candidate_count)
490 }
491 }
492 }
493}
494
495#[derive(Debug, Clone, PartialEq)]
496pub enum PredictError {
497 ProbabilityPredictionRequiresClassification,
498 RaggedRows {
499 row: usize,
500 expected: usize,
501 actual: usize,
502 },
503 FeatureCountMismatch {
504 expected: usize,
505 actual: usize,
506 },
507 ColumnLengthMismatch {
508 feature: String,
509 expected: usize,
510 actual: usize,
511 },
512 MissingFeature(String),
513 UnexpectedFeature(String),
514 InvalidBinaryValue {
515 feature_index: usize,
516 row_index: usize,
517 value: f64,
518 },
519 NullValue {
520 feature: String,
521 row_index: usize,
522 },
523 UnsupportedFeatureType {
524 feature: String,
525 dtype: String,
526 },
527 Polars(String),
528}
529
530impl Display for PredictError {
531 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
532 match self {
533 PredictError::ProbabilityPredictionRequiresClassification => write!(
534 f,
535 "predict_proba is only available for classification models."
536 ),
537 PredictError::RaggedRows {
538 row,
539 expected,
540 actual,
541 } => write!(
542 f,
543 "Ragged inference row at index {}: expected {} columns, found {}.",
544 row, expected, actual
545 ),
546 PredictError::FeatureCountMismatch { expected, actual } => write!(
547 f,
548 "Inference input has {} features, but the model expects {}.",
549 actual, expected
550 ),
551 PredictError::ColumnLengthMismatch {
552 feature,
553 expected,
554 actual,
555 } => write!(
556 f,
557 "Feature '{}' has {} values, expected {}.",
558 feature, actual, expected
559 ),
560 PredictError::MissingFeature(feature) => {
561 write!(f, "Missing required feature '{}'.", feature)
562 }
563 PredictError::UnexpectedFeature(feature) => {
564 write!(f, "Unexpected feature '{}'.", feature)
565 }
566 PredictError::InvalidBinaryValue {
567 feature_index,
568 row_index,
569 value,
570 } => write!(
571 f,
572 "Feature {} at row {} must be binary for inference, found {}.",
573 feature_index, row_index, value
574 ),
575 PredictError::NullValue { feature, row_index } => write!(
576 f,
577 "Feature '{}' contains a null value at row {}.",
578 feature, row_index
579 ),
580 PredictError::UnsupportedFeatureType { feature, dtype } => write!(
581 f,
582 "Feature '{}' has unsupported dtype '{}'.",
583 feature, dtype
584 ),
585 PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
586 }
587 }
588}
589
590impl Error for PredictError {}
591
592#[cfg(feature = "polars")]
593impl From<polars::error::PolarsError> for PredictError {
594 fn from(value: polars::error::PolarsError) -> Self {
595 PredictError::Polars(value.to_string())
596 }
597}
598
599#[derive(Debug)]
600pub enum OptimizeError {
601 InvalidPhysicalCoreCount { requested: usize, available: usize },
602 ThreadPoolBuildFailed(String),
603 UnsupportedModelType(&'static str),
604}
605
606impl Display for OptimizeError {
607 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
608 match self {
609 OptimizeError::InvalidPhysicalCoreCount {
610 requested,
611 available,
612 } => write!(
613 f,
614 "Requested {} physical cores, but the available physical core count is {}.",
615 requested, available
616 ),
617 OptimizeError::ThreadPoolBuildFailed(message) => {
618 write!(f, "Failed to build inference thread pool: {}.", message)
619 }
620 OptimizeError::UnsupportedModelType(model_type) => {
621 write!(
622 f,
623 "Optimized inference is not supported for model type '{}'.",
624 model_type
625 )
626 }
627 }
628 }
629}
630
631impl Error for OptimizeError {}
632
633#[derive(Debug, Clone, Copy)]
634pub(crate) struct Parallelism {
635 thread_count: usize,
636}
637
638impl Parallelism {
639 pub(crate) fn sequential() -> Self {
640 Self { thread_count: 1 }
641 }
642
643 #[cfg(test)]
644 pub(crate) fn with_threads(thread_count: usize) -> Self {
645 Self {
646 thread_count: thread_count.max(1),
647 }
648 }
649
650 pub(crate) fn enabled(self) -> bool {
651 self.thread_count > 1
652 }
653}
654
655pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
656 (0..table.n_features())
657 .map(|feature_index| {
658 if table.is_binary_feature(feature_index) {
659 FeaturePreprocessing::Binary
660 } else {
661 let values = (0..table.n_rows())
662 .map(|row_index| table.feature_value(feature_index, row_index))
663 .collect::<Vec<_>>();
664 FeaturePreprocessing::Numeric {
665 bin_boundaries: numeric_bin_boundaries(
666 &values,
667 NumericBins::Fixed(table.numeric_bin_cap()),
668 )
669 .into_iter()
670 .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
671 .collect(),
672 missing_bin: numeric_missing_bin(NumericBins::Fixed(table.numeric_bin_cap())),
673 }
674 }
675 })
676 .collect()
677}
678
679fn missing_feature_enabled(
680 feature_index: usize,
681 missing_features: Option<&BTreeSet<usize>>,
682) -> bool {
683 missing_features.is_none_or(|features| features.contains(&feature_index))
684}
685
686fn optimized_missing_bin(
687 preprocessing: &[FeaturePreprocessing],
688 feature_index: usize,
689 missing_features: Option<&BTreeSet<usize>>,
690) -> Option<u16> {
691 if !missing_feature_enabled(feature_index, missing_features) {
692 return None;
693 }
694
695 match preprocessing.get(feature_index) {
696 Some(FeaturePreprocessing::Binary) => Some(forestfire_data::BINARY_MISSING_BIN),
697 Some(FeaturePreprocessing::Numeric { missing_bin, .. }) => Some(*missing_bin),
698 None => None,
699 }
700}
701
702impl OptimizedRuntime {
703 fn supports_batch_matrix(&self) -> bool {
704 matches!(
705 self,
706 OptimizedRuntime::BinaryClassifier { .. }
707 | OptimizedRuntime::BinaryRegressor { .. }
708 | OptimizedRuntime::ObliviousClassifier { .. }
709 | OptimizedRuntime::ObliviousRegressor { .. }
710 | OptimizedRuntime::ForestClassifier { .. }
711 | OptimizedRuntime::ForestRegressor { .. }
712 | OptimizedRuntime::BoostedBinaryClassifier { .. }
713 | OptimizedRuntime::BoostedRegressor { .. }
714 )
715 }
716
717 fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
718 n_rows > 1 && self.supports_batch_matrix()
719 }
720
721 fn from_model(
722 model: &Model,
723 feature_index_map: &[usize],
724 missing_features: Option<&BTreeSet<usize>>,
725 ) -> Self {
726 match model {
727 Model::DecisionTreeClassifier(classifier) => {
728 Self::from_classifier(classifier, feature_index_map, missing_features)
729 }
730 Model::DecisionTreeRegressor(regressor) => {
731 Self::from_regressor(regressor, feature_index_map, missing_features)
732 }
733 Model::RandomForest(forest) => match forest.task() {
734 Task::Classification => {
735 let tree_order = ordered_ensemble_indices(forest.trees());
736 Self::ForestClassifier {
737 trees: tree_order
738 .into_iter()
739 .map(|tree_index| {
740 Self::from_model(
741 &forest.trees()[tree_index],
742 feature_index_map,
743 missing_features,
744 )
745 })
746 .collect(),
747 class_labels: forest
748 .class_labels()
749 .expect("classification forest stores class labels"),
750 }
751 }
752 Task::Regression => {
753 let tree_order = ordered_ensemble_indices(forest.trees());
754 Self::ForestRegressor {
755 trees: tree_order
756 .into_iter()
757 .map(|tree_index| {
758 Self::from_model(
759 &forest.trees()[tree_index],
760 feature_index_map,
761 missing_features,
762 )
763 })
764 .collect(),
765 }
766 }
767 },
768 Model::GradientBoostedTrees(model) => match model.task() {
769 Task::Classification => {
770 let tree_order = ordered_ensemble_indices(model.trees());
771 Self::BoostedBinaryClassifier {
772 trees: tree_order
773 .iter()
774 .map(|tree_index| {
775 Self::from_model(
776 &model.trees()[*tree_index],
777 feature_index_map,
778 missing_features,
779 )
780 })
781 .collect(),
782 tree_weights: tree_order
783 .iter()
784 .map(|tree_index| model.tree_weights()[*tree_index])
785 .collect(),
786 base_score: model.base_score(),
787 class_labels: model
788 .class_labels()
789 .expect("classification boosting stores class labels"),
790 }
791 }
792 Task::Regression => {
793 let tree_order = ordered_ensemble_indices(model.trees());
794 Self::BoostedRegressor {
795 trees: tree_order
796 .iter()
797 .map(|tree_index| {
798 Self::from_model(
799 &model.trees()[*tree_index],
800 feature_index_map,
801 missing_features,
802 )
803 })
804 .collect(),
805 tree_weights: tree_order
806 .iter()
807 .map(|tree_index| model.tree_weights()[*tree_index])
808 .collect(),
809 base_score: model.base_score(),
810 }
811 }
812 },
813 }
814 }
815
816 fn from_classifier(
817 classifier: &DecisionTreeClassifier,
818 feature_index_map: &[usize],
819 missing_features: Option<&BTreeSet<usize>>,
820 ) -> Self {
821 match classifier.structure() {
822 tree::classifier::TreeStructure::Standard { nodes, root } => {
823 if classifier_nodes_are_binary_only(nodes) {
824 return Self::BinaryClassifier {
825 nodes: build_binary_classifier_layout(
826 nodes,
827 *root,
828 classifier.class_labels(),
829 feature_index_map,
830 classifier.feature_preprocessing(),
831 missing_features,
832 ),
833 class_labels: classifier.class_labels().to_vec(),
834 };
835 }
836
837 let optimized_nodes = nodes
838 .iter()
839 .map(|node| match node {
840 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
841 OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
842 class_counts,
843 ))
844 }
845 tree::classifier::TreeNode::BinarySplit {
846 feature_index,
847 threshold_bin,
848 missing_direction,
849 left_child,
850 right_child,
851 class_counts,
852 ..
853 } => OptimizedClassifierNode::Binary {
854 feature_index: remap_feature_index(*feature_index, feature_index_map),
855 threshold_bin: *threshold_bin,
856 children: [*left_child, *right_child],
857 missing_bin: optimized_missing_bin(
858 classifier.feature_preprocessing(),
859 *feature_index,
860 missing_features,
861 ),
862 missing_child: if missing_feature_enabled(
863 *feature_index,
864 missing_features,
865 ) {
866 match missing_direction {
867 tree::shared::MissingBranchDirection::Left => Some(*left_child),
868 tree::shared::MissingBranchDirection::Right => {
869 Some(*right_child)
870 }
871 tree::shared::MissingBranchDirection::Node => None,
872 }
873 } else {
874 None
875 },
876 missing_probabilities: if missing_feature_enabled(
877 *feature_index,
878 missing_features,
879 ) && matches!(
880 missing_direction,
881 tree::shared::MissingBranchDirection::Node
882 ) {
883 Some(normalized_probabilities_from_counts(class_counts))
884 } else {
885 None
886 },
887 },
888 tree::classifier::TreeNode::MultiwaySplit {
889 feature_index,
890 class_counts,
891 branches,
892 missing_child,
893 ..
894 } => {
895 let max_bin_index = branches
896 .iter()
897 .map(|(bin, _)| usize::from(*bin))
898 .max()
899 .unwrap_or(0);
900 let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
901 for (bin, child_index) in branches {
902 child_lookup[usize::from(*bin)] = *child_index;
903 }
904 OptimizedClassifierNode::Multiway {
905 feature_index: remap_feature_index(
906 *feature_index,
907 feature_index_map,
908 ),
909 child_lookup,
910 max_bin_index,
911 missing_bin: optimized_missing_bin(
912 classifier.feature_preprocessing(),
913 *feature_index,
914 missing_features,
915 ),
916 missing_child: if missing_feature_enabled(
917 *feature_index,
918 missing_features,
919 ) {
920 *missing_child
921 } else {
922 None
923 },
924 fallback_probabilities: normalized_probabilities_from_counts(
925 class_counts,
926 ),
927 }
928 }
929 })
930 .collect();
931
932 Self::StandardClassifier {
933 nodes: optimized_nodes,
934 root: *root,
935 class_labels: classifier.class_labels().to_vec(),
936 }
937 }
938 tree::classifier::TreeStructure::Oblivious {
939 splits,
940 leaf_class_counts,
941 ..
942 } => Self::ObliviousClassifier {
943 feature_indices: splits
944 .iter()
945 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
946 .collect(),
947 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
948 leaf_values: leaf_class_counts
949 .iter()
950 .map(|class_counts| normalized_probabilities_from_counts(class_counts))
951 .collect(),
952 class_labels: classifier.class_labels().to_vec(),
953 },
954 }
955 }
956
957 fn from_regressor(
958 regressor: &DecisionTreeRegressor,
959 feature_index_map: &[usize],
960 missing_features: Option<&BTreeSet<usize>>,
961 ) -> Self {
962 match regressor.structure() {
963 tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
964 Self::BinaryRegressor {
965 nodes: build_binary_regressor_layout(
966 nodes,
967 *root,
968 feature_index_map,
969 regressor.feature_preprocessing(),
970 missing_features,
971 ),
972 }
973 }
974 tree::regressor::RegressionTreeStructure::Oblivious {
975 splits,
976 leaf_values,
977 ..
978 } => Self::ObliviousRegressor {
979 feature_indices: splits
980 .iter()
981 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
982 .collect(),
983 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
984 leaf_values: leaf_values.clone(),
985 },
986 }
987 }
988
989 #[inline(always)]
990 fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
991 match self {
992 OptimizedRuntime::BinaryClassifier { .. }
993 | OptimizedRuntime::StandardClassifier { .. }
994 | OptimizedRuntime::ObliviousClassifier { .. }
995 | OptimizedRuntime::ForestClassifier { .. }
996 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
997 let probabilities = self
998 .predict_proba_table_row(table, row_index)
999 .expect("classifier runtime supports probability prediction");
1000 class_label_from_probabilities(&probabilities, self.class_labels())
1001 }
1002 OptimizedRuntime::BinaryRegressor { nodes } => {
1003 predict_binary_regressor_row(nodes, |feature_index| {
1004 table.binned_value(feature_index, row_index)
1005 })
1006 }
1007 OptimizedRuntime::ObliviousRegressor {
1008 feature_indices,
1009 threshold_bins,
1010 leaf_values,
1011 } => predict_oblivious_row(
1012 feature_indices,
1013 threshold_bins,
1014 leaf_values,
1015 |feature_index| table.binned_value(feature_index, row_index),
1016 ),
1017 OptimizedRuntime::ForestRegressor { trees } => {
1018 trees
1019 .iter()
1020 .map(|tree| tree.predict_table_row(table, row_index))
1021 .sum::<f64>()
1022 / trees.len() as f64
1023 }
1024 OptimizedRuntime::BoostedRegressor {
1025 trees,
1026 tree_weights,
1027 base_score,
1028 } => {
1029 *base_score
1030 + trees
1031 .iter()
1032 .zip(tree_weights.iter().copied())
1033 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1034 .sum::<f64>()
1035 }
1036 }
1037 }
1038
1039 #[inline(always)]
1040 fn predict_proba_table_row(
1041 &self,
1042 table: &dyn TableAccess,
1043 row_index: usize,
1044 ) -> Result<Vec<f64>, PredictError> {
1045 match self {
1046 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1047 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1048 table.binned_value(feature_index, row_index)
1049 })
1050 .to_vec(),
1051 ),
1052 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1053 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1054 table.binned_value(feature_index, row_index)
1055 })
1056 .to_vec(),
1057 ),
1058 OptimizedRuntime::ObliviousClassifier {
1059 feature_indices,
1060 threshold_bins,
1061 leaf_values,
1062 ..
1063 } => Ok(predict_oblivious_probabilities_row(
1064 feature_indices,
1065 threshold_bins,
1066 leaf_values,
1067 |feature_index| table.binned_value(feature_index, row_index),
1068 )
1069 .to_vec()),
1070 OptimizedRuntime::ForestClassifier { trees, .. } => {
1071 let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1072 for tree in &trees[1..] {
1073 let row = tree.predict_proba_table_row(table, row_index)?;
1074 for (total, value) in totals.iter_mut().zip(row) {
1075 *total += value;
1076 }
1077 }
1078 let tree_count = trees.len() as f64;
1079 for value in &mut totals {
1080 *value /= tree_count;
1081 }
1082 Ok(totals)
1083 }
1084 OptimizedRuntime::BoostedBinaryClassifier {
1085 trees,
1086 tree_weights,
1087 base_score,
1088 ..
1089 } => {
1090 let raw_score = *base_score
1091 + trees
1092 .iter()
1093 .zip(tree_weights.iter().copied())
1094 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1095 .sum::<f64>();
1096 let positive = sigmoid(raw_score);
1097 Ok(vec![1.0 - positive, positive])
1098 }
1099 OptimizedRuntime::BinaryRegressor { .. }
1100 | OptimizedRuntime::ObliviousRegressor { .. }
1101 | OptimizedRuntime::ForestRegressor { .. }
1102 | OptimizedRuntime::BoostedRegressor { .. } => {
1103 Err(PredictError::ProbabilityPredictionRequiresClassification)
1104 }
1105 }
1106 }
1107
1108 fn predict_proba_table(
1109 &self,
1110 table: &dyn TableAccess,
1111 executor: &InferenceExecutor,
1112 ) -> Result<Vec<Vec<f64>>, PredictError> {
1113 match self {
1114 OptimizedRuntime::BinaryClassifier { .. }
1115 | OptimizedRuntime::StandardClassifier { .. }
1116 | OptimizedRuntime::ObliviousClassifier { .. }
1117 | OptimizedRuntime::ForestClassifier { .. }
1118 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1119 if self.should_use_batch_matrix(table.n_rows()) {
1120 let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1121 self.predict_proba_column_major_matrix(&matrix, executor)
1122 } else {
1123 (0..table.n_rows())
1124 .map(|row_index| self.predict_proba_table_row(table, row_index))
1125 .collect()
1126 }
1127 }
1128 OptimizedRuntime::BinaryRegressor { .. }
1129 | OptimizedRuntime::ObliviousRegressor { .. }
1130 | OptimizedRuntime::ForestRegressor { .. }
1131 | OptimizedRuntime::BoostedRegressor { .. } => {
1132 Err(PredictError::ProbabilityPredictionRequiresClassification)
1133 }
1134 }
1135 }
1136
1137 fn predict_column_major_matrix(
1138 &self,
1139 matrix: &ColumnMajorBinnedMatrix,
1140 executor: &InferenceExecutor,
1141 ) -> Vec<f64> {
1142 match self {
1143 OptimizedRuntime::BinaryClassifier { .. }
1144 | OptimizedRuntime::StandardClassifier { .. }
1145 | OptimizedRuntime::ObliviousClassifier { .. }
1146 | OptimizedRuntime::ForestClassifier { .. }
1147 | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1148 .predict_proba_column_major_matrix(matrix, executor)
1149 .expect("classifier runtime supports probability prediction")
1150 .into_iter()
1151 .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1152 .collect(),
1153 OptimizedRuntime::BinaryRegressor { nodes } => {
1154 predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1155 }
1156 OptimizedRuntime::ObliviousRegressor {
1157 feature_indices,
1158 threshold_bins,
1159 leaf_values,
1160 } => predict_oblivious_column_major_matrix(
1161 feature_indices,
1162 threshold_bins,
1163 leaf_values,
1164 matrix,
1165 executor,
1166 ),
1167 OptimizedRuntime::ForestRegressor { trees } => {
1168 let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1169 for tree in &trees[1..] {
1170 let values = tree.predict_column_major_matrix(matrix, executor);
1171 for (total, value) in totals.iter_mut().zip(values) {
1172 *total += value;
1173 }
1174 }
1175 let tree_count = trees.len() as f64;
1176 for total in &mut totals {
1177 *total /= tree_count;
1178 }
1179 totals
1180 }
1181 OptimizedRuntime::BoostedRegressor {
1182 trees,
1183 tree_weights,
1184 base_score,
1185 } => {
1186 let mut totals = vec![*base_score; matrix.n_rows];
1187 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1188 let values = tree.predict_column_major_matrix(matrix, executor);
1189 for (total, value) in totals.iter_mut().zip(values) {
1190 *total += weight * value;
1191 }
1192 }
1193 totals
1194 }
1195 }
1196 }
1197
1198 fn predict_proba_column_major_matrix(
1199 &self,
1200 matrix: &ColumnMajorBinnedMatrix,
1201 executor: &InferenceExecutor,
1202 ) -> Result<Vec<Vec<f64>>, PredictError> {
1203 match self {
1204 OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1205 Ok(predict_binary_classifier_probabilities_column_major_matrix(
1206 nodes, matrix, executor,
1207 ))
1208 }
1209 OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1210 .map(|row_index| {
1211 self.predict_proba_binned_row_from_columns(matrix, row_index)
1212 .expect("classifier runtime supports probability prediction")
1213 })
1214 .collect()),
1215 OptimizedRuntime::ObliviousClassifier {
1216 feature_indices,
1217 threshold_bins,
1218 leaf_values,
1219 ..
1220 } => Ok(predict_oblivious_probabilities_column_major_matrix(
1221 feature_indices,
1222 threshold_bins,
1223 leaf_values,
1224 matrix,
1225 executor,
1226 )),
1227 OptimizedRuntime::ForestClassifier { trees, .. } => {
1228 let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1229 for tree in &trees[1..] {
1230 let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1231 for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1232 for (total, value) in row_totals.iter_mut().zip(row_values) {
1233 *total += value;
1234 }
1235 }
1236 }
1237 let tree_count = trees.len() as f64;
1238 for row in &mut totals {
1239 for value in row {
1240 *value /= tree_count;
1241 }
1242 }
1243 Ok(totals)
1244 }
1245 OptimizedRuntime::BoostedBinaryClassifier {
1246 trees,
1247 tree_weights,
1248 base_score,
1249 ..
1250 } => {
1251 let mut raw_scores = vec![*base_score; matrix.n_rows];
1252 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1253 let values = tree.predict_column_major_matrix(matrix, executor);
1254 for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1255 *raw_score += weight * value;
1256 }
1257 }
1258 Ok(raw_scores
1259 .into_iter()
1260 .map(|raw_score| {
1261 let positive = sigmoid(raw_score);
1262 vec![1.0 - positive, positive]
1263 })
1264 .collect())
1265 }
1266 OptimizedRuntime::BinaryRegressor { .. }
1267 | OptimizedRuntime::ObliviousRegressor { .. }
1268 | OptimizedRuntime::ForestRegressor { .. }
1269 | OptimizedRuntime::BoostedRegressor { .. } => {
1270 Err(PredictError::ProbabilityPredictionRequiresClassification)
1271 }
1272 }
1273 }
1274
1275 fn class_labels(&self) -> &[f64] {
1276 match self {
1277 OptimizedRuntime::BinaryClassifier { class_labels, .. }
1278 | OptimizedRuntime::StandardClassifier { class_labels, .. }
1279 | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1280 | OptimizedRuntime::ForestClassifier { class_labels, .. }
1281 | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1282 _ => &[],
1283 }
1284 }
1285
1286 #[inline(always)]
1287 fn predict_binned_row_from_columns(
1288 &self,
1289 matrix: &ColumnMajorBinnedMatrix,
1290 row_index: usize,
1291 ) -> f64 {
1292 match self {
1293 OptimizedRuntime::BinaryRegressor { nodes } => {
1294 predict_binary_regressor_row(nodes, |feature_index| {
1295 matrix.column(feature_index).value_at(row_index)
1296 })
1297 }
1298 OptimizedRuntime::ObliviousRegressor {
1299 feature_indices,
1300 threshold_bins,
1301 leaf_values,
1302 } => predict_oblivious_row(
1303 feature_indices,
1304 threshold_bins,
1305 leaf_values,
1306 |feature_index| matrix.column(feature_index).value_at(row_index),
1307 ),
1308 OptimizedRuntime::BoostedRegressor {
1309 trees,
1310 tree_weights,
1311 base_score,
1312 } => {
1313 *base_score
1314 + trees
1315 .iter()
1316 .zip(tree_weights.iter().copied())
1317 .map(|(tree, weight)| {
1318 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1319 })
1320 .sum::<f64>()
1321 }
1322 _ => self.predict_column_major_matrix(
1323 matrix,
1324 &InferenceExecutor::new(1).expect("inference executor"),
1325 )[row_index],
1326 }
1327 }
1328
1329 #[inline(always)]
1330 fn predict_proba_binned_row_from_columns(
1331 &self,
1332 matrix: &ColumnMajorBinnedMatrix,
1333 row_index: usize,
1334 ) -> Result<Vec<f64>, PredictError> {
1335 match self {
1336 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1337 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1338 matrix.column(feature_index).value_at(row_index)
1339 })
1340 .to_vec(),
1341 ),
1342 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1343 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1344 matrix.column(feature_index).value_at(row_index)
1345 })
1346 .to_vec(),
1347 ),
1348 OptimizedRuntime::ObliviousClassifier {
1349 feature_indices,
1350 threshold_bins,
1351 leaf_values,
1352 ..
1353 } => Ok(predict_oblivious_probabilities_row(
1354 feature_indices,
1355 threshold_bins,
1356 leaf_values,
1357 |feature_index| matrix.column(feature_index).value_at(row_index),
1358 )
1359 .to_vec()),
1360 OptimizedRuntime::ForestClassifier { trees, .. } => {
1361 let mut totals =
1362 trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1363 for tree in &trees[1..] {
1364 let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1365 for (total, value) in totals.iter_mut().zip(row) {
1366 *total += value;
1367 }
1368 }
1369 let tree_count = trees.len() as f64;
1370 for value in &mut totals {
1371 *value /= tree_count;
1372 }
1373 Ok(totals)
1374 }
1375 OptimizedRuntime::BoostedBinaryClassifier {
1376 trees,
1377 tree_weights,
1378 base_score,
1379 ..
1380 } => {
1381 let raw_score = *base_score
1382 + trees
1383 .iter()
1384 .zip(tree_weights.iter().copied())
1385 .map(|(tree, weight)| {
1386 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1387 })
1388 .sum::<f64>();
1389 let positive = sigmoid(raw_score);
1390 Ok(vec![1.0 - positive, positive])
1391 }
1392 OptimizedRuntime::BinaryRegressor { .. }
1393 | OptimizedRuntime::ObliviousRegressor { .. }
1394 | OptimizedRuntime::ForestRegressor { .. }
1395 | OptimizedRuntime::BoostedRegressor { .. } => {
1396 Err(PredictError::ProbabilityPredictionRequiresClassification)
1397 }
1398 }
1399 }
1400}
1401
1402#[inline(always)]
1403fn predict_standard_classifier_probabilities_row<F>(
1404 nodes: &[OptimizedClassifierNode],
1405 root: usize,
1406 bin_at: F,
1407) -> &[f64]
1408where
1409 F: Fn(usize) -> u16,
1410{
1411 let mut node_index = root;
1412 loop {
1413 match &nodes[node_index] {
1414 OptimizedClassifierNode::Leaf(value) => return value,
1415 OptimizedClassifierNode::Binary {
1416 feature_index,
1417 threshold_bin,
1418 children,
1419 missing_bin,
1420 missing_child,
1421 missing_probabilities,
1422 } => {
1423 let bin = bin_at(*feature_index);
1424 if missing_bin.is_some_and(|expected| expected == bin) {
1425 if let Some(probabilities) = missing_probabilities {
1426 return probabilities;
1427 }
1428 if let Some(child_index) = missing_child {
1429 node_index = *child_index;
1430 continue;
1431 }
1432 }
1433 let go_right = usize::from(bin > *threshold_bin);
1434 node_index = children[go_right];
1435 }
1436 OptimizedClassifierNode::Multiway {
1437 feature_index,
1438 child_lookup,
1439 max_bin_index,
1440 missing_bin,
1441 missing_child,
1442 fallback_probabilities,
1443 } => {
1444 let bin_value = bin_at(*feature_index);
1445 if missing_bin.is_some_and(|expected| expected == bin_value) {
1446 if let Some(child_index) = missing_child {
1447 node_index = *child_index;
1448 continue;
1449 }
1450 return fallback_probabilities;
1451 }
1452 let bin = usize::from(bin_value);
1453 if bin > *max_bin_index {
1454 return fallback_probabilities;
1455 }
1456 let child_index = child_lookup[bin];
1457 if child_index == usize::MAX {
1458 return fallback_probabilities;
1459 }
1460 node_index = child_index;
1461 }
1462 }
1463 }
1464}
1465
1466#[inline(always)]
1467fn predict_binary_classifier_probabilities_row<F>(
1468 nodes: &[OptimizedBinaryClassifierNode],
1469 bin_at: F,
1470) -> &[f64]
1471where
1472 F: Fn(usize) -> u16,
1473{
1474 let mut node_index = 0usize;
1475 loop {
1476 match &nodes[node_index] {
1477 OptimizedBinaryClassifierNode::Leaf(value) => return value,
1478 OptimizedBinaryClassifierNode::Branch {
1479 feature_index,
1480 threshold_bin,
1481 jump_index,
1482 jump_if_greater,
1483 missing_bin,
1484 missing_jump_index,
1485 missing_probabilities,
1486 } => {
1487 let bin = bin_at(*feature_index);
1488 if missing_bin.is_some_and(|expected| expected == bin) {
1489 if let Some(probabilities) = missing_probabilities {
1490 return probabilities;
1491 }
1492 if let Some(jump_index) = missing_jump_index {
1493 node_index = *jump_index;
1494 continue;
1495 }
1496 }
1497 let go_right = bin > *threshold_bin;
1498 node_index = if go_right == *jump_if_greater {
1499 *jump_index
1500 } else {
1501 node_index + 1
1502 };
1503 }
1504 }
1505 }
1506}
1507
1508#[inline(always)]
1509fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1510where
1511 F: Fn(usize) -> u16,
1512{
1513 let mut node_index = 0usize;
1514 loop {
1515 match &nodes[node_index] {
1516 OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1517 OptimizedBinaryRegressorNode::Branch {
1518 feature_index,
1519 threshold_bin,
1520 jump_index,
1521 jump_if_greater,
1522 missing_bin,
1523 missing_jump_index,
1524 missing_value,
1525 } => {
1526 let bin = bin_at(*feature_index);
1527 if missing_bin.is_some_and(|expected| expected == bin) {
1528 if let Some(value) = missing_value {
1529 return *value;
1530 }
1531 if let Some(jump_index) = missing_jump_index {
1532 node_index = *jump_index;
1533 continue;
1534 }
1535 }
1536 let go_right = bin > *threshold_bin;
1537 node_index = if go_right == *jump_if_greater {
1538 *jump_index
1539 } else {
1540 node_index + 1
1541 };
1542 }
1543 }
1544 }
1545}
1546
1547fn predict_binary_classifier_probabilities_column_major_matrix(
1548 nodes: &[OptimizedBinaryClassifierNode],
1549 matrix: &ColumnMajorBinnedMatrix,
1550 _executor: &InferenceExecutor,
1551) -> Vec<Vec<f64>> {
1552 if binary_classifier_nodes_require_rowwise_missing(nodes) {
1553 return (0..matrix.n_rows)
1554 .map(|row_index| {
1555 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1556 matrix.column(feature_index).value_at(row_index)
1557 })
1558 .to_vec()
1559 })
1560 .collect();
1561 }
1562 (0..matrix.n_rows)
1563 .map(|row_index| {
1564 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1565 matrix.column(feature_index).value_at(row_index)
1566 })
1567 .to_vec()
1568 })
1569 .collect()
1570}
1571
1572fn predict_binary_regressor_column_major_matrix(
1573 nodes: &[OptimizedBinaryRegressorNode],
1574 matrix: &ColumnMajorBinnedMatrix,
1575 executor: &InferenceExecutor,
1576) -> Vec<f64> {
1577 if binary_regressor_nodes_require_rowwise_missing(nodes) {
1578 return (0..matrix.n_rows)
1579 .map(|row_index| {
1580 predict_binary_regressor_row(nodes, |feature_index| {
1581 matrix.column(feature_index).value_at(row_index)
1582 })
1583 })
1584 .collect();
1585 }
1586 let mut outputs = vec![0.0; matrix.n_rows];
1587 executor.fill_chunks(
1588 &mut outputs,
1589 STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1590 |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1591 );
1592 outputs
1593}
1594
1595fn predict_binary_regressor_chunk(
1596 nodes: &[OptimizedBinaryRegressorNode],
1597 matrix: &ColumnMajorBinnedMatrix,
1598 start_row: usize,
1599 output: &mut [f64],
1600) {
1601 let mut row_indices: Vec<usize> = (0..output.len()).collect();
1602 let mut stack = vec![(0usize, 0usize, output.len())];
1603
1604 while let Some((node_index, start, end)) = stack.pop() {
1605 match &nodes[node_index] {
1606 OptimizedBinaryRegressorNode::Leaf(value) => {
1607 for position in start..end {
1608 output[row_indices[position]] = *value;
1609 }
1610 }
1611 OptimizedBinaryRegressorNode::Branch {
1612 feature_index,
1613 threshold_bin,
1614 jump_index,
1615 jump_if_greater,
1616 ..
1617 } => {
1618 let fallthrough_index = node_index + 1;
1619 if *jump_index == fallthrough_index {
1620 stack.push((fallthrough_index, start, end));
1621 continue;
1622 }
1623
1624 let column = matrix.column(*feature_index);
1625 let mut partition = start;
1626 let mut jump_start = end;
1627 match column {
1628 CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1629 let threshold = *threshold_bin as u8;
1630 while partition < jump_start {
1631 let row_offset = row_indices[partition];
1632 let go_right = values[start_row + row_offset] > threshold;
1633 let goes_jump = go_right == *jump_if_greater;
1634 if goes_jump {
1635 jump_start -= 1;
1636 row_indices.swap(partition, jump_start);
1637 } else {
1638 partition += 1;
1639 }
1640 }
1641 }
1642 _ => {
1643 while partition < jump_start {
1644 let row_offset = row_indices[partition];
1645 let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1646 let goes_jump = go_right == *jump_if_greater;
1647 if goes_jump {
1648 jump_start -= 1;
1649 row_indices.swap(partition, jump_start);
1650 } else {
1651 partition += 1;
1652 }
1653 }
1654 }
1655 }
1656
1657 if jump_start < end {
1658 stack.push((*jump_index, jump_start, end));
1659 }
1660 if start < jump_start {
1661 stack.push((fallthrough_index, start, jump_start));
1662 }
1663 }
1664 }
1665 }
1666}
1667
1668fn binary_classifier_nodes_require_rowwise_missing(
1669 nodes: &[OptimizedBinaryClassifierNode],
1670) -> bool {
1671 nodes.iter().any(|node| match node {
1672 OptimizedBinaryClassifierNode::Leaf(_) => false,
1673 OptimizedBinaryClassifierNode::Branch {
1674 missing_bin,
1675 missing_jump_index,
1676 missing_probabilities,
1677 ..
1678 } => {
1679 missing_bin.is_some() || missing_jump_index.is_some() || missing_probabilities.is_some()
1680 }
1681 })
1682}
1683
1684fn binary_regressor_nodes_require_rowwise_missing(nodes: &[OptimizedBinaryRegressorNode]) -> bool {
1685 nodes.iter().any(|node| match node {
1686 OptimizedBinaryRegressorNode::Leaf(_) => false,
1687 OptimizedBinaryRegressorNode::Branch {
1688 missing_bin,
1689 missing_jump_index,
1690 missing_value,
1691 ..
1692 } => missing_bin.is_some() || missing_jump_index.is_some() || missing_value.is_some(),
1693 })
1694}
1695
1696#[inline(always)]
1697fn predict_oblivious_row<F>(
1698 feature_indices: &[usize],
1699 threshold_bins: &[u16],
1700 leaf_values: &[f64],
1701 bin_at: F,
1702) -> f64
1703where
1704 F: Fn(usize) -> u16,
1705{
1706 let mut leaf_index = 0usize;
1707 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1708 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1709 leaf_index = (leaf_index << 1) | go_right;
1710 }
1711 leaf_values[leaf_index]
1712}
1713
1714#[inline(always)]
1715fn predict_oblivious_probabilities_row<'a, F>(
1716 feature_indices: &[usize],
1717 threshold_bins: &[u16],
1718 leaf_values: &'a [Vec<f64>],
1719 bin_at: F,
1720) -> &'a [f64]
1721where
1722 F: Fn(usize) -> u16,
1723{
1724 let mut leaf_index = 0usize;
1725 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1726 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1727 leaf_index = (leaf_index << 1) | go_right;
1728 }
1729 leaf_values[leaf_index].as_slice()
1730}
1731
1732fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1733 let total = class_counts.iter().sum::<usize>();
1734 if total == 0 {
1735 return vec![0.0; class_counts.len()];
1736 }
1737
1738 class_counts
1739 .iter()
1740 .map(|count| *count as f64 / total as f64)
1741 .collect()
1742}
1743
1744fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1745 let best_index = probabilities
1746 .iter()
1747 .copied()
1748 .enumerate()
1749 .max_by(|(left_index, left), (right_index, right)| {
1750 left.total_cmp(right)
1751 .then_with(|| right_index.cmp(left_index))
1752 })
1753 .map(|(index, _)| index)
1754 .expect("classification probability row is non-empty");
1755 class_labels[best_index]
1756}
1757
1758#[inline(always)]
1759fn sigmoid(value: f64) -> f64 {
1760 if value >= 0.0 {
1761 let exp = (-value).exp();
1762 1.0 / (1.0 + exp)
1763 } else {
1764 let exp = value.exp();
1765 exp / (1.0 + exp)
1766 }
1767}
1768
1769fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1770 nodes.iter().all(|node| {
1771 matches!(
1772 node,
1773 tree::classifier::TreeNode::Leaf { .. }
1774 | tree::classifier::TreeNode::BinarySplit { .. }
1775 )
1776 })
1777}
1778
1779fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1780 match &nodes[node_index] {
1781 tree::classifier::TreeNode::Leaf { sample_count, .. }
1782 | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1783 | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1784 }
1785}
1786
1787fn build_binary_classifier_layout(
1788 nodes: &[tree::classifier::TreeNode],
1789 root: usize,
1790 _class_labels: &[f64],
1791 feature_index_map: &[usize],
1792 preprocessing: &[FeaturePreprocessing],
1793 missing_features: Option<&BTreeSet<usize>>,
1794) -> Vec<OptimizedBinaryClassifierNode> {
1795 let mut layout = Vec::with_capacity(nodes.len());
1796 append_binary_classifier_node(
1797 nodes,
1798 root,
1799 &mut layout,
1800 feature_index_map,
1801 preprocessing,
1802 missing_features,
1803 );
1804 layout
1805}
1806
1807fn append_binary_classifier_node(
1808 nodes: &[tree::classifier::TreeNode],
1809 node_index: usize,
1810 layout: &mut Vec<OptimizedBinaryClassifierNode>,
1811 feature_index_map: &[usize],
1812 preprocessing: &[FeaturePreprocessing],
1813 missing_features: Option<&BTreeSet<usize>>,
1814) -> usize {
1815 let current_index = layout.len();
1816 layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1817
1818 match &nodes[node_index] {
1819 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1820 layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1821 normalized_probabilities_from_counts(class_counts),
1822 );
1823 }
1824 tree::classifier::TreeNode::BinarySplit {
1825 feature_index,
1826 threshold_bin,
1827 missing_direction,
1828 left_child,
1829 right_child,
1830 class_counts,
1831 ..
1832 } => {
1833 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1834 (*left_child, *left_child, true)
1835 } else {
1836 let left_count = classifier_node_sample_count(nodes, *left_child);
1837 let right_count = classifier_node_sample_count(nodes, *right_child);
1838 if left_count >= right_count {
1839 (*left_child, *right_child, true)
1840 } else {
1841 (*right_child, *left_child, false)
1842 }
1843 };
1844
1845 let fallthrough_index = append_binary_classifier_node(
1846 nodes,
1847 fallthrough_child,
1848 layout,
1849 feature_index_map,
1850 preprocessing,
1851 missing_features,
1852 );
1853 debug_assert_eq!(fallthrough_index, current_index + 1);
1854 let jump_index = if jump_child == fallthrough_child {
1855 fallthrough_index
1856 } else {
1857 append_binary_classifier_node(
1858 nodes,
1859 jump_child,
1860 layout,
1861 feature_index_map,
1862 preprocessing,
1863 missing_features,
1864 )
1865 };
1866
1867 let missing_bin =
1868 optimized_missing_bin(preprocessing, *feature_index, missing_features);
1869 let (missing_jump_index, missing_probabilities) =
1870 if missing_feature_enabled(*feature_index, missing_features) {
1871 match missing_direction {
1872 tree::shared::MissingBranchDirection::Left => (
1873 Some(if *left_child == fallthrough_child {
1874 fallthrough_index
1875 } else {
1876 jump_index
1877 }),
1878 None,
1879 ),
1880 tree::shared::MissingBranchDirection::Right => (
1881 Some(if *right_child == fallthrough_child {
1882 fallthrough_index
1883 } else {
1884 jump_index
1885 }),
1886 None,
1887 ),
1888 tree::shared::MissingBranchDirection::Node => (
1889 None,
1890 Some(normalized_probabilities_from_counts(class_counts)),
1891 ),
1892 }
1893 } else {
1894 (None, None)
1895 };
1896
1897 layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1898 feature_index: remap_feature_index(*feature_index, feature_index_map),
1899 threshold_bin: *threshold_bin,
1900 jump_index,
1901 jump_if_greater,
1902 missing_bin,
1903 missing_jump_index,
1904 missing_probabilities,
1905 };
1906 }
1907 tree::classifier::TreeNode::MultiwaySplit { .. } => {
1908 unreachable!("multiway nodes are filtered out before binary layout construction");
1909 }
1910 }
1911
1912 current_index
1913}
1914
1915fn regressor_node_sample_count(
1916 nodes: &[tree::regressor::RegressionNode],
1917 node_index: usize,
1918) -> usize {
1919 match &nodes[node_index] {
1920 tree::regressor::RegressionNode::Leaf { sample_count, .. }
1921 | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1922 }
1923}
1924
1925fn build_binary_regressor_layout(
1926 nodes: &[tree::regressor::RegressionNode],
1927 root: usize,
1928 feature_index_map: &[usize],
1929 preprocessing: &[FeaturePreprocessing],
1930 missing_features: Option<&BTreeSet<usize>>,
1931) -> Vec<OptimizedBinaryRegressorNode> {
1932 let mut layout = Vec::with_capacity(nodes.len());
1933 append_binary_regressor_node(
1934 nodes,
1935 root,
1936 &mut layout,
1937 feature_index_map,
1938 preprocessing,
1939 missing_features,
1940 );
1941 layout
1942}
1943
1944fn append_binary_regressor_node(
1945 nodes: &[tree::regressor::RegressionNode],
1946 node_index: usize,
1947 layout: &mut Vec<OptimizedBinaryRegressorNode>,
1948 feature_index_map: &[usize],
1949 preprocessing: &[FeaturePreprocessing],
1950 missing_features: Option<&BTreeSet<usize>>,
1951) -> usize {
1952 let current_index = layout.len();
1953 layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1954
1955 match &nodes[node_index] {
1956 tree::regressor::RegressionNode::Leaf { value, .. } => {
1957 layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1958 }
1959 tree::regressor::RegressionNode::BinarySplit {
1960 feature_index,
1961 threshold_bin,
1962 missing_direction,
1963 missing_value,
1964 left_child,
1965 right_child,
1966 ..
1967 } => {
1968 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1969 (*left_child, *left_child, true)
1970 } else {
1971 let left_count = regressor_node_sample_count(nodes, *left_child);
1972 let right_count = regressor_node_sample_count(nodes, *right_child);
1973 if left_count >= right_count {
1974 (*left_child, *right_child, true)
1975 } else {
1976 (*right_child, *left_child, false)
1977 }
1978 };
1979
1980 let fallthrough_index = append_binary_regressor_node(
1981 nodes,
1982 fallthrough_child,
1983 layout,
1984 feature_index_map,
1985 preprocessing,
1986 missing_features,
1987 );
1988 debug_assert_eq!(fallthrough_index, current_index + 1);
1989 let jump_index = if jump_child == fallthrough_child {
1990 fallthrough_index
1991 } else {
1992 append_binary_regressor_node(
1993 nodes,
1994 jump_child,
1995 layout,
1996 feature_index_map,
1997 preprocessing,
1998 missing_features,
1999 )
2000 };
2001
2002 let missing_bin =
2003 optimized_missing_bin(preprocessing, *feature_index, missing_features);
2004 let (missing_jump_index, missing_value) =
2005 if missing_feature_enabled(*feature_index, missing_features) {
2006 match missing_direction {
2007 tree::shared::MissingBranchDirection::Left => (
2008 Some(if *left_child == fallthrough_child {
2009 fallthrough_index
2010 } else {
2011 jump_index
2012 }),
2013 None,
2014 ),
2015 tree::shared::MissingBranchDirection::Right => (
2016 Some(if *right_child == fallthrough_child {
2017 fallthrough_index
2018 } else {
2019 jump_index
2020 }),
2021 None,
2022 ),
2023 tree::shared::MissingBranchDirection::Node => (None, Some(*missing_value)),
2024 }
2025 } else {
2026 (None, None)
2027 };
2028
2029 layout[current_index] = OptimizedBinaryRegressorNode::Branch {
2030 feature_index: remap_feature_index(*feature_index, feature_index_map),
2031 threshold_bin: *threshold_bin,
2032 jump_index,
2033 jump_if_greater,
2034 missing_bin,
2035 missing_jump_index,
2036 missing_value,
2037 };
2038 }
2039 }
2040
2041 current_index
2042}
2043
2044fn predict_oblivious_column_major_matrix(
2045 feature_indices: &[usize],
2046 threshold_bins: &[u16],
2047 leaf_values: &[f64],
2048 matrix: &ColumnMajorBinnedMatrix,
2049 executor: &InferenceExecutor,
2050) -> Vec<f64> {
2051 let mut outputs = vec![0.0; matrix.n_rows];
2052 executor.fill_chunks(
2053 &mut outputs,
2054 PARALLEL_INFERENCE_CHUNK_ROWS,
2055 |start_row, chunk| {
2056 predict_oblivious_chunk(
2057 feature_indices,
2058 threshold_bins,
2059 leaf_values,
2060 matrix,
2061 start_row,
2062 chunk,
2063 )
2064 },
2065 );
2066 outputs
2067}
2068
2069fn predict_oblivious_probabilities_column_major_matrix(
2070 feature_indices: &[usize],
2071 threshold_bins: &[u16],
2072 leaf_values: &[Vec<f64>],
2073 matrix: &ColumnMajorBinnedMatrix,
2074 _executor: &InferenceExecutor,
2075) -> Vec<Vec<f64>> {
2076 (0..matrix.n_rows)
2077 .map(|row_index| {
2078 predict_oblivious_probabilities_row(
2079 feature_indices,
2080 threshold_bins,
2081 leaf_values,
2082 |feature_index| matrix.column(feature_index).value_at(row_index),
2083 )
2084 .to_vec()
2085 })
2086 .collect()
2087}
2088
2089fn predict_oblivious_chunk(
2090 feature_indices: &[usize],
2091 threshold_bins: &[u16],
2092 leaf_values: &[f64],
2093 matrix: &ColumnMajorBinnedMatrix,
2094 start_row: usize,
2095 output: &mut [f64],
2096) {
2097 let processed = simd_predict_oblivious_chunk(
2098 feature_indices,
2099 threshold_bins,
2100 leaf_values,
2101 matrix,
2102 start_row,
2103 output,
2104 );
2105
2106 for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2107 let row_index = start_row + offset;
2108 *out = predict_oblivious_row(
2109 feature_indices,
2110 threshold_bins,
2111 leaf_values,
2112 |feature_index| matrix.column(feature_index).value_at(row_index),
2113 );
2114 }
2115}
2116
2117fn simd_predict_oblivious_chunk(
2118 feature_indices: &[usize],
2119 threshold_bins: &[u16],
2120 leaf_values: &[f64],
2121 matrix: &ColumnMajorBinnedMatrix,
2122 start_row: usize,
2123 output: &mut [f64],
2124) -> usize {
2125 let mut processed = 0usize;
2126 let ones = u32x8::splat(1);
2127
2128 while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2129 let base_row = start_row + processed;
2130 let mut leaf_indices = u32x8::splat(0);
2131
2132 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2133 let column = matrix.column(feature_index);
2134 let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2135 let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2136 .try_into()
2137 .expect("lane width matches the fixed SIMD width");
2138 u32x8::new([
2139 u32::from(lanes[0]),
2140 u32::from(lanes[1]),
2141 u32::from(lanes[2]),
2142 u32::from(lanes[3]),
2143 u32::from(lanes[4]),
2144 u32::from(lanes[5]),
2145 u32::from(lanes[6]),
2146 u32::from(lanes[7]),
2147 ])
2148 } else {
2149 let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2150 .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2151 .expect("column is u16 when not u8")
2152 .try_into()
2153 .expect("lane width matches the fixed SIMD width");
2154 u32x8::from(u16x8::new(lanes))
2155 };
2156 let threshold = u32x8::splat(u32::from(threshold_bin));
2157 let bit = bins.cmp_gt(threshold) & ones;
2158 leaf_indices = (leaf_indices << 1) | bit;
2159 }
2160
2161 let lane_indices = leaf_indices.to_array();
2162 for lane in 0..OBLIVIOUS_SIMD_LANES {
2163 output[processed + lane] =
2164 leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2165 }
2166 processed += OBLIVIOUS_SIMD_LANES;
2167 }
2168
2169 processed
2170}
2171
2172pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2173 training::train(train_set, config)
2174}
2175
2176#[cfg(test)]
2177mod tests;