Skip to main content

forestfire_core/
lib.rs

1//! ForestFire core model, training, inference, and interchange layer.
2//!
3//! The crate is organized around a few stable abstractions:
4//!
5//! - [`forestfire_data::TableAccess`] is the common data boundary for both
6//!   training and inference.
7//! - [`TrainConfig`] is the normalized configuration surface shared by the Rust
8//!   and Python APIs.
9//! - [`Model`] is the semantic model view used for exact prediction,
10//!   serialization, and introspection.
11//! - [`OptimizedModel`] is a lowered runtime view used when prediction speed
12//!   matters more than preserving the original tree layout.
13//!
14//! Keeping the semantic model and the runtime model separate is deliberate. It
15//! makes export and introspection straightforward while still allowing the
16//! optimized path to use layouts that are awkward to serialize directly.
17
18use forestfire_data::{
19    BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20};
21#[cfg(feature = "polars")]
22use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
23use rayon::ThreadPoolBuilder;
24use rayon::prelude::*;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::collections::BTreeMap;
28use std::error::Error;
29use std::fmt::{Display, Formatter};
30use std::sync::Arc;
31use wide::{u16x8, u32x8};
32
33mod boosting;
34mod bootstrap;
35mod compiled_artifact;
36mod forest;
37mod inference_input;
38mod introspection;
39pub mod ir;
40mod model_api;
41mod optimized_runtime;
42mod runtime_planning;
43mod sampling;
44mod training;
45pub mod tree;
46
47pub use boosting::BoostingError;
48pub use boosting::GradientBoostedTrees;
49pub use compiled_artifact::CompiledArtifactError;
50pub use forest::RandomForest;
51pub use introspection::IntrospectionError;
52pub use introspection::PredictionHistogramEntry;
53pub use introspection::PredictionValueStats;
54pub use introspection::TreeStructureSummary;
55pub use ir::IrError;
56pub use ir::ModelPackageIr;
57pub use model_api::OptimizedModel;
58pub use tree::classifier::DecisionTreeAlgorithm;
59pub use tree::classifier::DecisionTreeClassifier;
60pub use tree::classifier::DecisionTreeError;
61pub use tree::classifier::DecisionTreeOptions;
62pub use tree::classifier::train_c45;
63pub use tree::classifier::train_cart;
64pub use tree::classifier::train_id3;
65pub use tree::classifier::train_oblivious;
66pub use tree::classifier::train_randomized;
67pub use tree::regressor::DecisionTreeRegressor;
68pub use tree::regressor::RegressionTreeAlgorithm;
69pub use tree::regressor::RegressionTreeError;
70pub use tree::regressor::RegressionTreeOptions;
71pub use tree::regressor::train_cart_regressor;
72pub use tree::regressor::train_oblivious_regressor;
73pub use tree::regressor::train_randomized_regressor;
74#[cfg(feature = "polars")]
75const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
76pub(crate) use inference_input::ColumnMajorBinnedMatrix;
77pub(crate) use inference_input::CompactBinnedColumn;
78pub(crate) use inference_input::InferenceTable;
79pub(crate) use inference_input::ProjectedTableView;
80#[cfg(feature = "polars")]
81pub(crate) use inference_input::polars_named_columns;
82pub(crate) use introspection::prediction_value_stats;
83pub(crate) use introspection::tree_structure_summary;
84pub(crate) use optimized_runtime::InferenceExecutor;
85pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
86pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
87pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
88pub(crate) use optimized_runtime::OptimizedClassifierNode;
89pub(crate) use optimized_runtime::OptimizedRuntime;
90pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
91pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::resolve_inference_thread_count;
93pub(crate) use runtime_planning::build_feature_index_map;
94pub(crate) use runtime_planning::build_feature_projection;
95pub(crate) use runtime_planning::model_used_feature_indices;
96pub(crate) use runtime_planning::ordered_ensemble_indices;
97pub(crate) use runtime_planning::remap_feature_index;
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum TrainAlgorithm {
101    /// Train a single decision tree.
102    Dt,
103    /// Train a bootstrap-aggregated random forest.
104    Rf,
105    /// Train a second-order gradient-boosted ensemble.
106    Gbm,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum Criterion {
111    /// Let training choose the appropriate criterion for the requested setup.
112    Auto,
113    /// Gini impurity for classification.
114    Gini,
115    /// Entropy / information gain for classification.
116    Entropy,
117    /// Mean-based regression criterion.
118    Mean,
119    /// Median-based regression criterion.
120    Median,
121    /// Internal second-order criterion used by gradient boosting.
122    SecondOrder,
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum Task {
127    /// Predict a continuous numeric value.
128    Regression,
129    /// Predict one label from a finite set.
130    Classification,
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub enum TreeType {
135    /// Multiway information-gain tree.
136    Id3,
137    /// C4.5-style multiway tree.
138    C45,
139    /// Standard binary threshold tree.
140    Cart,
141    /// CART-style tree with randomized candidate selection.
142    Randomized,
143    /// Symmetric tree where every level shares the same split.
144    Oblivious,
145}
146
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum MaxFeatures {
149    /// Task-aware default: `sqrt` for classification, `third` for regression.
150    Auto,
151    /// Use all features at each split.
152    All,
153    /// Use `floor(sqrt(feature_count))` features.
154    Sqrt,
155    /// Use roughly one third of the features.
156    Third,
157    /// Use exactly this many features, capped to the available count.
158    Count(usize),
159}
160
161impl MaxFeatures {
162    pub fn resolve(self, task: Task, feature_count: usize) -> usize {
163        match self {
164            MaxFeatures::Auto => match task {
165                Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
166                Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
167            },
168            MaxFeatures::All => feature_count.max(1),
169            MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
170            MaxFeatures::Third => (feature_count / 3).max(1),
171            MaxFeatures::Count(count) => count.min(feature_count).max(1),
172        }
173    }
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178pub enum InputFeatureKind {
179    /// Numeric features are compared through their binned representation.
180    Numeric,
181    /// Binary features stay boolean all the way through the pipeline.
182    Binary,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
186pub struct NumericBinBoundary {
187    /// Bin identifier in the preprocessed feature space.
188    pub bin: u16,
189    /// Largest raw floating-point value that still belongs to this bin.
190    pub upper_bound: f64,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194#[serde(tag = "kind", rename_all = "snake_case")]
195pub enum FeaturePreprocessing {
196    /// Numeric features are represented by explicit bin boundaries.
197    Numeric {
198        bin_boundaries: Vec<NumericBinBoundary>,
199    },
200    /// Binary features do not require numeric bin boundaries.
201    Binary,
202}
203
204/// Unified training configuration shared by the Rust and Python entry points.
205///
206/// The crate keeps one normalized config type so the binding layer only has to
207/// perform input validation and type conversion; all semantic decisions happen
208/// from this one structure downward.
209#[derive(Debug, Clone, Copy)]
210pub struct TrainConfig {
211    /// High-level training family.
212    pub algorithm: TrainAlgorithm,
213    /// Regression or classification.
214    pub task: Task,
215    /// Tree learner used by the selected algorithm family.
216    pub tree_type: TreeType,
217    /// Split criterion. [`Criterion::Auto`] is resolved by the trainer.
218    pub criterion: Criterion,
219    /// Maximum tree depth.
220    pub max_depth: Option<usize>,
221    /// Smallest node size that is still allowed to split.
222    pub min_samples_split: Option<usize>,
223    /// Minimum child size after a split.
224    pub min_samples_leaf: Option<usize>,
225    /// Optional cap on training-side rayon threads.
226    pub physical_cores: Option<usize>,
227    /// Number of trees for ensemble algorithms.
228    pub n_trees: Option<usize>,
229    /// Feature subsampling strategy.
230    pub max_features: MaxFeatures,
231    /// Seed used for reproducible sampling and randomized splits.
232    pub seed: Option<u64>,
233    /// Whether random forests should compute out-of-bag metrics.
234    pub compute_oob: bool,
235    /// Gradient boosting shrinkage factor.
236    pub learning_rate: Option<f64>,
237    /// Whether gradient boosting should bootstrap rows before gradient sampling.
238    pub bootstrap: bool,
239    /// Fraction of largest-gradient rows always kept by GOSS sampling.
240    pub top_gradient_fraction: Option<f64>,
241    /// Fraction of the remaining rows randomly retained by GOSS sampling.
242    pub other_gradient_fraction: Option<f64>,
243}
244
245impl Default for TrainConfig {
246    fn default() -> Self {
247        Self {
248            algorithm: TrainAlgorithm::Dt,
249            task: Task::Regression,
250            tree_type: TreeType::Cart,
251            criterion: Criterion::Auto,
252            max_depth: None,
253            min_samples_split: None,
254            min_samples_leaf: None,
255            physical_cores: None,
256            n_trees: None,
257            max_features: MaxFeatures::Auto,
258            seed: None,
259            compute_oob: false,
260            learning_rate: None,
261            bootstrap: false,
262            top_gradient_fraction: None,
263            other_gradient_fraction: None,
264        }
265    }
266}
267
268/// Top-level semantic model enum.
269///
270/// This type stays close to the learned structure rather than the fastest
271/// possible runtime layout. That is what makes it suitable for introspection,
272/// serialization, and exact behavior parity across bindings.
273#[derive(Debug, Clone)]
274pub enum Model {
275    DecisionTreeClassifier(DecisionTreeClassifier),
276    DecisionTreeRegressor(DecisionTreeRegressor),
277    RandomForest(RandomForest),
278    GradientBoostedTrees(GradientBoostedTrees),
279}
280
281#[derive(Debug)]
282pub enum TrainError {
283    DecisionTree(DecisionTreeError),
284    RegressionTree(RegressionTreeError),
285    Boosting(BoostingError),
286    InvalidPhysicalCoreCount {
287        requested: usize,
288        available: usize,
289    },
290    ThreadPoolBuildFailed(String),
291    UnsupportedConfiguration {
292        task: Task,
293        tree_type: TreeType,
294        criterion: Criterion,
295    },
296    InvalidMaxDepth(usize),
297    InvalidMinSamplesSplit(usize),
298    InvalidMinSamplesLeaf(usize),
299    InvalidTreeCount(usize),
300    InvalidMaxFeatures(usize),
301}
302
303impl Display for TrainError {
304    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
305        match self {
306            TrainError::DecisionTree(err) => err.fmt(f),
307            TrainError::RegressionTree(err) => err.fmt(f),
308            TrainError::Boosting(err) => err.fmt(f),
309            TrainError::InvalidPhysicalCoreCount {
310                requested,
311                available,
312            } => write!(
313                f,
314                "Requested {} physical cores, but the available physical core count is {}.",
315                requested, available
316            ),
317            TrainError::ThreadPoolBuildFailed(message) => {
318                write!(f, "Failed to build training thread pool: {}.", message)
319            }
320            TrainError::UnsupportedConfiguration {
321                task,
322                tree_type,
323                criterion,
324            } => write!(
325                f,
326                "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
327                task, tree_type, criterion
328            ),
329            TrainError::InvalidMaxDepth(value) => {
330                write!(f, "max_depth must be at least 1. Received {}.", value)
331            }
332            TrainError::InvalidMinSamplesSplit(value) => {
333                write!(
334                    f,
335                    "min_samples_split must be at least 1. Received {}.",
336                    value
337                )
338            }
339            TrainError::InvalidMinSamplesLeaf(value) => {
340                write!(
341                    f,
342                    "min_samples_leaf must be at least 1. Received {}.",
343                    value
344                )
345            }
346            TrainError::InvalidTreeCount(n_trees) => {
347                write!(
348                    f,
349                    "Random forest requires at least one tree. Received {}.",
350                    n_trees
351                )
352            }
353            TrainError::InvalidMaxFeatures(count) => {
354                write!(
355                    f,
356                    "max_features must be at least 1 when provided as an integer. Received {}.",
357                    count
358                )
359            }
360        }
361    }
362}
363
364impl Error for TrainError {}
365
366#[derive(Debug, Clone, PartialEq)]
367pub enum PredictError {
368    ProbabilityPredictionRequiresClassification,
369    RaggedRows {
370        row: usize,
371        expected: usize,
372        actual: usize,
373    },
374    FeatureCountMismatch {
375        expected: usize,
376        actual: usize,
377    },
378    ColumnLengthMismatch {
379        feature: String,
380        expected: usize,
381        actual: usize,
382    },
383    MissingFeature(String),
384    UnexpectedFeature(String),
385    InvalidBinaryValue {
386        feature_index: usize,
387        row_index: usize,
388        value: f64,
389    },
390    NullValue {
391        feature: String,
392        row_index: usize,
393    },
394    UnsupportedFeatureType {
395        feature: String,
396        dtype: String,
397    },
398    Polars(String),
399}
400
401impl Display for PredictError {
402    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403        match self {
404            PredictError::ProbabilityPredictionRequiresClassification => write!(
405                f,
406                "predict_proba is only available for classification models."
407            ),
408            PredictError::RaggedRows {
409                row,
410                expected,
411                actual,
412            } => write!(
413                f,
414                "Ragged inference row at index {}: expected {} columns, found {}.",
415                row, expected, actual
416            ),
417            PredictError::FeatureCountMismatch { expected, actual } => write!(
418                f,
419                "Inference input has {} features, but the model expects {}.",
420                actual, expected
421            ),
422            PredictError::ColumnLengthMismatch {
423                feature,
424                expected,
425                actual,
426            } => write!(
427                f,
428                "Feature '{}' has {} values, expected {}.",
429                feature, actual, expected
430            ),
431            PredictError::MissingFeature(feature) => {
432                write!(f, "Missing required feature '{}'.", feature)
433            }
434            PredictError::UnexpectedFeature(feature) => {
435                write!(f, "Unexpected feature '{}'.", feature)
436            }
437            PredictError::InvalidBinaryValue {
438                feature_index,
439                row_index,
440                value,
441            } => write!(
442                f,
443                "Feature {} at row {} must be binary for inference, found {}.",
444                feature_index, row_index, value
445            ),
446            PredictError::NullValue { feature, row_index } => write!(
447                f,
448                "Feature '{}' contains a null value at row {}.",
449                feature, row_index
450            ),
451            PredictError::UnsupportedFeatureType { feature, dtype } => write!(
452                f,
453                "Feature '{}' has unsupported dtype '{}'.",
454                feature, dtype
455            ),
456            PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
457        }
458    }
459}
460
461impl Error for PredictError {}
462
463#[cfg(feature = "polars")]
464impl From<polars::error::PolarsError> for PredictError {
465    fn from(value: polars::error::PolarsError) -> Self {
466        PredictError::Polars(value.to_string())
467    }
468}
469
470#[derive(Debug)]
471pub enum OptimizeError {
472    InvalidPhysicalCoreCount { requested: usize, available: usize },
473    ThreadPoolBuildFailed(String),
474    UnsupportedModelType(&'static str),
475}
476
477impl Display for OptimizeError {
478    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
479        match self {
480            OptimizeError::InvalidPhysicalCoreCount {
481                requested,
482                available,
483            } => write!(
484                f,
485                "Requested {} physical cores, but the available physical core count is {}.",
486                requested, available
487            ),
488            OptimizeError::ThreadPoolBuildFailed(message) => {
489                write!(f, "Failed to build inference thread pool: {}.", message)
490            }
491            OptimizeError::UnsupportedModelType(model_type) => {
492                write!(
493                    f,
494                    "Optimized inference is not supported for model type '{}'.",
495                    model_type
496                )
497            }
498        }
499    }
500}
501
502impl Error for OptimizeError {}
503
504#[derive(Debug, Clone, Copy)]
505pub(crate) struct Parallelism {
506    thread_count: usize,
507}
508
509impl Parallelism {
510    pub(crate) fn sequential() -> Self {
511        Self { thread_count: 1 }
512    }
513
514    #[cfg(test)]
515    pub(crate) fn with_threads(thread_count: usize) -> Self {
516        Self {
517            thread_count: thread_count.max(1),
518        }
519    }
520
521    pub(crate) fn enabled(self) -> bool {
522        self.thread_count > 1
523    }
524}
525
526pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
527    (0..table.n_features())
528        .map(|feature_index| {
529            if table.is_binary_feature(feature_index) {
530                FeaturePreprocessing::Binary
531            } else {
532                let values = (0..table.n_rows())
533                    .map(|row_index| table.feature_value(feature_index, row_index))
534                    .collect::<Vec<_>>();
535                FeaturePreprocessing::Numeric {
536                    bin_boundaries: numeric_bin_boundaries(
537                        &values,
538                        NumericBins::Fixed(table.numeric_bin_cap()),
539                    )
540                    .into_iter()
541                    .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
542                    .collect(),
543                }
544            }
545        })
546        .collect()
547}
548
549impl OptimizedRuntime {
550    fn supports_batch_matrix(&self) -> bool {
551        matches!(
552            self,
553            OptimizedRuntime::BinaryClassifier { .. }
554                | OptimizedRuntime::BinaryRegressor { .. }
555                | OptimizedRuntime::ObliviousClassifier { .. }
556                | OptimizedRuntime::ObliviousRegressor { .. }
557                | OptimizedRuntime::ForestClassifier { .. }
558                | OptimizedRuntime::ForestRegressor { .. }
559                | OptimizedRuntime::BoostedBinaryClassifier { .. }
560                | OptimizedRuntime::BoostedRegressor { .. }
561        )
562    }
563
564    fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
565        n_rows > 1 && self.supports_batch_matrix()
566    }
567
568    fn from_model(model: &Model, feature_index_map: &[usize]) -> Self {
569        match model {
570            Model::DecisionTreeClassifier(classifier) => {
571                Self::from_classifier(classifier, feature_index_map)
572            }
573            Model::DecisionTreeRegressor(regressor) => {
574                Self::from_regressor(regressor, feature_index_map)
575            }
576            Model::RandomForest(forest) => match forest.task() {
577                Task::Classification => {
578                    let tree_order = ordered_ensemble_indices(forest.trees());
579                    Self::ForestClassifier {
580                        trees: tree_order
581                            .into_iter()
582                            .map(|tree_index| {
583                                Self::from_model(&forest.trees()[tree_index], feature_index_map)
584                            })
585                            .collect(),
586                        class_labels: forest
587                            .class_labels()
588                            .expect("classification forest stores class labels"),
589                    }
590                }
591                Task::Regression => {
592                    let tree_order = ordered_ensemble_indices(forest.trees());
593                    Self::ForestRegressor {
594                        trees: tree_order
595                            .into_iter()
596                            .map(|tree_index| {
597                                Self::from_model(&forest.trees()[tree_index], feature_index_map)
598                            })
599                            .collect(),
600                    }
601                }
602            },
603            Model::GradientBoostedTrees(model) => match model.task() {
604                Task::Classification => {
605                    let tree_order = ordered_ensemble_indices(model.trees());
606                    Self::BoostedBinaryClassifier {
607                        trees: tree_order
608                            .iter()
609                            .map(|tree_index| {
610                                Self::from_model(&model.trees()[*tree_index], feature_index_map)
611                            })
612                            .collect(),
613                        tree_weights: tree_order
614                            .iter()
615                            .map(|tree_index| model.tree_weights()[*tree_index])
616                            .collect(),
617                        base_score: model.base_score(),
618                        class_labels: model
619                            .class_labels()
620                            .expect("classification boosting stores class labels"),
621                    }
622                }
623                Task::Regression => {
624                    let tree_order = ordered_ensemble_indices(model.trees());
625                    Self::BoostedRegressor {
626                        trees: tree_order
627                            .iter()
628                            .map(|tree_index| {
629                                Self::from_model(&model.trees()[*tree_index], feature_index_map)
630                            })
631                            .collect(),
632                        tree_weights: tree_order
633                            .iter()
634                            .map(|tree_index| model.tree_weights()[*tree_index])
635                            .collect(),
636                        base_score: model.base_score(),
637                    }
638                }
639            },
640        }
641    }
642
643    fn from_classifier(classifier: &DecisionTreeClassifier, feature_index_map: &[usize]) -> Self {
644        match classifier.structure() {
645            tree::classifier::TreeStructure::Standard { nodes, root } => {
646                if classifier_nodes_are_binary_only(nodes) {
647                    return Self::BinaryClassifier {
648                        nodes: build_binary_classifier_layout(
649                            nodes,
650                            *root,
651                            classifier.class_labels(),
652                            feature_index_map,
653                        ),
654                        class_labels: classifier.class_labels().to_vec(),
655                    };
656                }
657
658                let optimized_nodes = nodes
659                    .iter()
660                    .map(|node| match node {
661                        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
662                            OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
663                                class_counts,
664                            ))
665                        }
666                        tree::classifier::TreeNode::BinarySplit {
667                            feature_index,
668                            threshold_bin,
669                            left_child,
670                            right_child,
671                            ..
672                        } => OptimizedClassifierNode::Binary {
673                            feature_index: remap_feature_index(*feature_index, feature_index_map),
674                            threshold_bin: *threshold_bin,
675                            children: [*left_child, *right_child],
676                        },
677                        tree::classifier::TreeNode::MultiwaySplit {
678                            feature_index,
679                            class_counts,
680                            branches,
681                            ..
682                        } => {
683                            let max_bin_index = branches
684                                .iter()
685                                .map(|(bin, _)| usize::from(*bin))
686                                .max()
687                                .unwrap_or(0);
688                            let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
689                            for (bin, child_index) in branches {
690                                child_lookup[usize::from(*bin)] = *child_index;
691                            }
692                            OptimizedClassifierNode::Multiway {
693                                feature_index: remap_feature_index(
694                                    *feature_index,
695                                    feature_index_map,
696                                ),
697                                child_lookup,
698                                max_bin_index,
699                                fallback_probabilities: normalized_probabilities_from_counts(
700                                    class_counts,
701                                ),
702                            }
703                        }
704                    })
705                    .collect();
706
707                Self::StandardClassifier {
708                    nodes: optimized_nodes,
709                    root: *root,
710                    class_labels: classifier.class_labels().to_vec(),
711                }
712            }
713            tree::classifier::TreeStructure::Oblivious {
714                splits,
715                leaf_class_counts,
716                ..
717            } => Self::ObliviousClassifier {
718                feature_indices: splits
719                    .iter()
720                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
721                    .collect(),
722                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
723                leaf_values: leaf_class_counts
724                    .iter()
725                    .map(|class_counts| normalized_probabilities_from_counts(class_counts))
726                    .collect(),
727                class_labels: classifier.class_labels().to_vec(),
728            },
729        }
730    }
731
732    fn from_regressor(regressor: &DecisionTreeRegressor, feature_index_map: &[usize]) -> Self {
733        match regressor.structure() {
734            tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
735                Self::BinaryRegressor {
736                    nodes: build_binary_regressor_layout(nodes, *root, feature_index_map),
737                }
738            }
739            tree::regressor::RegressionTreeStructure::Oblivious {
740                splits,
741                leaf_values,
742                ..
743            } => Self::ObliviousRegressor {
744                feature_indices: splits
745                    .iter()
746                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
747                    .collect(),
748                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
749                leaf_values: leaf_values.clone(),
750            },
751        }
752    }
753
754    #[inline(always)]
755    fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
756        match self {
757            OptimizedRuntime::BinaryClassifier { .. }
758            | OptimizedRuntime::StandardClassifier { .. }
759            | OptimizedRuntime::ObliviousClassifier { .. }
760            | OptimizedRuntime::ForestClassifier { .. }
761            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
762                let probabilities = self
763                    .predict_proba_table_row(table, row_index)
764                    .expect("classifier runtime supports probability prediction");
765                class_label_from_probabilities(&probabilities, self.class_labels())
766            }
767            OptimizedRuntime::BinaryRegressor { nodes } => {
768                predict_binary_regressor_row(nodes, |feature_index| {
769                    table.binned_value(feature_index, row_index)
770                })
771            }
772            OptimizedRuntime::ObliviousRegressor {
773                feature_indices,
774                threshold_bins,
775                leaf_values,
776            } => predict_oblivious_row(
777                feature_indices,
778                threshold_bins,
779                leaf_values,
780                |feature_index| table.binned_value(feature_index, row_index),
781            ),
782            OptimizedRuntime::ForestRegressor { trees } => {
783                trees
784                    .iter()
785                    .map(|tree| tree.predict_table_row(table, row_index))
786                    .sum::<f64>()
787                    / trees.len() as f64
788            }
789            OptimizedRuntime::BoostedRegressor {
790                trees,
791                tree_weights,
792                base_score,
793            } => {
794                *base_score
795                    + trees
796                        .iter()
797                        .zip(tree_weights.iter().copied())
798                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
799                        .sum::<f64>()
800            }
801        }
802    }
803
804    #[inline(always)]
805    fn predict_proba_table_row(
806        &self,
807        table: &dyn TableAccess,
808        row_index: usize,
809    ) -> Result<Vec<f64>, PredictError> {
810        match self {
811            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
812                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
813                    table.binned_value(feature_index, row_index)
814                })
815                .to_vec(),
816            ),
817            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
818                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
819                    table.binned_value(feature_index, row_index)
820                })
821                .to_vec(),
822            ),
823            OptimizedRuntime::ObliviousClassifier {
824                feature_indices,
825                threshold_bins,
826                leaf_values,
827                ..
828            } => Ok(predict_oblivious_probabilities_row(
829                feature_indices,
830                threshold_bins,
831                leaf_values,
832                |feature_index| table.binned_value(feature_index, row_index),
833            )
834            .to_vec()),
835            OptimizedRuntime::ForestClassifier { trees, .. } => {
836                let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
837                for tree in &trees[1..] {
838                    let row = tree.predict_proba_table_row(table, row_index)?;
839                    for (total, value) in totals.iter_mut().zip(row) {
840                        *total += value;
841                    }
842                }
843                let tree_count = trees.len() as f64;
844                for value in &mut totals {
845                    *value /= tree_count;
846                }
847                Ok(totals)
848            }
849            OptimizedRuntime::BoostedBinaryClassifier {
850                trees,
851                tree_weights,
852                base_score,
853                ..
854            } => {
855                let raw_score = *base_score
856                    + trees
857                        .iter()
858                        .zip(tree_weights.iter().copied())
859                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
860                        .sum::<f64>();
861                let positive = sigmoid(raw_score);
862                Ok(vec![1.0 - positive, positive])
863            }
864            OptimizedRuntime::BinaryRegressor { .. }
865            | OptimizedRuntime::ObliviousRegressor { .. }
866            | OptimizedRuntime::ForestRegressor { .. }
867            | OptimizedRuntime::BoostedRegressor { .. } => {
868                Err(PredictError::ProbabilityPredictionRequiresClassification)
869            }
870        }
871    }
872
873    fn predict_proba_table(
874        &self,
875        table: &dyn TableAccess,
876        executor: &InferenceExecutor,
877    ) -> Result<Vec<Vec<f64>>, PredictError> {
878        match self {
879            OptimizedRuntime::BinaryClassifier { .. }
880            | OptimizedRuntime::StandardClassifier { .. }
881            | OptimizedRuntime::ObliviousClassifier { .. }
882            | OptimizedRuntime::ForestClassifier { .. }
883            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
884                if self.should_use_batch_matrix(table.n_rows()) {
885                    let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
886                    self.predict_proba_column_major_matrix(&matrix, executor)
887                } else {
888                    (0..table.n_rows())
889                        .map(|row_index| self.predict_proba_table_row(table, row_index))
890                        .collect()
891                }
892            }
893            OptimizedRuntime::BinaryRegressor { .. }
894            | OptimizedRuntime::ObliviousRegressor { .. }
895            | OptimizedRuntime::ForestRegressor { .. }
896            | OptimizedRuntime::BoostedRegressor { .. } => {
897                Err(PredictError::ProbabilityPredictionRequiresClassification)
898            }
899        }
900    }
901
902    fn predict_column_major_matrix(
903        &self,
904        matrix: &ColumnMajorBinnedMatrix,
905        executor: &InferenceExecutor,
906    ) -> Vec<f64> {
907        match self {
908            OptimizedRuntime::BinaryClassifier { .. }
909            | OptimizedRuntime::StandardClassifier { .. }
910            | OptimizedRuntime::ObliviousClassifier { .. }
911            | OptimizedRuntime::ForestClassifier { .. }
912            | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
913                .predict_proba_column_major_matrix(matrix, executor)
914                .expect("classifier runtime supports probability prediction")
915                .into_iter()
916                .map(|row| class_label_from_probabilities(&row, self.class_labels()))
917                .collect(),
918            OptimizedRuntime::BinaryRegressor { nodes } => {
919                predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
920            }
921            OptimizedRuntime::ObliviousRegressor {
922                feature_indices,
923                threshold_bins,
924                leaf_values,
925            } => predict_oblivious_column_major_matrix(
926                feature_indices,
927                threshold_bins,
928                leaf_values,
929                matrix,
930                executor,
931            ),
932            OptimizedRuntime::ForestRegressor { trees } => {
933                let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
934                for tree in &trees[1..] {
935                    let values = tree.predict_column_major_matrix(matrix, executor);
936                    for (total, value) in totals.iter_mut().zip(values) {
937                        *total += value;
938                    }
939                }
940                let tree_count = trees.len() as f64;
941                for total in &mut totals {
942                    *total /= tree_count;
943                }
944                totals
945            }
946            OptimizedRuntime::BoostedRegressor {
947                trees,
948                tree_weights,
949                base_score,
950            } => {
951                let mut totals = vec![*base_score; matrix.n_rows];
952                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
953                    let values = tree.predict_column_major_matrix(matrix, executor);
954                    for (total, value) in totals.iter_mut().zip(values) {
955                        *total += weight * value;
956                    }
957                }
958                totals
959            }
960        }
961    }
962
963    fn predict_proba_column_major_matrix(
964        &self,
965        matrix: &ColumnMajorBinnedMatrix,
966        executor: &InferenceExecutor,
967    ) -> Result<Vec<Vec<f64>>, PredictError> {
968        match self {
969            OptimizedRuntime::BinaryClassifier { nodes, .. } => {
970                Ok(predict_binary_classifier_probabilities_column_major_matrix(
971                    nodes, matrix, executor,
972                ))
973            }
974            OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
975                .map(|row_index| {
976                    self.predict_proba_binned_row_from_columns(matrix, row_index)
977                        .expect("classifier runtime supports probability prediction")
978                })
979                .collect()),
980            OptimizedRuntime::ObliviousClassifier {
981                feature_indices,
982                threshold_bins,
983                leaf_values,
984                ..
985            } => Ok(predict_oblivious_probabilities_column_major_matrix(
986                feature_indices,
987                threshold_bins,
988                leaf_values,
989                matrix,
990                executor,
991            )),
992            OptimizedRuntime::ForestClassifier { trees, .. } => {
993                let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
994                for tree in &trees[1..] {
995                    let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
996                    for (row_totals, row_values) in totals.iter_mut().zip(rows) {
997                        for (total, value) in row_totals.iter_mut().zip(row_values) {
998                            *total += value;
999                        }
1000                    }
1001                }
1002                let tree_count = trees.len() as f64;
1003                for row in &mut totals {
1004                    for value in row {
1005                        *value /= tree_count;
1006                    }
1007                }
1008                Ok(totals)
1009            }
1010            OptimizedRuntime::BoostedBinaryClassifier {
1011                trees,
1012                tree_weights,
1013                base_score,
1014                ..
1015            } => {
1016                let mut raw_scores = vec![*base_score; matrix.n_rows];
1017                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1018                    let values = tree.predict_column_major_matrix(matrix, executor);
1019                    for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1020                        *raw_score += weight * value;
1021                    }
1022                }
1023                Ok(raw_scores
1024                    .into_iter()
1025                    .map(|raw_score| {
1026                        let positive = sigmoid(raw_score);
1027                        vec![1.0 - positive, positive]
1028                    })
1029                    .collect())
1030            }
1031            OptimizedRuntime::BinaryRegressor { .. }
1032            | OptimizedRuntime::ObliviousRegressor { .. }
1033            | OptimizedRuntime::ForestRegressor { .. }
1034            | OptimizedRuntime::BoostedRegressor { .. } => {
1035                Err(PredictError::ProbabilityPredictionRequiresClassification)
1036            }
1037        }
1038    }
1039
1040    fn class_labels(&self) -> &[f64] {
1041        match self {
1042            OptimizedRuntime::BinaryClassifier { class_labels, .. }
1043            | OptimizedRuntime::StandardClassifier { class_labels, .. }
1044            | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1045            | OptimizedRuntime::ForestClassifier { class_labels, .. }
1046            | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1047            _ => &[],
1048        }
1049    }
1050
1051    #[inline(always)]
1052    fn predict_binned_row_from_columns(
1053        &self,
1054        matrix: &ColumnMajorBinnedMatrix,
1055        row_index: usize,
1056    ) -> f64 {
1057        match self {
1058            OptimizedRuntime::BinaryRegressor { nodes } => {
1059                predict_binary_regressor_row(nodes, |feature_index| {
1060                    matrix.column(feature_index).value_at(row_index)
1061                })
1062            }
1063            OptimizedRuntime::ObliviousRegressor {
1064                feature_indices,
1065                threshold_bins,
1066                leaf_values,
1067            } => predict_oblivious_row(
1068                feature_indices,
1069                threshold_bins,
1070                leaf_values,
1071                |feature_index| matrix.column(feature_index).value_at(row_index),
1072            ),
1073            OptimizedRuntime::BoostedRegressor {
1074                trees,
1075                tree_weights,
1076                base_score,
1077            } => {
1078                *base_score
1079                    + trees
1080                        .iter()
1081                        .zip(tree_weights.iter().copied())
1082                        .map(|(tree, weight)| {
1083                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1084                        })
1085                        .sum::<f64>()
1086            }
1087            _ => self.predict_column_major_matrix(
1088                matrix,
1089                &InferenceExecutor::new(1).expect("inference executor"),
1090            )[row_index],
1091        }
1092    }
1093
1094    #[inline(always)]
1095    fn predict_proba_binned_row_from_columns(
1096        &self,
1097        matrix: &ColumnMajorBinnedMatrix,
1098        row_index: usize,
1099    ) -> Result<Vec<f64>, PredictError> {
1100        match self {
1101            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1102                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1103                    matrix.column(feature_index).value_at(row_index)
1104                })
1105                .to_vec(),
1106            ),
1107            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1108                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1109                    matrix.column(feature_index).value_at(row_index)
1110                })
1111                .to_vec(),
1112            ),
1113            OptimizedRuntime::ObliviousClassifier {
1114                feature_indices,
1115                threshold_bins,
1116                leaf_values,
1117                ..
1118            } => Ok(predict_oblivious_probabilities_row(
1119                feature_indices,
1120                threshold_bins,
1121                leaf_values,
1122                |feature_index| matrix.column(feature_index).value_at(row_index),
1123            )
1124            .to_vec()),
1125            OptimizedRuntime::ForestClassifier { trees, .. } => {
1126                let mut totals =
1127                    trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1128                for tree in &trees[1..] {
1129                    let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1130                    for (total, value) in totals.iter_mut().zip(row) {
1131                        *total += value;
1132                    }
1133                }
1134                let tree_count = trees.len() as f64;
1135                for value in &mut totals {
1136                    *value /= tree_count;
1137                }
1138                Ok(totals)
1139            }
1140            OptimizedRuntime::BoostedBinaryClassifier {
1141                trees,
1142                tree_weights,
1143                base_score,
1144                ..
1145            } => {
1146                let raw_score = *base_score
1147                    + trees
1148                        .iter()
1149                        .zip(tree_weights.iter().copied())
1150                        .map(|(tree, weight)| {
1151                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1152                        })
1153                        .sum::<f64>();
1154                let positive = sigmoid(raw_score);
1155                Ok(vec![1.0 - positive, positive])
1156            }
1157            OptimizedRuntime::BinaryRegressor { .. }
1158            | OptimizedRuntime::ObliviousRegressor { .. }
1159            | OptimizedRuntime::ForestRegressor { .. }
1160            | OptimizedRuntime::BoostedRegressor { .. } => {
1161                Err(PredictError::ProbabilityPredictionRequiresClassification)
1162            }
1163        }
1164    }
1165}
1166
1167#[inline(always)]
1168fn predict_standard_classifier_probabilities_row<F>(
1169    nodes: &[OptimizedClassifierNode],
1170    root: usize,
1171    bin_at: F,
1172) -> &[f64]
1173where
1174    F: Fn(usize) -> u16,
1175{
1176    let mut node_index = root;
1177    loop {
1178        match &nodes[node_index] {
1179            OptimizedClassifierNode::Leaf(value) => return value,
1180            OptimizedClassifierNode::Binary {
1181                feature_index,
1182                threshold_bin,
1183                children,
1184            } => {
1185                let go_right = usize::from(bin_at(*feature_index) > *threshold_bin);
1186                node_index = children[go_right];
1187            }
1188            OptimizedClassifierNode::Multiway {
1189                feature_index,
1190                child_lookup,
1191                max_bin_index,
1192                fallback_probabilities,
1193            } => {
1194                let bin = usize::from(bin_at(*feature_index));
1195                if bin > *max_bin_index {
1196                    return fallback_probabilities;
1197                }
1198                let child_index = child_lookup[bin];
1199                if child_index == usize::MAX {
1200                    return fallback_probabilities;
1201                }
1202                node_index = child_index;
1203            }
1204        }
1205    }
1206}
1207
1208#[inline(always)]
1209fn predict_binary_classifier_probabilities_row<F>(
1210    nodes: &[OptimizedBinaryClassifierNode],
1211    bin_at: F,
1212) -> &[f64]
1213where
1214    F: Fn(usize) -> u16,
1215{
1216    let mut node_index = 0usize;
1217    loop {
1218        match &nodes[node_index] {
1219            OptimizedBinaryClassifierNode::Leaf(value) => return value,
1220            OptimizedBinaryClassifierNode::Branch {
1221                feature_index,
1222                threshold_bin,
1223                jump_index,
1224                jump_if_greater,
1225            } => {
1226                let go_right = bin_at(*feature_index) > *threshold_bin;
1227                node_index = if go_right == *jump_if_greater {
1228                    *jump_index
1229                } else {
1230                    node_index + 1
1231                };
1232            }
1233        }
1234    }
1235}
1236
1237#[inline(always)]
1238fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1239where
1240    F: Fn(usize) -> u16,
1241{
1242    let mut node_index = 0usize;
1243    loop {
1244        match &nodes[node_index] {
1245            OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1246            OptimizedBinaryRegressorNode::Branch {
1247                feature_index,
1248                threshold_bin,
1249                jump_index,
1250                jump_if_greater,
1251            } => {
1252                let go_right = bin_at(*feature_index) > *threshold_bin;
1253                node_index = if go_right == *jump_if_greater {
1254                    *jump_index
1255                } else {
1256                    node_index + 1
1257                };
1258            }
1259        }
1260    }
1261}
1262
1263fn predict_binary_classifier_probabilities_column_major_matrix(
1264    nodes: &[OptimizedBinaryClassifierNode],
1265    matrix: &ColumnMajorBinnedMatrix,
1266    _executor: &InferenceExecutor,
1267) -> Vec<Vec<f64>> {
1268    (0..matrix.n_rows)
1269        .map(|row_index| {
1270            predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1271                matrix.column(feature_index).value_at(row_index)
1272            })
1273            .to_vec()
1274        })
1275        .collect()
1276}
1277
1278fn predict_binary_regressor_column_major_matrix(
1279    nodes: &[OptimizedBinaryRegressorNode],
1280    matrix: &ColumnMajorBinnedMatrix,
1281    executor: &InferenceExecutor,
1282) -> Vec<f64> {
1283    let mut outputs = vec![0.0; matrix.n_rows];
1284    executor.fill_chunks(
1285        &mut outputs,
1286        STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1287        |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1288    );
1289    outputs
1290}
1291
1292fn predict_binary_regressor_chunk(
1293    nodes: &[OptimizedBinaryRegressorNode],
1294    matrix: &ColumnMajorBinnedMatrix,
1295    start_row: usize,
1296    output: &mut [f64],
1297) {
1298    let mut row_indices: Vec<usize> = (0..output.len()).collect();
1299    let mut stack = vec![(0usize, 0usize, output.len())];
1300
1301    while let Some((node_index, start, end)) = stack.pop() {
1302        match &nodes[node_index] {
1303            OptimizedBinaryRegressorNode::Leaf(value) => {
1304                for position in start..end {
1305                    output[row_indices[position]] = *value;
1306                }
1307            }
1308            OptimizedBinaryRegressorNode::Branch {
1309                feature_index,
1310                threshold_bin,
1311                jump_index,
1312                jump_if_greater,
1313            } => {
1314                let fallthrough_index = node_index + 1;
1315                if *jump_index == fallthrough_index {
1316                    stack.push((fallthrough_index, start, end));
1317                    continue;
1318                }
1319
1320                let column = matrix.column(*feature_index);
1321                let mut partition = start;
1322                let mut jump_start = end;
1323                match column {
1324                    CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1325                        let threshold = *threshold_bin as u8;
1326                        while partition < jump_start {
1327                            let row_offset = row_indices[partition];
1328                            let go_right = values[start_row + row_offset] > threshold;
1329                            let goes_jump = go_right == *jump_if_greater;
1330                            if goes_jump {
1331                                jump_start -= 1;
1332                                row_indices.swap(partition, jump_start);
1333                            } else {
1334                                partition += 1;
1335                            }
1336                        }
1337                    }
1338                    _ => {
1339                        while partition < jump_start {
1340                            let row_offset = row_indices[partition];
1341                            let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1342                            let goes_jump = go_right == *jump_if_greater;
1343                            if goes_jump {
1344                                jump_start -= 1;
1345                                row_indices.swap(partition, jump_start);
1346                            } else {
1347                                partition += 1;
1348                            }
1349                        }
1350                    }
1351                }
1352
1353                if jump_start < end {
1354                    stack.push((*jump_index, jump_start, end));
1355                }
1356                if start < jump_start {
1357                    stack.push((fallthrough_index, start, jump_start));
1358                }
1359            }
1360        }
1361    }
1362}
1363
1364#[inline(always)]
1365fn predict_oblivious_row<F>(
1366    feature_indices: &[usize],
1367    threshold_bins: &[u16],
1368    leaf_values: &[f64],
1369    bin_at: F,
1370) -> f64
1371where
1372    F: Fn(usize) -> u16,
1373{
1374    let mut leaf_index = 0usize;
1375    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1376        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1377        leaf_index = (leaf_index << 1) | go_right;
1378    }
1379    leaf_values[leaf_index]
1380}
1381
1382#[inline(always)]
1383fn predict_oblivious_probabilities_row<'a, F>(
1384    feature_indices: &[usize],
1385    threshold_bins: &[u16],
1386    leaf_values: &'a [Vec<f64>],
1387    bin_at: F,
1388) -> &'a [f64]
1389where
1390    F: Fn(usize) -> u16,
1391{
1392    let mut leaf_index = 0usize;
1393    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1394        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1395        leaf_index = (leaf_index << 1) | go_right;
1396    }
1397    leaf_values[leaf_index].as_slice()
1398}
1399
1400fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1401    let total = class_counts.iter().sum::<usize>();
1402    if total == 0 {
1403        return vec![0.0; class_counts.len()];
1404    }
1405
1406    class_counts
1407        .iter()
1408        .map(|count| *count as f64 / total as f64)
1409        .collect()
1410}
1411
1412fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1413    let best_index = probabilities
1414        .iter()
1415        .copied()
1416        .enumerate()
1417        .max_by(|(left_index, left), (right_index, right)| {
1418            left.total_cmp(right)
1419                .then_with(|| right_index.cmp(left_index))
1420        })
1421        .map(|(index, _)| index)
1422        .expect("classification probability row is non-empty");
1423    class_labels[best_index]
1424}
1425
1426#[inline(always)]
1427fn sigmoid(value: f64) -> f64 {
1428    if value >= 0.0 {
1429        let exp = (-value).exp();
1430        1.0 / (1.0 + exp)
1431    } else {
1432        let exp = value.exp();
1433        exp / (1.0 + exp)
1434    }
1435}
1436
1437fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1438    nodes.iter().all(|node| {
1439        matches!(
1440            node,
1441            tree::classifier::TreeNode::Leaf { .. }
1442                | tree::classifier::TreeNode::BinarySplit { .. }
1443        )
1444    })
1445}
1446
1447fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1448    match &nodes[node_index] {
1449        tree::classifier::TreeNode::Leaf { sample_count, .. }
1450        | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1451        | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1452    }
1453}
1454
1455fn build_binary_classifier_layout(
1456    nodes: &[tree::classifier::TreeNode],
1457    root: usize,
1458    _class_labels: &[f64],
1459    feature_index_map: &[usize],
1460) -> Vec<OptimizedBinaryClassifierNode> {
1461    let mut layout = Vec::with_capacity(nodes.len());
1462    append_binary_classifier_node(nodes, root, &mut layout, feature_index_map);
1463    layout
1464}
1465
1466fn append_binary_classifier_node(
1467    nodes: &[tree::classifier::TreeNode],
1468    node_index: usize,
1469    layout: &mut Vec<OptimizedBinaryClassifierNode>,
1470    feature_index_map: &[usize],
1471) -> usize {
1472    let current_index = layout.len();
1473    layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1474
1475    match &nodes[node_index] {
1476        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1477            layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1478                normalized_probabilities_from_counts(class_counts),
1479            );
1480        }
1481        tree::classifier::TreeNode::BinarySplit {
1482            feature_index,
1483            threshold_bin,
1484            left_child,
1485            right_child,
1486            ..
1487        } => {
1488            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1489                (*left_child, *left_child, true)
1490            } else {
1491                let left_count = classifier_node_sample_count(nodes, *left_child);
1492                let right_count = classifier_node_sample_count(nodes, *right_child);
1493                if left_count >= right_count {
1494                    (*left_child, *right_child, true)
1495                } else {
1496                    (*right_child, *left_child, false)
1497                }
1498            };
1499
1500            let fallthrough_index =
1501                append_binary_classifier_node(nodes, fallthrough_child, layout, feature_index_map);
1502            debug_assert_eq!(fallthrough_index, current_index + 1);
1503            let jump_index = if jump_child == fallthrough_child {
1504                fallthrough_index
1505            } else {
1506                append_binary_classifier_node(nodes, jump_child, layout, feature_index_map)
1507            };
1508
1509            layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1510                feature_index: remap_feature_index(*feature_index, feature_index_map),
1511                threshold_bin: *threshold_bin,
1512                jump_index,
1513                jump_if_greater,
1514            };
1515        }
1516        tree::classifier::TreeNode::MultiwaySplit { .. } => {
1517            unreachable!("multiway nodes are filtered out before binary layout construction");
1518        }
1519    }
1520
1521    current_index
1522}
1523
1524fn regressor_node_sample_count(
1525    nodes: &[tree::regressor::RegressionNode],
1526    node_index: usize,
1527) -> usize {
1528    match &nodes[node_index] {
1529        tree::regressor::RegressionNode::Leaf { sample_count, .. }
1530        | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1531    }
1532}
1533
1534fn build_binary_regressor_layout(
1535    nodes: &[tree::regressor::RegressionNode],
1536    root: usize,
1537    feature_index_map: &[usize],
1538) -> Vec<OptimizedBinaryRegressorNode> {
1539    let mut layout = Vec::with_capacity(nodes.len());
1540    append_binary_regressor_node(nodes, root, &mut layout, feature_index_map);
1541    layout
1542}
1543
1544fn append_binary_regressor_node(
1545    nodes: &[tree::regressor::RegressionNode],
1546    node_index: usize,
1547    layout: &mut Vec<OptimizedBinaryRegressorNode>,
1548    feature_index_map: &[usize],
1549) -> usize {
1550    let current_index = layout.len();
1551    layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1552
1553    match &nodes[node_index] {
1554        tree::regressor::RegressionNode::Leaf { value, .. } => {
1555            layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1556        }
1557        tree::regressor::RegressionNode::BinarySplit {
1558            feature_index,
1559            threshold_bin,
1560            left_child,
1561            right_child,
1562            ..
1563        } => {
1564            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1565                (*left_child, *left_child, true)
1566            } else {
1567                let left_count = regressor_node_sample_count(nodes, *left_child);
1568                let right_count = regressor_node_sample_count(nodes, *right_child);
1569                if left_count >= right_count {
1570                    (*left_child, *right_child, true)
1571                } else {
1572                    (*right_child, *left_child, false)
1573                }
1574            };
1575
1576            let fallthrough_index =
1577                append_binary_regressor_node(nodes, fallthrough_child, layout, feature_index_map);
1578            debug_assert_eq!(fallthrough_index, current_index + 1);
1579            let jump_index = if jump_child == fallthrough_child {
1580                fallthrough_index
1581            } else {
1582                append_binary_regressor_node(nodes, jump_child, layout, feature_index_map)
1583            };
1584
1585            layout[current_index] = OptimizedBinaryRegressorNode::Branch {
1586                feature_index: remap_feature_index(*feature_index, feature_index_map),
1587                threshold_bin: *threshold_bin,
1588                jump_index,
1589                jump_if_greater,
1590            };
1591        }
1592    }
1593
1594    current_index
1595}
1596
1597fn predict_oblivious_column_major_matrix(
1598    feature_indices: &[usize],
1599    threshold_bins: &[u16],
1600    leaf_values: &[f64],
1601    matrix: &ColumnMajorBinnedMatrix,
1602    executor: &InferenceExecutor,
1603) -> Vec<f64> {
1604    let mut outputs = vec![0.0; matrix.n_rows];
1605    executor.fill_chunks(
1606        &mut outputs,
1607        PARALLEL_INFERENCE_CHUNK_ROWS,
1608        |start_row, chunk| {
1609            predict_oblivious_chunk(
1610                feature_indices,
1611                threshold_bins,
1612                leaf_values,
1613                matrix,
1614                start_row,
1615                chunk,
1616            )
1617        },
1618    );
1619    outputs
1620}
1621
1622fn predict_oblivious_probabilities_column_major_matrix(
1623    feature_indices: &[usize],
1624    threshold_bins: &[u16],
1625    leaf_values: &[Vec<f64>],
1626    matrix: &ColumnMajorBinnedMatrix,
1627    _executor: &InferenceExecutor,
1628) -> Vec<Vec<f64>> {
1629    (0..matrix.n_rows)
1630        .map(|row_index| {
1631            predict_oblivious_probabilities_row(
1632                feature_indices,
1633                threshold_bins,
1634                leaf_values,
1635                |feature_index| matrix.column(feature_index).value_at(row_index),
1636            )
1637            .to_vec()
1638        })
1639        .collect()
1640}
1641
1642fn predict_oblivious_chunk(
1643    feature_indices: &[usize],
1644    threshold_bins: &[u16],
1645    leaf_values: &[f64],
1646    matrix: &ColumnMajorBinnedMatrix,
1647    start_row: usize,
1648    output: &mut [f64],
1649) {
1650    let processed = simd_predict_oblivious_chunk(
1651        feature_indices,
1652        threshold_bins,
1653        leaf_values,
1654        matrix,
1655        start_row,
1656        output,
1657    );
1658
1659    for (offset, out) in output.iter_mut().enumerate().skip(processed) {
1660        let row_index = start_row + offset;
1661        *out = predict_oblivious_row(
1662            feature_indices,
1663            threshold_bins,
1664            leaf_values,
1665            |feature_index| matrix.column(feature_index).value_at(row_index),
1666        );
1667    }
1668}
1669
1670fn simd_predict_oblivious_chunk(
1671    feature_indices: &[usize],
1672    threshold_bins: &[u16],
1673    leaf_values: &[f64],
1674    matrix: &ColumnMajorBinnedMatrix,
1675    start_row: usize,
1676    output: &mut [f64],
1677) -> usize {
1678    let mut processed = 0usize;
1679    let ones = u32x8::splat(1);
1680
1681    while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
1682        let base_row = start_row + processed;
1683        let mut leaf_indices = u32x8::splat(0);
1684
1685        for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1686            let column = matrix.column(feature_index);
1687            let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
1688                let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
1689                    .try_into()
1690                    .expect("lane width matches the fixed SIMD width");
1691                u32x8::new([
1692                    u32::from(lanes[0]),
1693                    u32::from(lanes[1]),
1694                    u32::from(lanes[2]),
1695                    u32::from(lanes[3]),
1696                    u32::from(lanes[4]),
1697                    u32::from(lanes[5]),
1698                    u32::from(lanes[6]),
1699                    u32::from(lanes[7]),
1700                ])
1701            } else {
1702                let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
1703                    .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
1704                    .expect("column is u16 when not u8")
1705                    .try_into()
1706                    .expect("lane width matches the fixed SIMD width");
1707                u32x8::from(u16x8::new(lanes))
1708            };
1709            let threshold = u32x8::splat(u32::from(threshold_bin));
1710            let bit = bins.cmp_gt(threshold) & ones;
1711            leaf_indices = (leaf_indices << 1) | bit;
1712        }
1713
1714        let lane_indices = leaf_indices.to_array();
1715        for lane in 0..OBLIVIOUS_SIMD_LANES {
1716            output[processed + lane] =
1717                leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
1718        }
1719        processed += OBLIVIOUS_SIMD_LANES;
1720    }
1721
1722    processed
1723}
1724
1725pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
1726    training::train(train_set, config)
1727}
1728
1729#[cfg(test)]
1730mod tests;