1use forestfire_data::{
19 BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20};
21#[cfg(feature = "polars")]
22use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
23use rayon::ThreadPoolBuilder;
24use rayon::prelude::*;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::collections::BTreeMap;
28use std::error::Error;
29use std::fmt::{Display, Formatter};
30use std::sync::Arc;
31use wide::{u16x8, u32x8};
32
33mod boosting;
34mod bootstrap;
35mod compiled_artifact;
36mod forest;
37mod inference_input;
38mod introspection;
39pub mod ir;
40mod model_api;
41mod optimized_runtime;
42mod runtime_planning;
43mod sampling;
44mod training;
45pub mod tree;
46
47pub use boosting::BoostingError;
48pub use boosting::GradientBoostedTrees;
49pub use compiled_artifact::CompiledArtifactError;
50pub use forest::RandomForest;
51pub use introspection::IntrospectionError;
52pub use introspection::PredictionHistogramEntry;
53pub use introspection::PredictionValueStats;
54pub use introspection::TreeStructureSummary;
55pub use ir::IrError;
56pub use ir::ModelPackageIr;
57pub use model_api::OptimizedModel;
58pub use tree::classifier::DecisionTreeAlgorithm;
59pub use tree::classifier::DecisionTreeClassifier;
60pub use tree::classifier::DecisionTreeError;
61pub use tree::classifier::DecisionTreeOptions;
62pub use tree::classifier::train_c45;
63pub use tree::classifier::train_cart;
64pub use tree::classifier::train_id3;
65pub use tree::classifier::train_oblivious;
66pub use tree::classifier::train_randomized;
67pub use tree::regressor::DecisionTreeRegressor;
68pub use tree::regressor::RegressionTreeAlgorithm;
69pub use tree::regressor::RegressionTreeError;
70pub use tree::regressor::RegressionTreeOptions;
71pub use tree::regressor::train_cart_regressor;
72pub use tree::regressor::train_oblivious_regressor;
73pub use tree::regressor::train_randomized_regressor;
74#[cfg(feature = "polars")]
75const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
76pub(crate) use inference_input::ColumnMajorBinnedMatrix;
77pub(crate) use inference_input::CompactBinnedColumn;
78pub(crate) use inference_input::InferenceTable;
79pub(crate) use inference_input::ProjectedTableView;
80#[cfg(feature = "polars")]
81pub(crate) use inference_input::polars_named_columns;
82pub(crate) use introspection::prediction_value_stats;
83pub(crate) use introspection::tree_structure_summary;
84pub(crate) use optimized_runtime::InferenceExecutor;
85pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
86pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
87pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
88pub(crate) use optimized_runtime::OptimizedClassifierNode;
89pub(crate) use optimized_runtime::OptimizedRuntime;
90pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
91pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::resolve_inference_thread_count;
93pub(crate) use runtime_planning::build_feature_index_map;
94pub(crate) use runtime_planning::build_feature_projection;
95pub(crate) use runtime_planning::model_used_feature_indices;
96pub(crate) use runtime_planning::ordered_ensemble_indices;
97pub(crate) use runtime_planning::remap_feature_index;
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum TrainAlgorithm {
101 Dt,
103 Rf,
105 Gbm,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum Criterion {
111 Auto,
113 Gini,
115 Entropy,
117 Mean,
119 Median,
121 SecondOrder,
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum Task {
127 Regression,
129 Classification,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum TreeType {
135 Id3,
137 C45,
139 Cart,
141 Randomized,
143 Oblivious,
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum MaxFeatures {
149 Auto,
151 All,
153 Sqrt,
155 Third,
157 Count(usize),
159}
160
161impl MaxFeatures {
162 pub fn resolve(self, task: Task, feature_count: usize) -> usize {
163 match self {
164 MaxFeatures::Auto => match task {
165 Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
166 Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
167 },
168 MaxFeatures::All => feature_count.max(1),
169 MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
170 MaxFeatures::Third => (feature_count / 3).max(1),
171 MaxFeatures::Count(count) => count.min(feature_count).max(1),
172 }
173 }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178pub enum InputFeatureKind {
179 Numeric,
181 Binary,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
186pub struct NumericBinBoundary {
187 pub bin: u16,
189 pub upper_bound: f64,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194#[serde(tag = "kind", rename_all = "snake_case")]
195pub enum FeaturePreprocessing {
196 Numeric {
198 bin_boundaries: Vec<NumericBinBoundary>,
199 },
200 Binary,
202}
203
204#[derive(Debug, Clone, Copy)]
210pub struct TrainConfig {
211 pub algorithm: TrainAlgorithm,
213 pub task: Task,
215 pub tree_type: TreeType,
217 pub criterion: Criterion,
219 pub max_depth: Option<usize>,
221 pub min_samples_split: Option<usize>,
223 pub min_samples_leaf: Option<usize>,
225 pub physical_cores: Option<usize>,
227 pub n_trees: Option<usize>,
229 pub max_features: MaxFeatures,
231 pub seed: Option<u64>,
233 pub compute_oob: bool,
235 pub learning_rate: Option<f64>,
237 pub bootstrap: bool,
239 pub top_gradient_fraction: Option<f64>,
241 pub other_gradient_fraction: Option<f64>,
243}
244
245impl Default for TrainConfig {
246 fn default() -> Self {
247 Self {
248 algorithm: TrainAlgorithm::Dt,
249 task: Task::Regression,
250 tree_type: TreeType::Cart,
251 criterion: Criterion::Auto,
252 max_depth: None,
253 min_samples_split: None,
254 min_samples_leaf: None,
255 physical_cores: None,
256 n_trees: None,
257 max_features: MaxFeatures::Auto,
258 seed: None,
259 compute_oob: false,
260 learning_rate: None,
261 bootstrap: false,
262 top_gradient_fraction: None,
263 other_gradient_fraction: None,
264 }
265 }
266}
267
268#[derive(Debug, Clone)]
274pub enum Model {
275 DecisionTreeClassifier(DecisionTreeClassifier),
276 DecisionTreeRegressor(DecisionTreeRegressor),
277 RandomForest(RandomForest),
278 GradientBoostedTrees(GradientBoostedTrees),
279}
280
281#[derive(Debug)]
282pub enum TrainError {
283 DecisionTree(DecisionTreeError),
284 RegressionTree(RegressionTreeError),
285 Boosting(BoostingError),
286 InvalidPhysicalCoreCount {
287 requested: usize,
288 available: usize,
289 },
290 ThreadPoolBuildFailed(String),
291 UnsupportedConfiguration {
292 task: Task,
293 tree_type: TreeType,
294 criterion: Criterion,
295 },
296 InvalidMaxDepth(usize),
297 InvalidMinSamplesSplit(usize),
298 InvalidMinSamplesLeaf(usize),
299 InvalidTreeCount(usize),
300 InvalidMaxFeatures(usize),
301}
302
303impl Display for TrainError {
304 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
305 match self {
306 TrainError::DecisionTree(err) => err.fmt(f),
307 TrainError::RegressionTree(err) => err.fmt(f),
308 TrainError::Boosting(err) => err.fmt(f),
309 TrainError::InvalidPhysicalCoreCount {
310 requested,
311 available,
312 } => write!(
313 f,
314 "Requested {} physical cores, but the available physical core count is {}.",
315 requested, available
316 ),
317 TrainError::ThreadPoolBuildFailed(message) => {
318 write!(f, "Failed to build training thread pool: {}.", message)
319 }
320 TrainError::UnsupportedConfiguration {
321 task,
322 tree_type,
323 criterion,
324 } => write!(
325 f,
326 "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
327 task, tree_type, criterion
328 ),
329 TrainError::InvalidMaxDepth(value) => {
330 write!(f, "max_depth must be at least 1. Received {}.", value)
331 }
332 TrainError::InvalidMinSamplesSplit(value) => {
333 write!(
334 f,
335 "min_samples_split must be at least 1. Received {}.",
336 value
337 )
338 }
339 TrainError::InvalidMinSamplesLeaf(value) => {
340 write!(
341 f,
342 "min_samples_leaf must be at least 1. Received {}.",
343 value
344 )
345 }
346 TrainError::InvalidTreeCount(n_trees) => {
347 write!(
348 f,
349 "Random forest requires at least one tree. Received {}.",
350 n_trees
351 )
352 }
353 TrainError::InvalidMaxFeatures(count) => {
354 write!(
355 f,
356 "max_features must be at least 1 when provided as an integer. Received {}.",
357 count
358 )
359 }
360 }
361 }
362}
363
364impl Error for TrainError {}
365
366#[derive(Debug, Clone, PartialEq)]
367pub enum PredictError {
368 ProbabilityPredictionRequiresClassification,
369 RaggedRows {
370 row: usize,
371 expected: usize,
372 actual: usize,
373 },
374 FeatureCountMismatch {
375 expected: usize,
376 actual: usize,
377 },
378 ColumnLengthMismatch {
379 feature: String,
380 expected: usize,
381 actual: usize,
382 },
383 MissingFeature(String),
384 UnexpectedFeature(String),
385 InvalidBinaryValue {
386 feature_index: usize,
387 row_index: usize,
388 value: f64,
389 },
390 NullValue {
391 feature: String,
392 row_index: usize,
393 },
394 UnsupportedFeatureType {
395 feature: String,
396 dtype: String,
397 },
398 Polars(String),
399}
400
401impl Display for PredictError {
402 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403 match self {
404 PredictError::ProbabilityPredictionRequiresClassification => write!(
405 f,
406 "predict_proba is only available for classification models."
407 ),
408 PredictError::RaggedRows {
409 row,
410 expected,
411 actual,
412 } => write!(
413 f,
414 "Ragged inference row at index {}: expected {} columns, found {}.",
415 row, expected, actual
416 ),
417 PredictError::FeatureCountMismatch { expected, actual } => write!(
418 f,
419 "Inference input has {} features, but the model expects {}.",
420 actual, expected
421 ),
422 PredictError::ColumnLengthMismatch {
423 feature,
424 expected,
425 actual,
426 } => write!(
427 f,
428 "Feature '{}' has {} values, expected {}.",
429 feature, actual, expected
430 ),
431 PredictError::MissingFeature(feature) => {
432 write!(f, "Missing required feature '{}'.", feature)
433 }
434 PredictError::UnexpectedFeature(feature) => {
435 write!(f, "Unexpected feature '{}'.", feature)
436 }
437 PredictError::InvalidBinaryValue {
438 feature_index,
439 row_index,
440 value,
441 } => write!(
442 f,
443 "Feature {} at row {} must be binary for inference, found {}.",
444 feature_index, row_index, value
445 ),
446 PredictError::NullValue { feature, row_index } => write!(
447 f,
448 "Feature '{}' contains a null value at row {}.",
449 feature, row_index
450 ),
451 PredictError::UnsupportedFeatureType { feature, dtype } => write!(
452 f,
453 "Feature '{}' has unsupported dtype '{}'.",
454 feature, dtype
455 ),
456 PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
457 }
458 }
459}
460
461impl Error for PredictError {}
462
463#[cfg(feature = "polars")]
464impl From<polars::error::PolarsError> for PredictError {
465 fn from(value: polars::error::PolarsError) -> Self {
466 PredictError::Polars(value.to_string())
467 }
468}
469
470#[derive(Debug)]
471pub enum OptimizeError {
472 InvalidPhysicalCoreCount { requested: usize, available: usize },
473 ThreadPoolBuildFailed(String),
474 UnsupportedModelType(&'static str),
475}
476
477impl Display for OptimizeError {
478 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
479 match self {
480 OptimizeError::InvalidPhysicalCoreCount {
481 requested,
482 available,
483 } => write!(
484 f,
485 "Requested {} physical cores, but the available physical core count is {}.",
486 requested, available
487 ),
488 OptimizeError::ThreadPoolBuildFailed(message) => {
489 write!(f, "Failed to build inference thread pool: {}.", message)
490 }
491 OptimizeError::UnsupportedModelType(model_type) => {
492 write!(
493 f,
494 "Optimized inference is not supported for model type '{}'.",
495 model_type
496 )
497 }
498 }
499 }
500}
501
502impl Error for OptimizeError {}
503
504#[derive(Debug, Clone, Copy)]
505pub(crate) struct Parallelism {
506 thread_count: usize,
507}
508
509impl Parallelism {
510 pub(crate) fn sequential() -> Self {
511 Self { thread_count: 1 }
512 }
513
514 #[cfg(test)]
515 pub(crate) fn with_threads(thread_count: usize) -> Self {
516 Self {
517 thread_count: thread_count.max(1),
518 }
519 }
520
521 pub(crate) fn enabled(self) -> bool {
522 self.thread_count > 1
523 }
524}
525
526pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
527 (0..table.n_features())
528 .map(|feature_index| {
529 if table.is_binary_feature(feature_index) {
530 FeaturePreprocessing::Binary
531 } else {
532 let values = (0..table.n_rows())
533 .map(|row_index| table.feature_value(feature_index, row_index))
534 .collect::<Vec<_>>();
535 FeaturePreprocessing::Numeric {
536 bin_boundaries: numeric_bin_boundaries(
537 &values,
538 NumericBins::Fixed(table.numeric_bin_cap()),
539 )
540 .into_iter()
541 .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
542 .collect(),
543 }
544 }
545 })
546 .collect()
547}
548
549impl OptimizedRuntime {
550 fn supports_batch_matrix(&self) -> bool {
551 matches!(
552 self,
553 OptimizedRuntime::BinaryClassifier { .. }
554 | OptimizedRuntime::BinaryRegressor { .. }
555 | OptimizedRuntime::ObliviousClassifier { .. }
556 | OptimizedRuntime::ObliviousRegressor { .. }
557 | OptimizedRuntime::ForestClassifier { .. }
558 | OptimizedRuntime::ForestRegressor { .. }
559 | OptimizedRuntime::BoostedBinaryClassifier { .. }
560 | OptimizedRuntime::BoostedRegressor { .. }
561 )
562 }
563
564 fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
565 n_rows > 1 && self.supports_batch_matrix()
566 }
567
568 fn from_model(model: &Model, feature_index_map: &[usize]) -> Self {
569 match model {
570 Model::DecisionTreeClassifier(classifier) => {
571 Self::from_classifier(classifier, feature_index_map)
572 }
573 Model::DecisionTreeRegressor(regressor) => {
574 Self::from_regressor(regressor, feature_index_map)
575 }
576 Model::RandomForest(forest) => match forest.task() {
577 Task::Classification => {
578 let tree_order = ordered_ensemble_indices(forest.trees());
579 Self::ForestClassifier {
580 trees: tree_order
581 .into_iter()
582 .map(|tree_index| {
583 Self::from_model(&forest.trees()[tree_index], feature_index_map)
584 })
585 .collect(),
586 class_labels: forest
587 .class_labels()
588 .expect("classification forest stores class labels"),
589 }
590 }
591 Task::Regression => {
592 let tree_order = ordered_ensemble_indices(forest.trees());
593 Self::ForestRegressor {
594 trees: tree_order
595 .into_iter()
596 .map(|tree_index| {
597 Self::from_model(&forest.trees()[tree_index], feature_index_map)
598 })
599 .collect(),
600 }
601 }
602 },
603 Model::GradientBoostedTrees(model) => match model.task() {
604 Task::Classification => {
605 let tree_order = ordered_ensemble_indices(model.trees());
606 Self::BoostedBinaryClassifier {
607 trees: tree_order
608 .iter()
609 .map(|tree_index| {
610 Self::from_model(&model.trees()[*tree_index], feature_index_map)
611 })
612 .collect(),
613 tree_weights: tree_order
614 .iter()
615 .map(|tree_index| model.tree_weights()[*tree_index])
616 .collect(),
617 base_score: model.base_score(),
618 class_labels: model
619 .class_labels()
620 .expect("classification boosting stores class labels"),
621 }
622 }
623 Task::Regression => {
624 let tree_order = ordered_ensemble_indices(model.trees());
625 Self::BoostedRegressor {
626 trees: tree_order
627 .iter()
628 .map(|tree_index| {
629 Self::from_model(&model.trees()[*tree_index], feature_index_map)
630 })
631 .collect(),
632 tree_weights: tree_order
633 .iter()
634 .map(|tree_index| model.tree_weights()[*tree_index])
635 .collect(),
636 base_score: model.base_score(),
637 }
638 }
639 },
640 }
641 }
642
643 fn from_classifier(classifier: &DecisionTreeClassifier, feature_index_map: &[usize]) -> Self {
644 match classifier.structure() {
645 tree::classifier::TreeStructure::Standard { nodes, root } => {
646 if classifier_nodes_are_binary_only(nodes) {
647 return Self::BinaryClassifier {
648 nodes: build_binary_classifier_layout(
649 nodes,
650 *root,
651 classifier.class_labels(),
652 feature_index_map,
653 ),
654 class_labels: classifier.class_labels().to_vec(),
655 };
656 }
657
658 let optimized_nodes = nodes
659 .iter()
660 .map(|node| match node {
661 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
662 OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
663 class_counts,
664 ))
665 }
666 tree::classifier::TreeNode::BinarySplit {
667 feature_index,
668 threshold_bin,
669 left_child,
670 right_child,
671 ..
672 } => OptimizedClassifierNode::Binary {
673 feature_index: remap_feature_index(*feature_index, feature_index_map),
674 threshold_bin: *threshold_bin,
675 children: [*left_child, *right_child],
676 },
677 tree::classifier::TreeNode::MultiwaySplit {
678 feature_index,
679 class_counts,
680 branches,
681 ..
682 } => {
683 let max_bin_index = branches
684 .iter()
685 .map(|(bin, _)| usize::from(*bin))
686 .max()
687 .unwrap_or(0);
688 let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
689 for (bin, child_index) in branches {
690 child_lookup[usize::from(*bin)] = *child_index;
691 }
692 OptimizedClassifierNode::Multiway {
693 feature_index: remap_feature_index(
694 *feature_index,
695 feature_index_map,
696 ),
697 child_lookup,
698 max_bin_index,
699 fallback_probabilities: normalized_probabilities_from_counts(
700 class_counts,
701 ),
702 }
703 }
704 })
705 .collect();
706
707 Self::StandardClassifier {
708 nodes: optimized_nodes,
709 root: *root,
710 class_labels: classifier.class_labels().to_vec(),
711 }
712 }
713 tree::classifier::TreeStructure::Oblivious {
714 splits,
715 leaf_class_counts,
716 ..
717 } => Self::ObliviousClassifier {
718 feature_indices: splits
719 .iter()
720 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
721 .collect(),
722 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
723 leaf_values: leaf_class_counts
724 .iter()
725 .map(|class_counts| normalized_probabilities_from_counts(class_counts))
726 .collect(),
727 class_labels: classifier.class_labels().to_vec(),
728 },
729 }
730 }
731
732 fn from_regressor(regressor: &DecisionTreeRegressor, feature_index_map: &[usize]) -> Self {
733 match regressor.structure() {
734 tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
735 Self::BinaryRegressor {
736 nodes: build_binary_regressor_layout(nodes, *root, feature_index_map),
737 }
738 }
739 tree::regressor::RegressionTreeStructure::Oblivious {
740 splits,
741 leaf_values,
742 ..
743 } => Self::ObliviousRegressor {
744 feature_indices: splits
745 .iter()
746 .map(|split| remap_feature_index(split.feature_index, feature_index_map))
747 .collect(),
748 threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
749 leaf_values: leaf_values.clone(),
750 },
751 }
752 }
753
754 #[inline(always)]
755 fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
756 match self {
757 OptimizedRuntime::BinaryClassifier { .. }
758 | OptimizedRuntime::StandardClassifier { .. }
759 | OptimizedRuntime::ObliviousClassifier { .. }
760 | OptimizedRuntime::ForestClassifier { .. }
761 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
762 let probabilities = self
763 .predict_proba_table_row(table, row_index)
764 .expect("classifier runtime supports probability prediction");
765 class_label_from_probabilities(&probabilities, self.class_labels())
766 }
767 OptimizedRuntime::BinaryRegressor { nodes } => {
768 predict_binary_regressor_row(nodes, |feature_index| {
769 table.binned_value(feature_index, row_index)
770 })
771 }
772 OptimizedRuntime::ObliviousRegressor {
773 feature_indices,
774 threshold_bins,
775 leaf_values,
776 } => predict_oblivious_row(
777 feature_indices,
778 threshold_bins,
779 leaf_values,
780 |feature_index| table.binned_value(feature_index, row_index),
781 ),
782 OptimizedRuntime::ForestRegressor { trees } => {
783 trees
784 .iter()
785 .map(|tree| tree.predict_table_row(table, row_index))
786 .sum::<f64>()
787 / trees.len() as f64
788 }
789 OptimizedRuntime::BoostedRegressor {
790 trees,
791 tree_weights,
792 base_score,
793 } => {
794 *base_score
795 + trees
796 .iter()
797 .zip(tree_weights.iter().copied())
798 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
799 .sum::<f64>()
800 }
801 }
802 }
803
804 #[inline(always)]
805 fn predict_proba_table_row(
806 &self,
807 table: &dyn TableAccess,
808 row_index: usize,
809 ) -> Result<Vec<f64>, PredictError> {
810 match self {
811 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
812 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
813 table.binned_value(feature_index, row_index)
814 })
815 .to_vec(),
816 ),
817 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
818 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
819 table.binned_value(feature_index, row_index)
820 })
821 .to_vec(),
822 ),
823 OptimizedRuntime::ObliviousClassifier {
824 feature_indices,
825 threshold_bins,
826 leaf_values,
827 ..
828 } => Ok(predict_oblivious_probabilities_row(
829 feature_indices,
830 threshold_bins,
831 leaf_values,
832 |feature_index| table.binned_value(feature_index, row_index),
833 )
834 .to_vec()),
835 OptimizedRuntime::ForestClassifier { trees, .. } => {
836 let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
837 for tree in &trees[1..] {
838 let row = tree.predict_proba_table_row(table, row_index)?;
839 for (total, value) in totals.iter_mut().zip(row) {
840 *total += value;
841 }
842 }
843 let tree_count = trees.len() as f64;
844 for value in &mut totals {
845 *value /= tree_count;
846 }
847 Ok(totals)
848 }
849 OptimizedRuntime::BoostedBinaryClassifier {
850 trees,
851 tree_weights,
852 base_score,
853 ..
854 } => {
855 let raw_score = *base_score
856 + trees
857 .iter()
858 .zip(tree_weights.iter().copied())
859 .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
860 .sum::<f64>();
861 let positive = sigmoid(raw_score);
862 Ok(vec![1.0 - positive, positive])
863 }
864 OptimizedRuntime::BinaryRegressor { .. }
865 | OptimizedRuntime::ObliviousRegressor { .. }
866 | OptimizedRuntime::ForestRegressor { .. }
867 | OptimizedRuntime::BoostedRegressor { .. } => {
868 Err(PredictError::ProbabilityPredictionRequiresClassification)
869 }
870 }
871 }
872
873 fn predict_proba_table(
874 &self,
875 table: &dyn TableAccess,
876 executor: &InferenceExecutor,
877 ) -> Result<Vec<Vec<f64>>, PredictError> {
878 match self {
879 OptimizedRuntime::BinaryClassifier { .. }
880 | OptimizedRuntime::StandardClassifier { .. }
881 | OptimizedRuntime::ObliviousClassifier { .. }
882 | OptimizedRuntime::ForestClassifier { .. }
883 | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
884 if self.should_use_batch_matrix(table.n_rows()) {
885 let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
886 self.predict_proba_column_major_matrix(&matrix, executor)
887 } else {
888 (0..table.n_rows())
889 .map(|row_index| self.predict_proba_table_row(table, row_index))
890 .collect()
891 }
892 }
893 OptimizedRuntime::BinaryRegressor { .. }
894 | OptimizedRuntime::ObliviousRegressor { .. }
895 | OptimizedRuntime::ForestRegressor { .. }
896 | OptimizedRuntime::BoostedRegressor { .. } => {
897 Err(PredictError::ProbabilityPredictionRequiresClassification)
898 }
899 }
900 }
901
902 fn predict_column_major_matrix(
903 &self,
904 matrix: &ColumnMajorBinnedMatrix,
905 executor: &InferenceExecutor,
906 ) -> Vec<f64> {
907 match self {
908 OptimizedRuntime::BinaryClassifier { .. }
909 | OptimizedRuntime::StandardClassifier { .. }
910 | OptimizedRuntime::ObliviousClassifier { .. }
911 | OptimizedRuntime::ForestClassifier { .. }
912 | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
913 .predict_proba_column_major_matrix(matrix, executor)
914 .expect("classifier runtime supports probability prediction")
915 .into_iter()
916 .map(|row| class_label_from_probabilities(&row, self.class_labels()))
917 .collect(),
918 OptimizedRuntime::BinaryRegressor { nodes } => {
919 predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
920 }
921 OptimizedRuntime::ObliviousRegressor {
922 feature_indices,
923 threshold_bins,
924 leaf_values,
925 } => predict_oblivious_column_major_matrix(
926 feature_indices,
927 threshold_bins,
928 leaf_values,
929 matrix,
930 executor,
931 ),
932 OptimizedRuntime::ForestRegressor { trees } => {
933 let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
934 for tree in &trees[1..] {
935 let values = tree.predict_column_major_matrix(matrix, executor);
936 for (total, value) in totals.iter_mut().zip(values) {
937 *total += value;
938 }
939 }
940 let tree_count = trees.len() as f64;
941 for total in &mut totals {
942 *total /= tree_count;
943 }
944 totals
945 }
946 OptimizedRuntime::BoostedRegressor {
947 trees,
948 tree_weights,
949 base_score,
950 } => {
951 let mut totals = vec![*base_score; matrix.n_rows];
952 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
953 let values = tree.predict_column_major_matrix(matrix, executor);
954 for (total, value) in totals.iter_mut().zip(values) {
955 *total += weight * value;
956 }
957 }
958 totals
959 }
960 }
961 }
962
963 fn predict_proba_column_major_matrix(
964 &self,
965 matrix: &ColumnMajorBinnedMatrix,
966 executor: &InferenceExecutor,
967 ) -> Result<Vec<Vec<f64>>, PredictError> {
968 match self {
969 OptimizedRuntime::BinaryClassifier { nodes, .. } => {
970 Ok(predict_binary_classifier_probabilities_column_major_matrix(
971 nodes, matrix, executor,
972 ))
973 }
974 OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
975 .map(|row_index| {
976 self.predict_proba_binned_row_from_columns(matrix, row_index)
977 .expect("classifier runtime supports probability prediction")
978 })
979 .collect()),
980 OptimizedRuntime::ObliviousClassifier {
981 feature_indices,
982 threshold_bins,
983 leaf_values,
984 ..
985 } => Ok(predict_oblivious_probabilities_column_major_matrix(
986 feature_indices,
987 threshold_bins,
988 leaf_values,
989 matrix,
990 executor,
991 )),
992 OptimizedRuntime::ForestClassifier { trees, .. } => {
993 let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
994 for tree in &trees[1..] {
995 let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
996 for (row_totals, row_values) in totals.iter_mut().zip(rows) {
997 for (total, value) in row_totals.iter_mut().zip(row_values) {
998 *total += value;
999 }
1000 }
1001 }
1002 let tree_count = trees.len() as f64;
1003 for row in &mut totals {
1004 for value in row {
1005 *value /= tree_count;
1006 }
1007 }
1008 Ok(totals)
1009 }
1010 OptimizedRuntime::BoostedBinaryClassifier {
1011 trees,
1012 tree_weights,
1013 base_score,
1014 ..
1015 } => {
1016 let mut raw_scores = vec![*base_score; matrix.n_rows];
1017 for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1018 let values = tree.predict_column_major_matrix(matrix, executor);
1019 for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1020 *raw_score += weight * value;
1021 }
1022 }
1023 Ok(raw_scores
1024 .into_iter()
1025 .map(|raw_score| {
1026 let positive = sigmoid(raw_score);
1027 vec![1.0 - positive, positive]
1028 })
1029 .collect())
1030 }
1031 OptimizedRuntime::BinaryRegressor { .. }
1032 | OptimizedRuntime::ObliviousRegressor { .. }
1033 | OptimizedRuntime::ForestRegressor { .. }
1034 | OptimizedRuntime::BoostedRegressor { .. } => {
1035 Err(PredictError::ProbabilityPredictionRequiresClassification)
1036 }
1037 }
1038 }
1039
1040 fn class_labels(&self) -> &[f64] {
1041 match self {
1042 OptimizedRuntime::BinaryClassifier { class_labels, .. }
1043 | OptimizedRuntime::StandardClassifier { class_labels, .. }
1044 | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1045 | OptimizedRuntime::ForestClassifier { class_labels, .. }
1046 | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1047 _ => &[],
1048 }
1049 }
1050
1051 #[inline(always)]
1052 fn predict_binned_row_from_columns(
1053 &self,
1054 matrix: &ColumnMajorBinnedMatrix,
1055 row_index: usize,
1056 ) -> f64 {
1057 match self {
1058 OptimizedRuntime::BinaryRegressor { nodes } => {
1059 predict_binary_regressor_row(nodes, |feature_index| {
1060 matrix.column(feature_index).value_at(row_index)
1061 })
1062 }
1063 OptimizedRuntime::ObliviousRegressor {
1064 feature_indices,
1065 threshold_bins,
1066 leaf_values,
1067 } => predict_oblivious_row(
1068 feature_indices,
1069 threshold_bins,
1070 leaf_values,
1071 |feature_index| matrix.column(feature_index).value_at(row_index),
1072 ),
1073 OptimizedRuntime::BoostedRegressor {
1074 trees,
1075 tree_weights,
1076 base_score,
1077 } => {
1078 *base_score
1079 + trees
1080 .iter()
1081 .zip(tree_weights.iter().copied())
1082 .map(|(tree, weight)| {
1083 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1084 })
1085 .sum::<f64>()
1086 }
1087 _ => self.predict_column_major_matrix(
1088 matrix,
1089 &InferenceExecutor::new(1).expect("inference executor"),
1090 )[row_index],
1091 }
1092 }
1093
1094 #[inline(always)]
1095 fn predict_proba_binned_row_from_columns(
1096 &self,
1097 matrix: &ColumnMajorBinnedMatrix,
1098 row_index: usize,
1099 ) -> Result<Vec<f64>, PredictError> {
1100 match self {
1101 OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1102 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1103 matrix.column(feature_index).value_at(row_index)
1104 })
1105 .to_vec(),
1106 ),
1107 OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1108 predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1109 matrix.column(feature_index).value_at(row_index)
1110 })
1111 .to_vec(),
1112 ),
1113 OptimizedRuntime::ObliviousClassifier {
1114 feature_indices,
1115 threshold_bins,
1116 leaf_values,
1117 ..
1118 } => Ok(predict_oblivious_probabilities_row(
1119 feature_indices,
1120 threshold_bins,
1121 leaf_values,
1122 |feature_index| matrix.column(feature_index).value_at(row_index),
1123 )
1124 .to_vec()),
1125 OptimizedRuntime::ForestClassifier { trees, .. } => {
1126 let mut totals =
1127 trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1128 for tree in &trees[1..] {
1129 let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1130 for (total, value) in totals.iter_mut().zip(row) {
1131 *total += value;
1132 }
1133 }
1134 let tree_count = trees.len() as f64;
1135 for value in &mut totals {
1136 *value /= tree_count;
1137 }
1138 Ok(totals)
1139 }
1140 OptimizedRuntime::BoostedBinaryClassifier {
1141 trees,
1142 tree_weights,
1143 base_score,
1144 ..
1145 } => {
1146 let raw_score = *base_score
1147 + trees
1148 .iter()
1149 .zip(tree_weights.iter().copied())
1150 .map(|(tree, weight)| {
1151 weight * tree.predict_binned_row_from_columns(matrix, row_index)
1152 })
1153 .sum::<f64>();
1154 let positive = sigmoid(raw_score);
1155 Ok(vec![1.0 - positive, positive])
1156 }
1157 OptimizedRuntime::BinaryRegressor { .. }
1158 | OptimizedRuntime::ObliviousRegressor { .. }
1159 | OptimizedRuntime::ForestRegressor { .. }
1160 | OptimizedRuntime::BoostedRegressor { .. } => {
1161 Err(PredictError::ProbabilityPredictionRequiresClassification)
1162 }
1163 }
1164 }
1165}
1166
1167#[inline(always)]
1168fn predict_standard_classifier_probabilities_row<F>(
1169 nodes: &[OptimizedClassifierNode],
1170 root: usize,
1171 bin_at: F,
1172) -> &[f64]
1173where
1174 F: Fn(usize) -> u16,
1175{
1176 let mut node_index = root;
1177 loop {
1178 match &nodes[node_index] {
1179 OptimizedClassifierNode::Leaf(value) => return value,
1180 OptimizedClassifierNode::Binary {
1181 feature_index,
1182 threshold_bin,
1183 children,
1184 } => {
1185 let go_right = usize::from(bin_at(*feature_index) > *threshold_bin);
1186 node_index = children[go_right];
1187 }
1188 OptimizedClassifierNode::Multiway {
1189 feature_index,
1190 child_lookup,
1191 max_bin_index,
1192 fallback_probabilities,
1193 } => {
1194 let bin = usize::from(bin_at(*feature_index));
1195 if bin > *max_bin_index {
1196 return fallback_probabilities;
1197 }
1198 let child_index = child_lookup[bin];
1199 if child_index == usize::MAX {
1200 return fallback_probabilities;
1201 }
1202 node_index = child_index;
1203 }
1204 }
1205 }
1206}
1207
1208#[inline(always)]
1209fn predict_binary_classifier_probabilities_row<F>(
1210 nodes: &[OptimizedBinaryClassifierNode],
1211 bin_at: F,
1212) -> &[f64]
1213where
1214 F: Fn(usize) -> u16,
1215{
1216 let mut node_index = 0usize;
1217 loop {
1218 match &nodes[node_index] {
1219 OptimizedBinaryClassifierNode::Leaf(value) => return value,
1220 OptimizedBinaryClassifierNode::Branch {
1221 feature_index,
1222 threshold_bin,
1223 jump_index,
1224 jump_if_greater,
1225 } => {
1226 let go_right = bin_at(*feature_index) > *threshold_bin;
1227 node_index = if go_right == *jump_if_greater {
1228 *jump_index
1229 } else {
1230 node_index + 1
1231 };
1232 }
1233 }
1234 }
1235}
1236
1237#[inline(always)]
1238fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1239where
1240 F: Fn(usize) -> u16,
1241{
1242 let mut node_index = 0usize;
1243 loop {
1244 match &nodes[node_index] {
1245 OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1246 OptimizedBinaryRegressorNode::Branch {
1247 feature_index,
1248 threshold_bin,
1249 jump_index,
1250 jump_if_greater,
1251 } => {
1252 let go_right = bin_at(*feature_index) > *threshold_bin;
1253 node_index = if go_right == *jump_if_greater {
1254 *jump_index
1255 } else {
1256 node_index + 1
1257 };
1258 }
1259 }
1260 }
1261}
1262
1263fn predict_binary_classifier_probabilities_column_major_matrix(
1264 nodes: &[OptimizedBinaryClassifierNode],
1265 matrix: &ColumnMajorBinnedMatrix,
1266 _executor: &InferenceExecutor,
1267) -> Vec<Vec<f64>> {
1268 (0..matrix.n_rows)
1269 .map(|row_index| {
1270 predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1271 matrix.column(feature_index).value_at(row_index)
1272 })
1273 .to_vec()
1274 })
1275 .collect()
1276}
1277
1278fn predict_binary_regressor_column_major_matrix(
1279 nodes: &[OptimizedBinaryRegressorNode],
1280 matrix: &ColumnMajorBinnedMatrix,
1281 executor: &InferenceExecutor,
1282) -> Vec<f64> {
1283 let mut outputs = vec![0.0; matrix.n_rows];
1284 executor.fill_chunks(
1285 &mut outputs,
1286 STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1287 |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1288 );
1289 outputs
1290}
1291
1292fn predict_binary_regressor_chunk(
1293 nodes: &[OptimizedBinaryRegressorNode],
1294 matrix: &ColumnMajorBinnedMatrix,
1295 start_row: usize,
1296 output: &mut [f64],
1297) {
1298 let mut row_indices: Vec<usize> = (0..output.len()).collect();
1299 let mut stack = vec![(0usize, 0usize, output.len())];
1300
1301 while let Some((node_index, start, end)) = stack.pop() {
1302 match &nodes[node_index] {
1303 OptimizedBinaryRegressorNode::Leaf(value) => {
1304 for position in start..end {
1305 output[row_indices[position]] = *value;
1306 }
1307 }
1308 OptimizedBinaryRegressorNode::Branch {
1309 feature_index,
1310 threshold_bin,
1311 jump_index,
1312 jump_if_greater,
1313 } => {
1314 let fallthrough_index = node_index + 1;
1315 if *jump_index == fallthrough_index {
1316 stack.push((fallthrough_index, start, end));
1317 continue;
1318 }
1319
1320 let column = matrix.column(*feature_index);
1321 let mut partition = start;
1322 let mut jump_start = end;
1323 match column {
1324 CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1325 let threshold = *threshold_bin as u8;
1326 while partition < jump_start {
1327 let row_offset = row_indices[partition];
1328 let go_right = values[start_row + row_offset] > threshold;
1329 let goes_jump = go_right == *jump_if_greater;
1330 if goes_jump {
1331 jump_start -= 1;
1332 row_indices.swap(partition, jump_start);
1333 } else {
1334 partition += 1;
1335 }
1336 }
1337 }
1338 _ => {
1339 while partition < jump_start {
1340 let row_offset = row_indices[partition];
1341 let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1342 let goes_jump = go_right == *jump_if_greater;
1343 if goes_jump {
1344 jump_start -= 1;
1345 row_indices.swap(partition, jump_start);
1346 } else {
1347 partition += 1;
1348 }
1349 }
1350 }
1351 }
1352
1353 if jump_start < end {
1354 stack.push((*jump_index, jump_start, end));
1355 }
1356 if start < jump_start {
1357 stack.push((fallthrough_index, start, jump_start));
1358 }
1359 }
1360 }
1361 }
1362}
1363
1364#[inline(always)]
1365fn predict_oblivious_row<F>(
1366 feature_indices: &[usize],
1367 threshold_bins: &[u16],
1368 leaf_values: &[f64],
1369 bin_at: F,
1370) -> f64
1371where
1372 F: Fn(usize) -> u16,
1373{
1374 let mut leaf_index = 0usize;
1375 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1376 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1377 leaf_index = (leaf_index << 1) | go_right;
1378 }
1379 leaf_values[leaf_index]
1380}
1381
1382#[inline(always)]
1383fn predict_oblivious_probabilities_row<'a, F>(
1384 feature_indices: &[usize],
1385 threshold_bins: &[u16],
1386 leaf_values: &'a [Vec<f64>],
1387 bin_at: F,
1388) -> &'a [f64]
1389where
1390 F: Fn(usize) -> u16,
1391{
1392 let mut leaf_index = 0usize;
1393 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1394 let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1395 leaf_index = (leaf_index << 1) | go_right;
1396 }
1397 leaf_values[leaf_index].as_slice()
1398}
1399
1400fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1401 let total = class_counts.iter().sum::<usize>();
1402 if total == 0 {
1403 return vec![0.0; class_counts.len()];
1404 }
1405
1406 class_counts
1407 .iter()
1408 .map(|count| *count as f64 / total as f64)
1409 .collect()
1410}
1411
1412fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1413 let best_index = probabilities
1414 .iter()
1415 .copied()
1416 .enumerate()
1417 .max_by(|(left_index, left), (right_index, right)| {
1418 left.total_cmp(right)
1419 .then_with(|| right_index.cmp(left_index))
1420 })
1421 .map(|(index, _)| index)
1422 .expect("classification probability row is non-empty");
1423 class_labels[best_index]
1424}
1425
1426#[inline(always)]
1427fn sigmoid(value: f64) -> f64 {
1428 if value >= 0.0 {
1429 let exp = (-value).exp();
1430 1.0 / (1.0 + exp)
1431 } else {
1432 let exp = value.exp();
1433 exp / (1.0 + exp)
1434 }
1435}
1436
1437fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1438 nodes.iter().all(|node| {
1439 matches!(
1440 node,
1441 tree::classifier::TreeNode::Leaf { .. }
1442 | tree::classifier::TreeNode::BinarySplit { .. }
1443 )
1444 })
1445}
1446
1447fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1448 match &nodes[node_index] {
1449 tree::classifier::TreeNode::Leaf { sample_count, .. }
1450 | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1451 | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1452 }
1453}
1454
1455fn build_binary_classifier_layout(
1456 nodes: &[tree::classifier::TreeNode],
1457 root: usize,
1458 _class_labels: &[f64],
1459 feature_index_map: &[usize],
1460) -> Vec<OptimizedBinaryClassifierNode> {
1461 let mut layout = Vec::with_capacity(nodes.len());
1462 append_binary_classifier_node(nodes, root, &mut layout, feature_index_map);
1463 layout
1464}
1465
1466fn append_binary_classifier_node(
1467 nodes: &[tree::classifier::TreeNode],
1468 node_index: usize,
1469 layout: &mut Vec<OptimizedBinaryClassifierNode>,
1470 feature_index_map: &[usize],
1471) -> usize {
1472 let current_index = layout.len();
1473 layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1474
1475 match &nodes[node_index] {
1476 tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1477 layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1478 normalized_probabilities_from_counts(class_counts),
1479 );
1480 }
1481 tree::classifier::TreeNode::BinarySplit {
1482 feature_index,
1483 threshold_bin,
1484 left_child,
1485 right_child,
1486 ..
1487 } => {
1488 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1489 (*left_child, *left_child, true)
1490 } else {
1491 let left_count = classifier_node_sample_count(nodes, *left_child);
1492 let right_count = classifier_node_sample_count(nodes, *right_child);
1493 if left_count >= right_count {
1494 (*left_child, *right_child, true)
1495 } else {
1496 (*right_child, *left_child, false)
1497 }
1498 };
1499
1500 let fallthrough_index =
1501 append_binary_classifier_node(nodes, fallthrough_child, layout, feature_index_map);
1502 debug_assert_eq!(fallthrough_index, current_index + 1);
1503 let jump_index = if jump_child == fallthrough_child {
1504 fallthrough_index
1505 } else {
1506 append_binary_classifier_node(nodes, jump_child, layout, feature_index_map)
1507 };
1508
1509 layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1510 feature_index: remap_feature_index(*feature_index, feature_index_map),
1511 threshold_bin: *threshold_bin,
1512 jump_index,
1513 jump_if_greater,
1514 };
1515 }
1516 tree::classifier::TreeNode::MultiwaySplit { .. } => {
1517 unreachable!("multiway nodes are filtered out before binary layout construction");
1518 }
1519 }
1520
1521 current_index
1522}
1523
1524fn regressor_node_sample_count(
1525 nodes: &[tree::regressor::RegressionNode],
1526 node_index: usize,
1527) -> usize {
1528 match &nodes[node_index] {
1529 tree::regressor::RegressionNode::Leaf { sample_count, .. }
1530 | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1531 }
1532}
1533
1534fn build_binary_regressor_layout(
1535 nodes: &[tree::regressor::RegressionNode],
1536 root: usize,
1537 feature_index_map: &[usize],
1538) -> Vec<OptimizedBinaryRegressorNode> {
1539 let mut layout = Vec::with_capacity(nodes.len());
1540 append_binary_regressor_node(nodes, root, &mut layout, feature_index_map);
1541 layout
1542}
1543
1544fn append_binary_regressor_node(
1545 nodes: &[tree::regressor::RegressionNode],
1546 node_index: usize,
1547 layout: &mut Vec<OptimizedBinaryRegressorNode>,
1548 feature_index_map: &[usize],
1549) -> usize {
1550 let current_index = layout.len();
1551 layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1552
1553 match &nodes[node_index] {
1554 tree::regressor::RegressionNode::Leaf { value, .. } => {
1555 layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1556 }
1557 tree::regressor::RegressionNode::BinarySplit {
1558 feature_index,
1559 threshold_bin,
1560 left_child,
1561 right_child,
1562 ..
1563 } => {
1564 let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1565 (*left_child, *left_child, true)
1566 } else {
1567 let left_count = regressor_node_sample_count(nodes, *left_child);
1568 let right_count = regressor_node_sample_count(nodes, *right_child);
1569 if left_count >= right_count {
1570 (*left_child, *right_child, true)
1571 } else {
1572 (*right_child, *left_child, false)
1573 }
1574 };
1575
1576 let fallthrough_index =
1577 append_binary_regressor_node(nodes, fallthrough_child, layout, feature_index_map);
1578 debug_assert_eq!(fallthrough_index, current_index + 1);
1579 let jump_index = if jump_child == fallthrough_child {
1580 fallthrough_index
1581 } else {
1582 append_binary_regressor_node(nodes, jump_child, layout, feature_index_map)
1583 };
1584
1585 layout[current_index] = OptimizedBinaryRegressorNode::Branch {
1586 feature_index: remap_feature_index(*feature_index, feature_index_map),
1587 threshold_bin: *threshold_bin,
1588 jump_index,
1589 jump_if_greater,
1590 };
1591 }
1592 }
1593
1594 current_index
1595}
1596
1597fn predict_oblivious_column_major_matrix(
1598 feature_indices: &[usize],
1599 threshold_bins: &[u16],
1600 leaf_values: &[f64],
1601 matrix: &ColumnMajorBinnedMatrix,
1602 executor: &InferenceExecutor,
1603) -> Vec<f64> {
1604 let mut outputs = vec![0.0; matrix.n_rows];
1605 executor.fill_chunks(
1606 &mut outputs,
1607 PARALLEL_INFERENCE_CHUNK_ROWS,
1608 |start_row, chunk| {
1609 predict_oblivious_chunk(
1610 feature_indices,
1611 threshold_bins,
1612 leaf_values,
1613 matrix,
1614 start_row,
1615 chunk,
1616 )
1617 },
1618 );
1619 outputs
1620}
1621
1622fn predict_oblivious_probabilities_column_major_matrix(
1623 feature_indices: &[usize],
1624 threshold_bins: &[u16],
1625 leaf_values: &[Vec<f64>],
1626 matrix: &ColumnMajorBinnedMatrix,
1627 _executor: &InferenceExecutor,
1628) -> Vec<Vec<f64>> {
1629 (0..matrix.n_rows)
1630 .map(|row_index| {
1631 predict_oblivious_probabilities_row(
1632 feature_indices,
1633 threshold_bins,
1634 leaf_values,
1635 |feature_index| matrix.column(feature_index).value_at(row_index),
1636 )
1637 .to_vec()
1638 })
1639 .collect()
1640}
1641
1642fn predict_oblivious_chunk(
1643 feature_indices: &[usize],
1644 threshold_bins: &[u16],
1645 leaf_values: &[f64],
1646 matrix: &ColumnMajorBinnedMatrix,
1647 start_row: usize,
1648 output: &mut [f64],
1649) {
1650 let processed = simd_predict_oblivious_chunk(
1651 feature_indices,
1652 threshold_bins,
1653 leaf_values,
1654 matrix,
1655 start_row,
1656 output,
1657 );
1658
1659 for (offset, out) in output.iter_mut().enumerate().skip(processed) {
1660 let row_index = start_row + offset;
1661 *out = predict_oblivious_row(
1662 feature_indices,
1663 threshold_bins,
1664 leaf_values,
1665 |feature_index| matrix.column(feature_index).value_at(row_index),
1666 );
1667 }
1668}
1669
1670fn simd_predict_oblivious_chunk(
1671 feature_indices: &[usize],
1672 threshold_bins: &[u16],
1673 leaf_values: &[f64],
1674 matrix: &ColumnMajorBinnedMatrix,
1675 start_row: usize,
1676 output: &mut [f64],
1677) -> usize {
1678 let mut processed = 0usize;
1679 let ones = u32x8::splat(1);
1680
1681 while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
1682 let base_row = start_row + processed;
1683 let mut leaf_indices = u32x8::splat(0);
1684
1685 for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1686 let column = matrix.column(feature_index);
1687 let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
1688 let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
1689 .try_into()
1690 .expect("lane width matches the fixed SIMD width");
1691 u32x8::new([
1692 u32::from(lanes[0]),
1693 u32::from(lanes[1]),
1694 u32::from(lanes[2]),
1695 u32::from(lanes[3]),
1696 u32::from(lanes[4]),
1697 u32::from(lanes[5]),
1698 u32::from(lanes[6]),
1699 u32::from(lanes[7]),
1700 ])
1701 } else {
1702 let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
1703 .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
1704 .expect("column is u16 when not u8")
1705 .try_into()
1706 .expect("lane width matches the fixed SIMD width");
1707 u32x8::from(u16x8::new(lanes))
1708 };
1709 let threshold = u32x8::splat(u32::from(threshold_bin));
1710 let bit = bins.cmp_gt(threshold) & ones;
1711 leaf_indices = (leaf_indices << 1) | bit;
1712 }
1713
1714 let lane_indices = leaf_indices.to_array();
1715 for lane in 0..OBLIVIOUS_SIMD_LANES {
1716 output[processed + lane] =
1717 leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
1718 }
1719 processed += OBLIVIOUS_SIMD_LANES;
1720 }
1721
1722 processed
1723}
1724
1725pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
1726 training::train(train_set, config)
1727}
1728
1729#[cfg(test)]
1730mod tests;