Skip to main content

forestfire_core/
lib.rs

1//! ForestFire core model, training, inference, and interchange layer.
2//!
3//! The crate is organized around a few stable abstractions:
4//!
5//! - [`forestfire_data::TableAccess`] is the common data boundary for both
6//!   training and inference.
7//! - [`TrainConfig`] is the normalized configuration surface shared by the Rust
8//!   and Python APIs.
9//! - [`Model`] is the semantic model view used for exact prediction,
10//!   serialization, and introspection.
11//! - [`OptimizedModel`] is a lowered runtime view used when prediction speed
12//!   matters more than preserving the original tree layout.
13//!
14//! Keeping the semantic model and the runtime model separate is deliberate. It
15//! makes export and introspection straightforward while still allowing the
16//! optimized path to use layouts that are awkward to serialize directly.
17
18use forestfire_data::{
19    BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20    numeric_missing_bin,
21};
22#[cfg(feature = "polars")]
23use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
24use rayon::ThreadPoolBuilder;
25use rayon::prelude::*;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use std::collections::{BTreeMap, BTreeSet};
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31use std::sync::Arc;
32use wide::{u16x8, u32x8};
33
34mod boosting;
35mod bootstrap;
36mod compiled_artifact;
37mod forest;
38mod inference_input;
39mod introspection;
40pub mod ir;
41mod model_api;
42mod optimized_runtime;
43mod runtime_planning;
44mod sampling;
45mod training;
46pub mod tree;
47
48pub use boosting::BoostingError;
49pub use boosting::GradientBoostedTrees;
50pub use compiled_artifact::CompiledArtifactError;
51pub use forest::RandomForest;
52pub use introspection::IntrospectionError;
53pub use introspection::PredictionHistogramEntry;
54pub use introspection::PredictionValueStats;
55pub use introspection::TreeStructureSummary;
56pub use ir::IrError;
57pub use ir::ModelPackageIr;
58pub use model_api::OptimizedModel;
59pub use tree::classifier::DecisionTreeAlgorithm;
60pub use tree::classifier::DecisionTreeClassifier;
61pub use tree::classifier::DecisionTreeError;
62pub use tree::classifier::DecisionTreeOptions;
63pub use tree::classifier::train_c45;
64pub use tree::classifier::train_cart;
65pub use tree::classifier::train_id3;
66pub use tree::classifier::train_oblivious;
67pub use tree::classifier::train_randomized;
68pub use tree::regressor::DecisionTreeRegressor;
69pub use tree::regressor::RegressionTreeAlgorithm;
70pub use tree::regressor::RegressionTreeError;
71pub use tree::regressor::RegressionTreeOptions;
72pub use tree::regressor::train_cart_regressor;
73pub use tree::regressor::train_oblivious_regressor;
74pub use tree::regressor::train_randomized_regressor;
75#[cfg(feature = "polars")]
76const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
77pub(crate) use inference_input::ColumnMajorBinnedMatrix;
78pub(crate) use inference_input::CompactBinnedColumn;
79pub(crate) use inference_input::InferenceTable;
80pub(crate) use inference_input::ProjectedTableView;
81#[cfg(feature = "polars")]
82pub(crate) use inference_input::polars_named_columns;
83pub(crate) use introspection::prediction_value_stats;
84pub(crate) use introspection::tree_structure_summary;
85pub(crate) use optimized_runtime::InferenceExecutor;
86pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
87pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
88pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
89pub(crate) use optimized_runtime::OptimizedClassifierNode;
90pub(crate) use optimized_runtime::OptimizedRuntime;
91pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
93pub(crate) use optimized_runtime::resolve_inference_thread_count;
94pub(crate) use runtime_planning::build_feature_index_map;
95pub(crate) use runtime_planning::build_feature_projection;
96pub(crate) use runtime_planning::model_used_feature_indices;
97pub(crate) use runtime_planning::ordered_ensemble_indices;
98pub(crate) use runtime_planning::remap_feature_index;
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum TrainAlgorithm {
102    /// Train a single decision tree.
103    Dt,
104    /// Train a bootstrap-aggregated random forest.
105    Rf,
106    /// Train a second-order gradient-boosted ensemble.
107    Gbm,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum Criterion {
112    /// Let training choose the appropriate criterion for the requested setup.
113    Auto,
114    /// Gini impurity for classification.
115    Gini,
116    /// Entropy / information gain for classification.
117    Entropy,
118    /// Mean-based regression criterion.
119    Mean,
120    /// Median-based regression criterion.
121    Median,
122    /// Internal second-order criterion used by gradient boosting.
123    SecondOrder,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum Task {
128    /// Predict a continuous numeric value.
129    Regression,
130    /// Predict one label from a finite set.
131    Classification,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TreeType {
136    /// Multiway information-gain tree.
137    Id3,
138    /// C4.5-style multiway tree.
139    C45,
140    /// Standard binary threshold tree.
141    Cart,
142    /// CART-style tree with randomized candidate selection.
143    Randomized,
144    /// Symmetric tree where every level shares the same split.
145    Oblivious,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum MaxFeatures {
150    /// Task-aware default: `sqrt` for classification, `third` for regression.
151    Auto,
152    /// Use all features at each split.
153    All,
154    /// Use `floor(sqrt(feature_count))` features.
155    Sqrt,
156    /// Use roughly one third of the features.
157    Third,
158    /// Use exactly this many features, capped to the available count.
159    Count(usize),
160}
161
162impl MaxFeatures {
163    pub fn resolve(self, task: Task, feature_count: usize) -> usize {
164        match self {
165            MaxFeatures::Auto => match task {
166                Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
167                Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
168            },
169            MaxFeatures::All => feature_count.max(1),
170            MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
171            MaxFeatures::Third => (feature_count / 3).max(1),
172            MaxFeatures::Count(count) => count.min(feature_count).max(1),
173        }
174    }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
178#[serde(rename_all = "snake_case")]
179pub enum InputFeatureKind {
180    /// Numeric features are compared through their binned representation.
181    Numeric,
182    /// Binary features stay boolean all the way through the pipeline.
183    Binary,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
187pub struct NumericBinBoundary {
188    /// Bin identifier in the preprocessed feature space.
189    pub bin: u16,
190    /// Largest raw floating-point value that still belongs to this bin.
191    pub upper_bound: f64,
192}
193
194#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
195#[serde(tag = "kind", rename_all = "snake_case")]
196pub enum FeaturePreprocessing {
197    /// Numeric features are represented by explicit bin boundaries.
198    Numeric {
199        bin_boundaries: Vec<NumericBinBoundary>,
200        missing_bin: u16,
201    },
202    /// Binary features do not require numeric bin boundaries.
203    Binary,
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum MissingValueStrategy {
208    Heuristic,
209    Optimal,
210}
211
212#[derive(Debug, Clone, PartialEq, Eq)]
213pub enum MissingValueStrategyConfig {
214    Global(MissingValueStrategy),
215    PerFeature(BTreeMap<usize, MissingValueStrategy>),
216}
217
218impl MissingValueStrategyConfig {
219    pub fn heuristic() -> Self {
220        Self::Global(MissingValueStrategy::Heuristic)
221    }
222
223    pub fn optimal() -> Self {
224        Self::Global(MissingValueStrategy::Optimal)
225    }
226
227    pub fn resolve_for_feature_count(
228        &self,
229        feature_count: usize,
230    ) -> Result<Vec<MissingValueStrategy>, TrainError> {
231        match self {
232            MissingValueStrategyConfig::Global(strategy) => Ok(vec![*strategy; feature_count]),
233            MissingValueStrategyConfig::PerFeature(strategies) => {
234                let mut resolved = vec![MissingValueStrategy::Heuristic; feature_count];
235                for (&feature_index, &strategy) in strategies {
236                    if feature_index >= feature_count {
237                        return Err(TrainError::InvalidMissingValueStrategyFeature {
238                            feature_index,
239                            feature_count,
240                        });
241                    }
242                    resolved[feature_index] = strategy;
243                }
244                Ok(resolved)
245            }
246        }
247    }
248}
249
250/// Unified training configuration shared by the Rust and Python entry points.
251///
252/// The crate keeps one normalized config type so the binding layer only has to
253/// perform input validation and type conversion; all semantic decisions happen
254/// from this one structure downward.
255#[derive(Debug, Clone)]
256pub struct TrainConfig {
257    /// High-level training family.
258    pub algorithm: TrainAlgorithm,
259    /// Regression or classification.
260    pub task: Task,
261    /// Tree learner used by the selected algorithm family.
262    pub tree_type: TreeType,
263    /// Split criterion. [`Criterion::Auto`] is resolved by the trainer.
264    pub criterion: Criterion,
265    /// Maximum tree depth.
266    pub max_depth: Option<usize>,
267    /// Smallest node size that is still allowed to split.
268    pub min_samples_split: Option<usize>,
269    /// Minimum child size after a split.
270    pub min_samples_leaf: Option<usize>,
271    /// Optional cap on training-side rayon threads.
272    pub physical_cores: Option<usize>,
273    /// Number of trees for ensemble algorithms.
274    pub n_trees: Option<usize>,
275    /// Feature subsampling strategy.
276    pub max_features: MaxFeatures,
277    /// Seed used for reproducible sampling and randomized splits.
278    pub seed: Option<u64>,
279    /// Whether random forests should compute out-of-bag metrics.
280    pub compute_oob: bool,
281    /// Gradient boosting shrinkage factor.
282    pub learning_rate: Option<f64>,
283    /// Whether gradient boosting should bootstrap rows before gradient sampling.
284    pub bootstrap: bool,
285    /// Fraction of largest-gradient rows always kept by GOSS sampling.
286    pub top_gradient_fraction: Option<f64>,
287    /// Fraction of the remaining rows randomly retained by GOSS sampling.
288    pub other_gradient_fraction: Option<f64>,
289    /// Strategy used to evaluate missing-value routing during split search.
290    pub missing_value_strategy: MissingValueStrategyConfig,
291    /// Optional numeric histogram bin configuration for training-time split search.
292    ///
293    /// `None` preserves the incoming table's existing numeric bins. `Some(...)`
294    /// rebuilds the numeric training view at the requested resolution before
295    /// fitting, while leaving the caller's source table unchanged.
296    pub histogram_bins: Option<NumericBins>,
297}
298
299impl Default for TrainConfig {
300    fn default() -> Self {
301        Self {
302            algorithm: TrainAlgorithm::Dt,
303            task: Task::Regression,
304            tree_type: TreeType::Cart,
305            criterion: Criterion::Auto,
306            max_depth: None,
307            min_samples_split: None,
308            min_samples_leaf: None,
309            physical_cores: None,
310            n_trees: None,
311            max_features: MaxFeatures::Auto,
312            seed: None,
313            compute_oob: false,
314            learning_rate: None,
315            bootstrap: false,
316            top_gradient_fraction: None,
317            other_gradient_fraction: None,
318            missing_value_strategy: MissingValueStrategyConfig::heuristic(),
319            histogram_bins: None,
320        }
321    }
322}
323
324/// Top-level semantic model enum.
325///
326/// This type stays close to the learned structure rather than the fastest
327/// possible runtime layout. That is what makes it suitable for introspection,
328/// serialization, and exact behavior parity across bindings.
329#[derive(Debug, Clone)]
330pub enum Model {
331    DecisionTreeClassifier(DecisionTreeClassifier),
332    DecisionTreeRegressor(DecisionTreeRegressor),
333    RandomForest(RandomForest),
334    GradientBoostedTrees(GradientBoostedTrees),
335}
336
337#[derive(Debug)]
338pub enum TrainError {
339    DecisionTree(DecisionTreeError),
340    RegressionTree(RegressionTreeError),
341    Boosting(BoostingError),
342    InvalidPhysicalCoreCount {
343        requested: usize,
344        available: usize,
345    },
346    ThreadPoolBuildFailed(String),
347    UnsupportedConfiguration {
348        task: Task,
349        tree_type: TreeType,
350        criterion: Criterion,
351    },
352    InvalidMaxDepth(usize),
353    InvalidMinSamplesSplit(usize),
354    InvalidMinSamplesLeaf(usize),
355    InvalidTreeCount(usize),
356    InvalidMaxFeatures(usize),
357    InvalidMissingValueStrategyFeature {
358        feature_index: usize,
359        feature_count: usize,
360    },
361}
362
363impl Display for TrainError {
364    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
365        match self {
366            TrainError::DecisionTree(err) => err.fmt(f),
367            TrainError::RegressionTree(err) => err.fmt(f),
368            TrainError::Boosting(err) => err.fmt(f),
369            TrainError::InvalidPhysicalCoreCount {
370                requested,
371                available,
372            } => write!(
373                f,
374                "Requested {} physical cores, but the available physical core count is {}.",
375                requested, available
376            ),
377            TrainError::ThreadPoolBuildFailed(message) => {
378                write!(f, "Failed to build training thread pool: {}.", message)
379            }
380            TrainError::UnsupportedConfiguration {
381                task,
382                tree_type,
383                criterion,
384            } => write!(
385                f,
386                "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
387                task, tree_type, criterion
388            ),
389            TrainError::InvalidMaxDepth(value) => {
390                write!(f, "max_depth must be at least 1. Received {}.", value)
391            }
392            TrainError::InvalidMinSamplesSplit(value) => {
393                write!(
394                    f,
395                    "min_samples_split must be at least 1. Received {}.",
396                    value
397                )
398            }
399            TrainError::InvalidMinSamplesLeaf(value) => {
400                write!(
401                    f,
402                    "min_samples_leaf must be at least 1. Received {}.",
403                    value
404                )
405            }
406            TrainError::InvalidTreeCount(n_trees) => {
407                write!(
408                    f,
409                    "Random forest requires at least one tree. Received {}.",
410                    n_trees
411                )
412            }
413            TrainError::InvalidMaxFeatures(count) => {
414                write!(
415                    f,
416                    "max_features must be at least 1 when provided as an integer. Received {}.",
417                    count
418                )
419            }
420            TrainError::InvalidMissingValueStrategyFeature {
421                feature_index,
422                feature_count,
423            } => write!(
424                f,
425                "missing_value_strategy references feature {}, but the training table only has {} features.",
426                feature_index, feature_count
427            ),
428        }
429    }
430}
431
432impl Error for TrainError {}
433
434#[derive(Debug, Clone, PartialEq)]
435pub enum PredictError {
436    ProbabilityPredictionRequiresClassification,
437    RaggedRows {
438        row: usize,
439        expected: usize,
440        actual: usize,
441    },
442    FeatureCountMismatch {
443        expected: usize,
444        actual: usize,
445    },
446    ColumnLengthMismatch {
447        feature: String,
448        expected: usize,
449        actual: usize,
450    },
451    MissingFeature(String),
452    UnexpectedFeature(String),
453    InvalidBinaryValue {
454        feature_index: usize,
455        row_index: usize,
456        value: f64,
457    },
458    NullValue {
459        feature: String,
460        row_index: usize,
461    },
462    UnsupportedFeatureType {
463        feature: String,
464        dtype: String,
465    },
466    Polars(String),
467}
468
469impl Display for PredictError {
470    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
471        match self {
472            PredictError::ProbabilityPredictionRequiresClassification => write!(
473                f,
474                "predict_proba is only available for classification models."
475            ),
476            PredictError::RaggedRows {
477                row,
478                expected,
479                actual,
480            } => write!(
481                f,
482                "Ragged inference row at index {}: expected {} columns, found {}.",
483                row, expected, actual
484            ),
485            PredictError::FeatureCountMismatch { expected, actual } => write!(
486                f,
487                "Inference input has {} features, but the model expects {}.",
488                actual, expected
489            ),
490            PredictError::ColumnLengthMismatch {
491                feature,
492                expected,
493                actual,
494            } => write!(
495                f,
496                "Feature '{}' has {} values, expected {}.",
497                feature, actual, expected
498            ),
499            PredictError::MissingFeature(feature) => {
500                write!(f, "Missing required feature '{}'.", feature)
501            }
502            PredictError::UnexpectedFeature(feature) => {
503                write!(f, "Unexpected feature '{}'.", feature)
504            }
505            PredictError::InvalidBinaryValue {
506                feature_index,
507                row_index,
508                value,
509            } => write!(
510                f,
511                "Feature {} at row {} must be binary for inference, found {}.",
512                feature_index, row_index, value
513            ),
514            PredictError::NullValue { feature, row_index } => write!(
515                f,
516                "Feature '{}' contains a null value at row {}.",
517                feature, row_index
518            ),
519            PredictError::UnsupportedFeatureType { feature, dtype } => write!(
520                f,
521                "Feature '{}' has unsupported dtype '{}'.",
522                feature, dtype
523            ),
524            PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
525        }
526    }
527}
528
529impl Error for PredictError {}
530
531#[cfg(feature = "polars")]
532impl From<polars::error::PolarsError> for PredictError {
533    fn from(value: polars::error::PolarsError) -> Self {
534        PredictError::Polars(value.to_string())
535    }
536}
537
538#[derive(Debug)]
539pub enum OptimizeError {
540    InvalidPhysicalCoreCount { requested: usize, available: usize },
541    ThreadPoolBuildFailed(String),
542    UnsupportedModelType(&'static str),
543}
544
545impl Display for OptimizeError {
546    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
547        match self {
548            OptimizeError::InvalidPhysicalCoreCount {
549                requested,
550                available,
551            } => write!(
552                f,
553                "Requested {} physical cores, but the available physical core count is {}.",
554                requested, available
555            ),
556            OptimizeError::ThreadPoolBuildFailed(message) => {
557                write!(f, "Failed to build inference thread pool: {}.", message)
558            }
559            OptimizeError::UnsupportedModelType(model_type) => {
560                write!(
561                    f,
562                    "Optimized inference is not supported for model type '{}'.",
563                    model_type
564                )
565            }
566        }
567    }
568}
569
570impl Error for OptimizeError {}
571
572#[derive(Debug, Clone, Copy)]
573pub(crate) struct Parallelism {
574    thread_count: usize,
575}
576
577impl Parallelism {
578    pub(crate) fn sequential() -> Self {
579        Self { thread_count: 1 }
580    }
581
582    #[cfg(test)]
583    pub(crate) fn with_threads(thread_count: usize) -> Self {
584        Self {
585            thread_count: thread_count.max(1),
586        }
587    }
588
589    pub(crate) fn enabled(self) -> bool {
590        self.thread_count > 1
591    }
592}
593
594pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
595    (0..table.n_features())
596        .map(|feature_index| {
597            if table.is_binary_feature(feature_index) {
598                FeaturePreprocessing::Binary
599            } else {
600                let values = (0..table.n_rows())
601                    .map(|row_index| table.feature_value(feature_index, row_index))
602                    .collect::<Vec<_>>();
603                FeaturePreprocessing::Numeric {
604                    bin_boundaries: numeric_bin_boundaries(
605                        &values,
606                        NumericBins::Fixed(table.numeric_bin_cap()),
607                    )
608                    .into_iter()
609                    .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
610                    .collect(),
611                    missing_bin: numeric_missing_bin(NumericBins::Fixed(table.numeric_bin_cap())),
612                }
613            }
614        })
615        .collect()
616}
617
618fn missing_feature_enabled(
619    feature_index: usize,
620    missing_features: Option<&BTreeSet<usize>>,
621) -> bool {
622    missing_features.is_none_or(|features| features.contains(&feature_index))
623}
624
625fn optimized_missing_bin(
626    preprocessing: &[FeaturePreprocessing],
627    feature_index: usize,
628    missing_features: Option<&BTreeSet<usize>>,
629) -> Option<u16> {
630    if !missing_feature_enabled(feature_index, missing_features) {
631        return None;
632    }
633
634    match preprocessing.get(feature_index) {
635        Some(FeaturePreprocessing::Binary) => Some(forestfire_data::BINARY_MISSING_BIN),
636        Some(FeaturePreprocessing::Numeric { missing_bin, .. }) => Some(*missing_bin),
637        None => None,
638    }
639}
640
641impl OptimizedRuntime {
642    fn supports_batch_matrix(&self) -> bool {
643        matches!(
644            self,
645            OptimizedRuntime::BinaryClassifier { .. }
646                | OptimizedRuntime::BinaryRegressor { .. }
647                | OptimizedRuntime::ObliviousClassifier { .. }
648                | OptimizedRuntime::ObliviousRegressor { .. }
649                | OptimizedRuntime::ForestClassifier { .. }
650                | OptimizedRuntime::ForestRegressor { .. }
651                | OptimizedRuntime::BoostedBinaryClassifier { .. }
652                | OptimizedRuntime::BoostedRegressor { .. }
653        )
654    }
655
656    fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
657        n_rows > 1 && self.supports_batch_matrix()
658    }
659
660    fn from_model(
661        model: &Model,
662        feature_index_map: &[usize],
663        missing_features: Option<&BTreeSet<usize>>,
664    ) -> Self {
665        match model {
666            Model::DecisionTreeClassifier(classifier) => {
667                Self::from_classifier(classifier, feature_index_map, missing_features)
668            }
669            Model::DecisionTreeRegressor(regressor) => {
670                Self::from_regressor(regressor, feature_index_map, missing_features)
671            }
672            Model::RandomForest(forest) => match forest.task() {
673                Task::Classification => {
674                    let tree_order = ordered_ensemble_indices(forest.trees());
675                    Self::ForestClassifier {
676                        trees: tree_order
677                            .into_iter()
678                            .map(|tree_index| {
679                                Self::from_model(
680                                    &forest.trees()[tree_index],
681                                    feature_index_map,
682                                    missing_features,
683                                )
684                            })
685                            .collect(),
686                        class_labels: forest
687                            .class_labels()
688                            .expect("classification forest stores class labels"),
689                    }
690                }
691                Task::Regression => {
692                    let tree_order = ordered_ensemble_indices(forest.trees());
693                    Self::ForestRegressor {
694                        trees: tree_order
695                            .into_iter()
696                            .map(|tree_index| {
697                                Self::from_model(
698                                    &forest.trees()[tree_index],
699                                    feature_index_map,
700                                    missing_features,
701                                )
702                            })
703                            .collect(),
704                    }
705                }
706            },
707            Model::GradientBoostedTrees(model) => match model.task() {
708                Task::Classification => {
709                    let tree_order = ordered_ensemble_indices(model.trees());
710                    Self::BoostedBinaryClassifier {
711                        trees: tree_order
712                            .iter()
713                            .map(|tree_index| {
714                                Self::from_model(
715                                    &model.trees()[*tree_index],
716                                    feature_index_map,
717                                    missing_features,
718                                )
719                            })
720                            .collect(),
721                        tree_weights: tree_order
722                            .iter()
723                            .map(|tree_index| model.tree_weights()[*tree_index])
724                            .collect(),
725                        base_score: model.base_score(),
726                        class_labels: model
727                            .class_labels()
728                            .expect("classification boosting stores class labels"),
729                    }
730                }
731                Task::Regression => {
732                    let tree_order = ordered_ensemble_indices(model.trees());
733                    Self::BoostedRegressor {
734                        trees: tree_order
735                            .iter()
736                            .map(|tree_index| {
737                                Self::from_model(
738                                    &model.trees()[*tree_index],
739                                    feature_index_map,
740                                    missing_features,
741                                )
742                            })
743                            .collect(),
744                        tree_weights: tree_order
745                            .iter()
746                            .map(|tree_index| model.tree_weights()[*tree_index])
747                            .collect(),
748                        base_score: model.base_score(),
749                    }
750                }
751            },
752        }
753    }
754
755    fn from_classifier(
756        classifier: &DecisionTreeClassifier,
757        feature_index_map: &[usize],
758        missing_features: Option<&BTreeSet<usize>>,
759    ) -> Self {
760        match classifier.structure() {
761            tree::classifier::TreeStructure::Standard { nodes, root } => {
762                if classifier_nodes_are_binary_only(nodes) {
763                    return Self::BinaryClassifier {
764                        nodes: build_binary_classifier_layout(
765                            nodes,
766                            *root,
767                            classifier.class_labels(),
768                            feature_index_map,
769                            classifier.feature_preprocessing(),
770                            missing_features,
771                        ),
772                        class_labels: classifier.class_labels().to_vec(),
773                    };
774                }
775
776                let optimized_nodes = nodes
777                    .iter()
778                    .map(|node| match node {
779                        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
780                            OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
781                                class_counts,
782                            ))
783                        }
784                        tree::classifier::TreeNode::BinarySplit {
785                            feature_index,
786                            threshold_bin,
787                            missing_direction,
788                            left_child,
789                            right_child,
790                            class_counts,
791                            ..
792                        } => OptimizedClassifierNode::Binary {
793                            feature_index: remap_feature_index(*feature_index, feature_index_map),
794                            threshold_bin: *threshold_bin,
795                            children: [*left_child, *right_child],
796                            missing_bin: optimized_missing_bin(
797                                classifier.feature_preprocessing(),
798                                *feature_index,
799                                missing_features,
800                            ),
801                            missing_child: if missing_feature_enabled(
802                                *feature_index,
803                                missing_features,
804                            ) {
805                                match missing_direction {
806                                    tree::shared::MissingBranchDirection::Left => Some(*left_child),
807                                    tree::shared::MissingBranchDirection::Right => {
808                                        Some(*right_child)
809                                    }
810                                    tree::shared::MissingBranchDirection::Node => None,
811                                }
812                            } else {
813                                None
814                            },
815                            missing_probabilities: if missing_feature_enabled(
816                                *feature_index,
817                                missing_features,
818                            ) && matches!(
819                                missing_direction,
820                                tree::shared::MissingBranchDirection::Node
821                            ) {
822                                Some(normalized_probabilities_from_counts(class_counts))
823                            } else {
824                                None
825                            },
826                        },
827                        tree::classifier::TreeNode::MultiwaySplit {
828                            feature_index,
829                            class_counts,
830                            branches,
831                            missing_child,
832                            ..
833                        } => {
834                            let max_bin_index = branches
835                                .iter()
836                                .map(|(bin, _)| usize::from(*bin))
837                                .max()
838                                .unwrap_or(0);
839                            let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
840                            for (bin, child_index) in branches {
841                                child_lookup[usize::from(*bin)] = *child_index;
842                            }
843                            OptimizedClassifierNode::Multiway {
844                                feature_index: remap_feature_index(
845                                    *feature_index,
846                                    feature_index_map,
847                                ),
848                                child_lookup,
849                                max_bin_index,
850                                missing_bin: optimized_missing_bin(
851                                    classifier.feature_preprocessing(),
852                                    *feature_index,
853                                    missing_features,
854                                ),
855                                missing_child: if missing_feature_enabled(
856                                    *feature_index,
857                                    missing_features,
858                                ) {
859                                    *missing_child
860                                } else {
861                                    None
862                                },
863                                fallback_probabilities: normalized_probabilities_from_counts(
864                                    class_counts,
865                                ),
866                            }
867                        }
868                    })
869                    .collect();
870
871                Self::StandardClassifier {
872                    nodes: optimized_nodes,
873                    root: *root,
874                    class_labels: classifier.class_labels().to_vec(),
875                }
876            }
877            tree::classifier::TreeStructure::Oblivious {
878                splits,
879                leaf_class_counts,
880                ..
881            } => Self::ObliviousClassifier {
882                feature_indices: splits
883                    .iter()
884                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
885                    .collect(),
886                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
887                leaf_values: leaf_class_counts
888                    .iter()
889                    .map(|class_counts| normalized_probabilities_from_counts(class_counts))
890                    .collect(),
891                class_labels: classifier.class_labels().to_vec(),
892            },
893        }
894    }
895
896    fn from_regressor(
897        regressor: &DecisionTreeRegressor,
898        feature_index_map: &[usize],
899        missing_features: Option<&BTreeSet<usize>>,
900    ) -> Self {
901        match regressor.structure() {
902            tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
903                Self::BinaryRegressor {
904                    nodes: build_binary_regressor_layout(
905                        nodes,
906                        *root,
907                        feature_index_map,
908                        regressor.feature_preprocessing(),
909                        missing_features,
910                    ),
911                }
912            }
913            tree::regressor::RegressionTreeStructure::Oblivious {
914                splits,
915                leaf_values,
916                ..
917            } => Self::ObliviousRegressor {
918                feature_indices: splits
919                    .iter()
920                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
921                    .collect(),
922                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
923                leaf_values: leaf_values.clone(),
924            },
925        }
926    }
927
928    #[inline(always)]
929    fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
930        match self {
931            OptimizedRuntime::BinaryClassifier { .. }
932            | OptimizedRuntime::StandardClassifier { .. }
933            | OptimizedRuntime::ObliviousClassifier { .. }
934            | OptimizedRuntime::ForestClassifier { .. }
935            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
936                let probabilities = self
937                    .predict_proba_table_row(table, row_index)
938                    .expect("classifier runtime supports probability prediction");
939                class_label_from_probabilities(&probabilities, self.class_labels())
940            }
941            OptimizedRuntime::BinaryRegressor { nodes } => {
942                predict_binary_regressor_row(nodes, |feature_index| {
943                    table.binned_value(feature_index, row_index)
944                })
945            }
946            OptimizedRuntime::ObliviousRegressor {
947                feature_indices,
948                threshold_bins,
949                leaf_values,
950            } => predict_oblivious_row(
951                feature_indices,
952                threshold_bins,
953                leaf_values,
954                |feature_index| table.binned_value(feature_index, row_index),
955            ),
956            OptimizedRuntime::ForestRegressor { trees } => {
957                trees
958                    .iter()
959                    .map(|tree| tree.predict_table_row(table, row_index))
960                    .sum::<f64>()
961                    / trees.len() as f64
962            }
963            OptimizedRuntime::BoostedRegressor {
964                trees,
965                tree_weights,
966                base_score,
967            } => {
968                *base_score
969                    + trees
970                        .iter()
971                        .zip(tree_weights.iter().copied())
972                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
973                        .sum::<f64>()
974            }
975        }
976    }
977
978    #[inline(always)]
979    fn predict_proba_table_row(
980        &self,
981        table: &dyn TableAccess,
982        row_index: usize,
983    ) -> Result<Vec<f64>, PredictError> {
984        match self {
985            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
986                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
987                    table.binned_value(feature_index, row_index)
988                })
989                .to_vec(),
990            ),
991            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
992                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
993                    table.binned_value(feature_index, row_index)
994                })
995                .to_vec(),
996            ),
997            OptimizedRuntime::ObliviousClassifier {
998                feature_indices,
999                threshold_bins,
1000                leaf_values,
1001                ..
1002            } => Ok(predict_oblivious_probabilities_row(
1003                feature_indices,
1004                threshold_bins,
1005                leaf_values,
1006                |feature_index| table.binned_value(feature_index, row_index),
1007            )
1008            .to_vec()),
1009            OptimizedRuntime::ForestClassifier { trees, .. } => {
1010                let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1011                for tree in &trees[1..] {
1012                    let row = tree.predict_proba_table_row(table, row_index)?;
1013                    for (total, value) in totals.iter_mut().zip(row) {
1014                        *total += value;
1015                    }
1016                }
1017                let tree_count = trees.len() as f64;
1018                for value in &mut totals {
1019                    *value /= tree_count;
1020                }
1021                Ok(totals)
1022            }
1023            OptimizedRuntime::BoostedBinaryClassifier {
1024                trees,
1025                tree_weights,
1026                base_score,
1027                ..
1028            } => {
1029                let raw_score = *base_score
1030                    + trees
1031                        .iter()
1032                        .zip(tree_weights.iter().copied())
1033                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1034                        .sum::<f64>();
1035                let positive = sigmoid(raw_score);
1036                Ok(vec![1.0 - positive, positive])
1037            }
1038            OptimizedRuntime::BinaryRegressor { .. }
1039            | OptimizedRuntime::ObliviousRegressor { .. }
1040            | OptimizedRuntime::ForestRegressor { .. }
1041            | OptimizedRuntime::BoostedRegressor { .. } => {
1042                Err(PredictError::ProbabilityPredictionRequiresClassification)
1043            }
1044        }
1045    }
1046
1047    fn predict_proba_table(
1048        &self,
1049        table: &dyn TableAccess,
1050        executor: &InferenceExecutor,
1051    ) -> Result<Vec<Vec<f64>>, PredictError> {
1052        match self {
1053            OptimizedRuntime::BinaryClassifier { .. }
1054            | OptimizedRuntime::StandardClassifier { .. }
1055            | OptimizedRuntime::ObliviousClassifier { .. }
1056            | OptimizedRuntime::ForestClassifier { .. }
1057            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1058                if self.should_use_batch_matrix(table.n_rows()) {
1059                    let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1060                    self.predict_proba_column_major_matrix(&matrix, executor)
1061                } else {
1062                    (0..table.n_rows())
1063                        .map(|row_index| self.predict_proba_table_row(table, row_index))
1064                        .collect()
1065                }
1066            }
1067            OptimizedRuntime::BinaryRegressor { .. }
1068            | OptimizedRuntime::ObliviousRegressor { .. }
1069            | OptimizedRuntime::ForestRegressor { .. }
1070            | OptimizedRuntime::BoostedRegressor { .. } => {
1071                Err(PredictError::ProbabilityPredictionRequiresClassification)
1072            }
1073        }
1074    }
1075
1076    fn predict_column_major_matrix(
1077        &self,
1078        matrix: &ColumnMajorBinnedMatrix,
1079        executor: &InferenceExecutor,
1080    ) -> Vec<f64> {
1081        match self {
1082            OptimizedRuntime::BinaryClassifier { .. }
1083            | OptimizedRuntime::StandardClassifier { .. }
1084            | OptimizedRuntime::ObliviousClassifier { .. }
1085            | OptimizedRuntime::ForestClassifier { .. }
1086            | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1087                .predict_proba_column_major_matrix(matrix, executor)
1088                .expect("classifier runtime supports probability prediction")
1089                .into_iter()
1090                .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1091                .collect(),
1092            OptimizedRuntime::BinaryRegressor { nodes } => {
1093                predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1094            }
1095            OptimizedRuntime::ObliviousRegressor {
1096                feature_indices,
1097                threshold_bins,
1098                leaf_values,
1099            } => predict_oblivious_column_major_matrix(
1100                feature_indices,
1101                threshold_bins,
1102                leaf_values,
1103                matrix,
1104                executor,
1105            ),
1106            OptimizedRuntime::ForestRegressor { trees } => {
1107                let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1108                for tree in &trees[1..] {
1109                    let values = tree.predict_column_major_matrix(matrix, executor);
1110                    for (total, value) in totals.iter_mut().zip(values) {
1111                        *total += value;
1112                    }
1113                }
1114                let tree_count = trees.len() as f64;
1115                for total in &mut totals {
1116                    *total /= tree_count;
1117                }
1118                totals
1119            }
1120            OptimizedRuntime::BoostedRegressor {
1121                trees,
1122                tree_weights,
1123                base_score,
1124            } => {
1125                let mut totals = vec![*base_score; matrix.n_rows];
1126                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1127                    let values = tree.predict_column_major_matrix(matrix, executor);
1128                    for (total, value) in totals.iter_mut().zip(values) {
1129                        *total += weight * value;
1130                    }
1131                }
1132                totals
1133            }
1134        }
1135    }
1136
1137    fn predict_proba_column_major_matrix(
1138        &self,
1139        matrix: &ColumnMajorBinnedMatrix,
1140        executor: &InferenceExecutor,
1141    ) -> Result<Vec<Vec<f64>>, PredictError> {
1142        match self {
1143            OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1144                Ok(predict_binary_classifier_probabilities_column_major_matrix(
1145                    nodes, matrix, executor,
1146                ))
1147            }
1148            OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1149                .map(|row_index| {
1150                    self.predict_proba_binned_row_from_columns(matrix, row_index)
1151                        .expect("classifier runtime supports probability prediction")
1152                })
1153                .collect()),
1154            OptimizedRuntime::ObliviousClassifier {
1155                feature_indices,
1156                threshold_bins,
1157                leaf_values,
1158                ..
1159            } => Ok(predict_oblivious_probabilities_column_major_matrix(
1160                feature_indices,
1161                threshold_bins,
1162                leaf_values,
1163                matrix,
1164                executor,
1165            )),
1166            OptimizedRuntime::ForestClassifier { trees, .. } => {
1167                let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1168                for tree in &trees[1..] {
1169                    let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1170                    for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1171                        for (total, value) in row_totals.iter_mut().zip(row_values) {
1172                            *total += value;
1173                        }
1174                    }
1175                }
1176                let tree_count = trees.len() as f64;
1177                for row in &mut totals {
1178                    for value in row {
1179                        *value /= tree_count;
1180                    }
1181                }
1182                Ok(totals)
1183            }
1184            OptimizedRuntime::BoostedBinaryClassifier {
1185                trees,
1186                tree_weights,
1187                base_score,
1188                ..
1189            } => {
1190                let mut raw_scores = vec![*base_score; matrix.n_rows];
1191                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1192                    let values = tree.predict_column_major_matrix(matrix, executor);
1193                    for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1194                        *raw_score += weight * value;
1195                    }
1196                }
1197                Ok(raw_scores
1198                    .into_iter()
1199                    .map(|raw_score| {
1200                        let positive = sigmoid(raw_score);
1201                        vec![1.0 - positive, positive]
1202                    })
1203                    .collect())
1204            }
1205            OptimizedRuntime::BinaryRegressor { .. }
1206            | OptimizedRuntime::ObliviousRegressor { .. }
1207            | OptimizedRuntime::ForestRegressor { .. }
1208            | OptimizedRuntime::BoostedRegressor { .. } => {
1209                Err(PredictError::ProbabilityPredictionRequiresClassification)
1210            }
1211        }
1212    }
1213
1214    fn class_labels(&self) -> &[f64] {
1215        match self {
1216            OptimizedRuntime::BinaryClassifier { class_labels, .. }
1217            | OptimizedRuntime::StandardClassifier { class_labels, .. }
1218            | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1219            | OptimizedRuntime::ForestClassifier { class_labels, .. }
1220            | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1221            _ => &[],
1222        }
1223    }
1224
1225    #[inline(always)]
1226    fn predict_binned_row_from_columns(
1227        &self,
1228        matrix: &ColumnMajorBinnedMatrix,
1229        row_index: usize,
1230    ) -> f64 {
1231        match self {
1232            OptimizedRuntime::BinaryRegressor { nodes } => {
1233                predict_binary_regressor_row(nodes, |feature_index| {
1234                    matrix.column(feature_index).value_at(row_index)
1235                })
1236            }
1237            OptimizedRuntime::ObliviousRegressor {
1238                feature_indices,
1239                threshold_bins,
1240                leaf_values,
1241            } => predict_oblivious_row(
1242                feature_indices,
1243                threshold_bins,
1244                leaf_values,
1245                |feature_index| matrix.column(feature_index).value_at(row_index),
1246            ),
1247            OptimizedRuntime::BoostedRegressor {
1248                trees,
1249                tree_weights,
1250                base_score,
1251            } => {
1252                *base_score
1253                    + trees
1254                        .iter()
1255                        .zip(tree_weights.iter().copied())
1256                        .map(|(tree, weight)| {
1257                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1258                        })
1259                        .sum::<f64>()
1260            }
1261            _ => self.predict_column_major_matrix(
1262                matrix,
1263                &InferenceExecutor::new(1).expect("inference executor"),
1264            )[row_index],
1265        }
1266    }
1267
1268    #[inline(always)]
1269    fn predict_proba_binned_row_from_columns(
1270        &self,
1271        matrix: &ColumnMajorBinnedMatrix,
1272        row_index: usize,
1273    ) -> Result<Vec<f64>, PredictError> {
1274        match self {
1275            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1276                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1277                    matrix.column(feature_index).value_at(row_index)
1278                })
1279                .to_vec(),
1280            ),
1281            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1282                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1283                    matrix.column(feature_index).value_at(row_index)
1284                })
1285                .to_vec(),
1286            ),
1287            OptimizedRuntime::ObliviousClassifier {
1288                feature_indices,
1289                threshold_bins,
1290                leaf_values,
1291                ..
1292            } => Ok(predict_oblivious_probabilities_row(
1293                feature_indices,
1294                threshold_bins,
1295                leaf_values,
1296                |feature_index| matrix.column(feature_index).value_at(row_index),
1297            )
1298            .to_vec()),
1299            OptimizedRuntime::ForestClassifier { trees, .. } => {
1300                let mut totals =
1301                    trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1302                for tree in &trees[1..] {
1303                    let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1304                    for (total, value) in totals.iter_mut().zip(row) {
1305                        *total += value;
1306                    }
1307                }
1308                let tree_count = trees.len() as f64;
1309                for value in &mut totals {
1310                    *value /= tree_count;
1311                }
1312                Ok(totals)
1313            }
1314            OptimizedRuntime::BoostedBinaryClassifier {
1315                trees,
1316                tree_weights,
1317                base_score,
1318                ..
1319            } => {
1320                let raw_score = *base_score
1321                    + trees
1322                        .iter()
1323                        .zip(tree_weights.iter().copied())
1324                        .map(|(tree, weight)| {
1325                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1326                        })
1327                        .sum::<f64>();
1328                let positive = sigmoid(raw_score);
1329                Ok(vec![1.0 - positive, positive])
1330            }
1331            OptimizedRuntime::BinaryRegressor { .. }
1332            | OptimizedRuntime::ObliviousRegressor { .. }
1333            | OptimizedRuntime::ForestRegressor { .. }
1334            | OptimizedRuntime::BoostedRegressor { .. } => {
1335                Err(PredictError::ProbabilityPredictionRequiresClassification)
1336            }
1337        }
1338    }
1339}
1340
1341#[inline(always)]
1342fn predict_standard_classifier_probabilities_row<F>(
1343    nodes: &[OptimizedClassifierNode],
1344    root: usize,
1345    bin_at: F,
1346) -> &[f64]
1347where
1348    F: Fn(usize) -> u16,
1349{
1350    let mut node_index = root;
1351    loop {
1352        match &nodes[node_index] {
1353            OptimizedClassifierNode::Leaf(value) => return value,
1354            OptimizedClassifierNode::Binary {
1355                feature_index,
1356                threshold_bin,
1357                children,
1358                missing_bin,
1359                missing_child,
1360                missing_probabilities,
1361            } => {
1362                let bin = bin_at(*feature_index);
1363                if missing_bin.is_some_and(|expected| expected == bin) {
1364                    if let Some(probabilities) = missing_probabilities {
1365                        return probabilities;
1366                    }
1367                    if let Some(child_index) = missing_child {
1368                        node_index = *child_index;
1369                        continue;
1370                    }
1371                }
1372                let go_right = usize::from(bin > *threshold_bin);
1373                node_index = children[go_right];
1374            }
1375            OptimizedClassifierNode::Multiway {
1376                feature_index,
1377                child_lookup,
1378                max_bin_index,
1379                missing_bin,
1380                missing_child,
1381                fallback_probabilities,
1382            } => {
1383                let bin_value = bin_at(*feature_index);
1384                if missing_bin.is_some_and(|expected| expected == bin_value) {
1385                    if let Some(child_index) = missing_child {
1386                        node_index = *child_index;
1387                        continue;
1388                    }
1389                    return fallback_probabilities;
1390                }
1391                let bin = usize::from(bin_value);
1392                if bin > *max_bin_index {
1393                    return fallback_probabilities;
1394                }
1395                let child_index = child_lookup[bin];
1396                if child_index == usize::MAX {
1397                    return fallback_probabilities;
1398                }
1399                node_index = child_index;
1400            }
1401        }
1402    }
1403}
1404
1405#[inline(always)]
1406fn predict_binary_classifier_probabilities_row<F>(
1407    nodes: &[OptimizedBinaryClassifierNode],
1408    bin_at: F,
1409) -> &[f64]
1410where
1411    F: Fn(usize) -> u16,
1412{
1413    let mut node_index = 0usize;
1414    loop {
1415        match &nodes[node_index] {
1416            OptimizedBinaryClassifierNode::Leaf(value) => return value,
1417            OptimizedBinaryClassifierNode::Branch {
1418                feature_index,
1419                threshold_bin,
1420                jump_index,
1421                jump_if_greater,
1422                missing_bin,
1423                missing_jump_index,
1424                missing_probabilities,
1425            } => {
1426                let bin = bin_at(*feature_index);
1427                if missing_bin.is_some_and(|expected| expected == bin) {
1428                    if let Some(probabilities) = missing_probabilities {
1429                        return probabilities;
1430                    }
1431                    if let Some(jump_index) = missing_jump_index {
1432                        node_index = *jump_index;
1433                        continue;
1434                    }
1435                }
1436                let go_right = bin > *threshold_bin;
1437                node_index = if go_right == *jump_if_greater {
1438                    *jump_index
1439                } else {
1440                    node_index + 1
1441                };
1442            }
1443        }
1444    }
1445}
1446
1447#[inline(always)]
1448fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1449where
1450    F: Fn(usize) -> u16,
1451{
1452    let mut node_index = 0usize;
1453    loop {
1454        match &nodes[node_index] {
1455            OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1456            OptimizedBinaryRegressorNode::Branch {
1457                feature_index,
1458                threshold_bin,
1459                jump_index,
1460                jump_if_greater,
1461                missing_bin,
1462                missing_jump_index,
1463                missing_value,
1464            } => {
1465                let bin = bin_at(*feature_index);
1466                if missing_bin.is_some_and(|expected| expected == bin) {
1467                    if let Some(value) = missing_value {
1468                        return *value;
1469                    }
1470                    if let Some(jump_index) = missing_jump_index {
1471                        node_index = *jump_index;
1472                        continue;
1473                    }
1474                }
1475                let go_right = bin > *threshold_bin;
1476                node_index = if go_right == *jump_if_greater {
1477                    *jump_index
1478                } else {
1479                    node_index + 1
1480                };
1481            }
1482        }
1483    }
1484}
1485
1486fn predict_binary_classifier_probabilities_column_major_matrix(
1487    nodes: &[OptimizedBinaryClassifierNode],
1488    matrix: &ColumnMajorBinnedMatrix,
1489    _executor: &InferenceExecutor,
1490) -> Vec<Vec<f64>> {
1491    if binary_classifier_nodes_require_rowwise_missing(nodes) {
1492        return (0..matrix.n_rows)
1493            .map(|row_index| {
1494                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1495                    matrix.column(feature_index).value_at(row_index)
1496                })
1497                .to_vec()
1498            })
1499            .collect();
1500    }
1501    (0..matrix.n_rows)
1502        .map(|row_index| {
1503            predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1504                matrix.column(feature_index).value_at(row_index)
1505            })
1506            .to_vec()
1507        })
1508        .collect()
1509}
1510
1511fn predict_binary_regressor_column_major_matrix(
1512    nodes: &[OptimizedBinaryRegressorNode],
1513    matrix: &ColumnMajorBinnedMatrix,
1514    executor: &InferenceExecutor,
1515) -> Vec<f64> {
1516    if binary_regressor_nodes_require_rowwise_missing(nodes) {
1517        return (0..matrix.n_rows)
1518            .map(|row_index| {
1519                predict_binary_regressor_row(nodes, |feature_index| {
1520                    matrix.column(feature_index).value_at(row_index)
1521                })
1522            })
1523            .collect();
1524    }
1525    let mut outputs = vec![0.0; matrix.n_rows];
1526    executor.fill_chunks(
1527        &mut outputs,
1528        STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1529        |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1530    );
1531    outputs
1532}
1533
1534fn predict_binary_regressor_chunk(
1535    nodes: &[OptimizedBinaryRegressorNode],
1536    matrix: &ColumnMajorBinnedMatrix,
1537    start_row: usize,
1538    output: &mut [f64],
1539) {
1540    let mut row_indices: Vec<usize> = (0..output.len()).collect();
1541    let mut stack = vec![(0usize, 0usize, output.len())];
1542
1543    while let Some((node_index, start, end)) = stack.pop() {
1544        match &nodes[node_index] {
1545            OptimizedBinaryRegressorNode::Leaf(value) => {
1546                for position in start..end {
1547                    output[row_indices[position]] = *value;
1548                }
1549            }
1550            OptimizedBinaryRegressorNode::Branch {
1551                feature_index,
1552                threshold_bin,
1553                jump_index,
1554                jump_if_greater,
1555                ..
1556            } => {
1557                let fallthrough_index = node_index + 1;
1558                if *jump_index == fallthrough_index {
1559                    stack.push((fallthrough_index, start, end));
1560                    continue;
1561                }
1562
1563                let column = matrix.column(*feature_index);
1564                let mut partition = start;
1565                let mut jump_start = end;
1566                match column {
1567                    CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1568                        let threshold = *threshold_bin as u8;
1569                        while partition < jump_start {
1570                            let row_offset = row_indices[partition];
1571                            let go_right = values[start_row + row_offset] > threshold;
1572                            let goes_jump = go_right == *jump_if_greater;
1573                            if goes_jump {
1574                                jump_start -= 1;
1575                                row_indices.swap(partition, jump_start);
1576                            } else {
1577                                partition += 1;
1578                            }
1579                        }
1580                    }
1581                    _ => {
1582                        while partition < jump_start {
1583                            let row_offset = row_indices[partition];
1584                            let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1585                            let goes_jump = go_right == *jump_if_greater;
1586                            if goes_jump {
1587                                jump_start -= 1;
1588                                row_indices.swap(partition, jump_start);
1589                            } else {
1590                                partition += 1;
1591                            }
1592                        }
1593                    }
1594                }
1595
1596                if jump_start < end {
1597                    stack.push((*jump_index, jump_start, end));
1598                }
1599                if start < jump_start {
1600                    stack.push((fallthrough_index, start, jump_start));
1601                }
1602            }
1603        }
1604    }
1605}
1606
1607fn binary_classifier_nodes_require_rowwise_missing(
1608    nodes: &[OptimizedBinaryClassifierNode],
1609) -> bool {
1610    nodes.iter().any(|node| match node {
1611        OptimizedBinaryClassifierNode::Leaf(_) => false,
1612        OptimizedBinaryClassifierNode::Branch {
1613            missing_bin,
1614            missing_jump_index,
1615            missing_probabilities,
1616            ..
1617        } => {
1618            missing_bin.is_some() || missing_jump_index.is_some() || missing_probabilities.is_some()
1619        }
1620    })
1621}
1622
1623fn binary_regressor_nodes_require_rowwise_missing(nodes: &[OptimizedBinaryRegressorNode]) -> bool {
1624    nodes.iter().any(|node| match node {
1625        OptimizedBinaryRegressorNode::Leaf(_) => false,
1626        OptimizedBinaryRegressorNode::Branch {
1627            missing_bin,
1628            missing_jump_index,
1629            missing_value,
1630            ..
1631        } => missing_bin.is_some() || missing_jump_index.is_some() || missing_value.is_some(),
1632    })
1633}
1634
1635#[inline(always)]
1636fn predict_oblivious_row<F>(
1637    feature_indices: &[usize],
1638    threshold_bins: &[u16],
1639    leaf_values: &[f64],
1640    bin_at: F,
1641) -> f64
1642where
1643    F: Fn(usize) -> u16,
1644{
1645    let mut leaf_index = 0usize;
1646    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1647        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1648        leaf_index = (leaf_index << 1) | go_right;
1649    }
1650    leaf_values[leaf_index]
1651}
1652
1653#[inline(always)]
1654fn predict_oblivious_probabilities_row<'a, F>(
1655    feature_indices: &[usize],
1656    threshold_bins: &[u16],
1657    leaf_values: &'a [Vec<f64>],
1658    bin_at: F,
1659) -> &'a [f64]
1660where
1661    F: Fn(usize) -> u16,
1662{
1663    let mut leaf_index = 0usize;
1664    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1665        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1666        leaf_index = (leaf_index << 1) | go_right;
1667    }
1668    leaf_values[leaf_index].as_slice()
1669}
1670
1671fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1672    let total = class_counts.iter().sum::<usize>();
1673    if total == 0 {
1674        return vec![0.0; class_counts.len()];
1675    }
1676
1677    class_counts
1678        .iter()
1679        .map(|count| *count as f64 / total as f64)
1680        .collect()
1681}
1682
1683fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1684    let best_index = probabilities
1685        .iter()
1686        .copied()
1687        .enumerate()
1688        .max_by(|(left_index, left), (right_index, right)| {
1689            left.total_cmp(right)
1690                .then_with(|| right_index.cmp(left_index))
1691        })
1692        .map(|(index, _)| index)
1693        .expect("classification probability row is non-empty");
1694    class_labels[best_index]
1695}
1696
1697#[inline(always)]
1698fn sigmoid(value: f64) -> f64 {
1699    if value >= 0.0 {
1700        let exp = (-value).exp();
1701        1.0 / (1.0 + exp)
1702    } else {
1703        let exp = value.exp();
1704        exp / (1.0 + exp)
1705    }
1706}
1707
1708fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1709    nodes.iter().all(|node| {
1710        matches!(
1711            node,
1712            tree::classifier::TreeNode::Leaf { .. }
1713                | tree::classifier::TreeNode::BinarySplit { .. }
1714        )
1715    })
1716}
1717
1718fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1719    match &nodes[node_index] {
1720        tree::classifier::TreeNode::Leaf { sample_count, .. }
1721        | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1722        | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1723    }
1724}
1725
1726fn build_binary_classifier_layout(
1727    nodes: &[tree::classifier::TreeNode],
1728    root: usize,
1729    _class_labels: &[f64],
1730    feature_index_map: &[usize],
1731    preprocessing: &[FeaturePreprocessing],
1732    missing_features: Option<&BTreeSet<usize>>,
1733) -> Vec<OptimizedBinaryClassifierNode> {
1734    let mut layout = Vec::with_capacity(nodes.len());
1735    append_binary_classifier_node(
1736        nodes,
1737        root,
1738        &mut layout,
1739        feature_index_map,
1740        preprocessing,
1741        missing_features,
1742    );
1743    layout
1744}
1745
1746fn append_binary_classifier_node(
1747    nodes: &[tree::classifier::TreeNode],
1748    node_index: usize,
1749    layout: &mut Vec<OptimizedBinaryClassifierNode>,
1750    feature_index_map: &[usize],
1751    preprocessing: &[FeaturePreprocessing],
1752    missing_features: Option<&BTreeSet<usize>>,
1753) -> usize {
1754    let current_index = layout.len();
1755    layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1756
1757    match &nodes[node_index] {
1758        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1759            layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1760                normalized_probabilities_from_counts(class_counts),
1761            );
1762        }
1763        tree::classifier::TreeNode::BinarySplit {
1764            feature_index,
1765            threshold_bin,
1766            missing_direction,
1767            left_child,
1768            right_child,
1769            class_counts,
1770            ..
1771        } => {
1772            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1773                (*left_child, *left_child, true)
1774            } else {
1775                let left_count = classifier_node_sample_count(nodes, *left_child);
1776                let right_count = classifier_node_sample_count(nodes, *right_child);
1777                if left_count >= right_count {
1778                    (*left_child, *right_child, true)
1779                } else {
1780                    (*right_child, *left_child, false)
1781                }
1782            };
1783
1784            let fallthrough_index = append_binary_classifier_node(
1785                nodes,
1786                fallthrough_child,
1787                layout,
1788                feature_index_map,
1789                preprocessing,
1790                missing_features,
1791            );
1792            debug_assert_eq!(fallthrough_index, current_index + 1);
1793            let jump_index = if jump_child == fallthrough_child {
1794                fallthrough_index
1795            } else {
1796                append_binary_classifier_node(
1797                    nodes,
1798                    jump_child,
1799                    layout,
1800                    feature_index_map,
1801                    preprocessing,
1802                    missing_features,
1803                )
1804            };
1805
1806            let missing_bin =
1807                optimized_missing_bin(preprocessing, *feature_index, missing_features);
1808            let (missing_jump_index, missing_probabilities) =
1809                if missing_feature_enabled(*feature_index, missing_features) {
1810                    match missing_direction {
1811                        tree::shared::MissingBranchDirection::Left => (
1812                            Some(if *left_child == fallthrough_child {
1813                                fallthrough_index
1814                            } else {
1815                                jump_index
1816                            }),
1817                            None,
1818                        ),
1819                        tree::shared::MissingBranchDirection::Right => (
1820                            Some(if *right_child == fallthrough_child {
1821                                fallthrough_index
1822                            } else {
1823                                jump_index
1824                            }),
1825                            None,
1826                        ),
1827                        tree::shared::MissingBranchDirection::Node => (
1828                            None,
1829                            Some(normalized_probabilities_from_counts(class_counts)),
1830                        ),
1831                    }
1832                } else {
1833                    (None, None)
1834                };
1835
1836            layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1837                feature_index: remap_feature_index(*feature_index, feature_index_map),
1838                threshold_bin: *threshold_bin,
1839                jump_index,
1840                jump_if_greater,
1841                missing_bin,
1842                missing_jump_index,
1843                missing_probabilities,
1844            };
1845        }
1846        tree::classifier::TreeNode::MultiwaySplit { .. } => {
1847            unreachable!("multiway nodes are filtered out before binary layout construction");
1848        }
1849    }
1850
1851    current_index
1852}
1853
1854fn regressor_node_sample_count(
1855    nodes: &[tree::regressor::RegressionNode],
1856    node_index: usize,
1857) -> usize {
1858    match &nodes[node_index] {
1859        tree::regressor::RegressionNode::Leaf { sample_count, .. }
1860        | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1861    }
1862}
1863
1864fn build_binary_regressor_layout(
1865    nodes: &[tree::regressor::RegressionNode],
1866    root: usize,
1867    feature_index_map: &[usize],
1868    preprocessing: &[FeaturePreprocessing],
1869    missing_features: Option<&BTreeSet<usize>>,
1870) -> Vec<OptimizedBinaryRegressorNode> {
1871    let mut layout = Vec::with_capacity(nodes.len());
1872    append_binary_regressor_node(
1873        nodes,
1874        root,
1875        &mut layout,
1876        feature_index_map,
1877        preprocessing,
1878        missing_features,
1879    );
1880    layout
1881}
1882
1883fn append_binary_regressor_node(
1884    nodes: &[tree::regressor::RegressionNode],
1885    node_index: usize,
1886    layout: &mut Vec<OptimizedBinaryRegressorNode>,
1887    feature_index_map: &[usize],
1888    preprocessing: &[FeaturePreprocessing],
1889    missing_features: Option<&BTreeSet<usize>>,
1890) -> usize {
1891    let current_index = layout.len();
1892    layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1893
1894    match &nodes[node_index] {
1895        tree::regressor::RegressionNode::Leaf { value, .. } => {
1896            layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1897        }
1898        tree::regressor::RegressionNode::BinarySplit {
1899            feature_index,
1900            threshold_bin,
1901            missing_direction,
1902            missing_value,
1903            left_child,
1904            right_child,
1905            ..
1906        } => {
1907            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1908                (*left_child, *left_child, true)
1909            } else {
1910                let left_count = regressor_node_sample_count(nodes, *left_child);
1911                let right_count = regressor_node_sample_count(nodes, *right_child);
1912                if left_count >= right_count {
1913                    (*left_child, *right_child, true)
1914                } else {
1915                    (*right_child, *left_child, false)
1916                }
1917            };
1918
1919            let fallthrough_index = append_binary_regressor_node(
1920                nodes,
1921                fallthrough_child,
1922                layout,
1923                feature_index_map,
1924                preprocessing,
1925                missing_features,
1926            );
1927            debug_assert_eq!(fallthrough_index, current_index + 1);
1928            let jump_index = if jump_child == fallthrough_child {
1929                fallthrough_index
1930            } else {
1931                append_binary_regressor_node(
1932                    nodes,
1933                    jump_child,
1934                    layout,
1935                    feature_index_map,
1936                    preprocessing,
1937                    missing_features,
1938                )
1939            };
1940
1941            let missing_bin =
1942                optimized_missing_bin(preprocessing, *feature_index, missing_features);
1943            let (missing_jump_index, missing_value) =
1944                if missing_feature_enabled(*feature_index, missing_features) {
1945                    match missing_direction {
1946                        tree::shared::MissingBranchDirection::Left => (
1947                            Some(if *left_child == fallthrough_child {
1948                                fallthrough_index
1949                            } else {
1950                                jump_index
1951                            }),
1952                            None,
1953                        ),
1954                        tree::shared::MissingBranchDirection::Right => (
1955                            Some(if *right_child == fallthrough_child {
1956                                fallthrough_index
1957                            } else {
1958                                jump_index
1959                            }),
1960                            None,
1961                        ),
1962                        tree::shared::MissingBranchDirection::Node => (None, Some(*missing_value)),
1963                    }
1964                } else {
1965                    (None, None)
1966                };
1967
1968            layout[current_index] = OptimizedBinaryRegressorNode::Branch {
1969                feature_index: remap_feature_index(*feature_index, feature_index_map),
1970                threshold_bin: *threshold_bin,
1971                jump_index,
1972                jump_if_greater,
1973                missing_bin,
1974                missing_jump_index,
1975                missing_value,
1976            };
1977        }
1978    }
1979
1980    current_index
1981}
1982
1983fn predict_oblivious_column_major_matrix(
1984    feature_indices: &[usize],
1985    threshold_bins: &[u16],
1986    leaf_values: &[f64],
1987    matrix: &ColumnMajorBinnedMatrix,
1988    executor: &InferenceExecutor,
1989) -> Vec<f64> {
1990    let mut outputs = vec![0.0; matrix.n_rows];
1991    executor.fill_chunks(
1992        &mut outputs,
1993        PARALLEL_INFERENCE_CHUNK_ROWS,
1994        |start_row, chunk| {
1995            predict_oblivious_chunk(
1996                feature_indices,
1997                threshold_bins,
1998                leaf_values,
1999                matrix,
2000                start_row,
2001                chunk,
2002            )
2003        },
2004    );
2005    outputs
2006}
2007
2008fn predict_oblivious_probabilities_column_major_matrix(
2009    feature_indices: &[usize],
2010    threshold_bins: &[u16],
2011    leaf_values: &[Vec<f64>],
2012    matrix: &ColumnMajorBinnedMatrix,
2013    _executor: &InferenceExecutor,
2014) -> Vec<Vec<f64>> {
2015    (0..matrix.n_rows)
2016        .map(|row_index| {
2017            predict_oblivious_probabilities_row(
2018                feature_indices,
2019                threshold_bins,
2020                leaf_values,
2021                |feature_index| matrix.column(feature_index).value_at(row_index),
2022            )
2023            .to_vec()
2024        })
2025        .collect()
2026}
2027
2028fn predict_oblivious_chunk(
2029    feature_indices: &[usize],
2030    threshold_bins: &[u16],
2031    leaf_values: &[f64],
2032    matrix: &ColumnMajorBinnedMatrix,
2033    start_row: usize,
2034    output: &mut [f64],
2035) {
2036    let processed = simd_predict_oblivious_chunk(
2037        feature_indices,
2038        threshold_bins,
2039        leaf_values,
2040        matrix,
2041        start_row,
2042        output,
2043    );
2044
2045    for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2046        let row_index = start_row + offset;
2047        *out = predict_oblivious_row(
2048            feature_indices,
2049            threshold_bins,
2050            leaf_values,
2051            |feature_index| matrix.column(feature_index).value_at(row_index),
2052        );
2053    }
2054}
2055
2056fn simd_predict_oblivious_chunk(
2057    feature_indices: &[usize],
2058    threshold_bins: &[u16],
2059    leaf_values: &[f64],
2060    matrix: &ColumnMajorBinnedMatrix,
2061    start_row: usize,
2062    output: &mut [f64],
2063) -> usize {
2064    let mut processed = 0usize;
2065    let ones = u32x8::splat(1);
2066
2067    while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2068        let base_row = start_row + processed;
2069        let mut leaf_indices = u32x8::splat(0);
2070
2071        for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2072            let column = matrix.column(feature_index);
2073            let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2074                let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2075                    .try_into()
2076                    .expect("lane width matches the fixed SIMD width");
2077                u32x8::new([
2078                    u32::from(lanes[0]),
2079                    u32::from(lanes[1]),
2080                    u32::from(lanes[2]),
2081                    u32::from(lanes[3]),
2082                    u32::from(lanes[4]),
2083                    u32::from(lanes[5]),
2084                    u32::from(lanes[6]),
2085                    u32::from(lanes[7]),
2086                ])
2087            } else {
2088                let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2089                    .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2090                    .expect("column is u16 when not u8")
2091                    .try_into()
2092                    .expect("lane width matches the fixed SIMD width");
2093                u32x8::from(u16x8::new(lanes))
2094            };
2095            let threshold = u32x8::splat(u32::from(threshold_bin));
2096            let bit = bins.cmp_gt(threshold) & ones;
2097            leaf_indices = (leaf_indices << 1) | bit;
2098        }
2099
2100        let lane_indices = leaf_indices.to_array();
2101        for lane in 0..OBLIVIOUS_SIMD_LANES {
2102            output[processed + lane] =
2103                leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2104        }
2105        processed += OBLIVIOUS_SIMD_LANES;
2106    }
2107
2108    processed
2109}
2110
2111pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2112    training::train(train_set, config)
2113}
2114
2115#[cfg(test)]
2116mod tests;