Skip to main content

forestfire_core/
lib.rs

1//! ForestFire core model, training, inference, and interchange layer.
2//!
3//! The crate is organized around a few stable abstractions:
4//!
5//! - [`forestfire_data::TableAccess`] is the common data boundary for both
6//!   training and inference.
7//! - [`TrainConfig`] is the normalized configuration surface shared by the Rust
8//!   and Python APIs.
9//! - [`Model`] is the semantic model view used for exact prediction,
10//!   serialization, and introspection.
11//! - [`OptimizedModel`] is a lowered runtime view used when prediction speed
12//!   matters more than preserving the original tree layout.
13//!
14//! Keeping the semantic model and the runtime model separate is deliberate. It
15//! makes export and introspection straightforward while still allowing the
16//! optimized path to use layouts that are awkward to serialize directly.
17
18use forestfire_data::{
19    BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
20    numeric_missing_bin,
21};
22#[cfg(feature = "polars")]
23use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
24use rayon::ThreadPoolBuilder;
25use rayon::prelude::*;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use std::collections::{BTreeMap, BTreeSet};
29use std::error::Error;
30use std::fmt::{Display, Formatter};
31use std::sync::Arc;
32use wide::{u16x8, u32x8};
33
34mod boosting;
35mod bootstrap;
36mod compiled_artifact;
37mod forest;
38mod inference_input;
39mod introspection;
40pub mod ir;
41mod model_api;
42mod optimized_runtime;
43mod runtime_planning;
44mod sampling;
45mod training;
46pub mod tree;
47
48pub use boosting::BoostingError;
49pub use boosting::GradientBoostedTrees;
50pub use compiled_artifact::CompiledArtifactError;
51pub use forest::RandomForest;
52pub use introspection::IntrospectionError;
53pub use introspection::PredictionHistogramEntry;
54pub use introspection::PredictionValueStats;
55pub use introspection::TreeStructureSummary;
56pub use ir::IrError;
57pub use ir::ModelPackageIr;
58pub use model_api::OptimizedModel;
59pub use tree::classifier::DecisionTreeAlgorithm;
60pub use tree::classifier::DecisionTreeClassifier;
61pub use tree::classifier::DecisionTreeError;
62pub use tree::classifier::DecisionTreeOptions;
63pub use tree::classifier::train_c45;
64pub use tree::classifier::train_cart;
65pub use tree::classifier::train_id3;
66pub use tree::classifier::train_oblivious;
67pub use tree::classifier::train_randomized;
68pub use tree::regressor::DecisionTreeRegressor;
69pub use tree::regressor::RegressionTreeAlgorithm;
70pub use tree::regressor::RegressionTreeError;
71pub use tree::regressor::RegressionTreeOptions;
72pub use tree::regressor::train_cart_regressor;
73pub use tree::regressor::train_oblivious_regressor;
74pub use tree::regressor::train_randomized_regressor;
75#[cfg(feature = "polars")]
76const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
77pub(crate) use inference_input::ColumnMajorBinnedMatrix;
78pub(crate) use inference_input::CompactBinnedColumn;
79pub(crate) use inference_input::InferenceTable;
80pub(crate) use inference_input::ProjectedTableView;
81#[cfg(feature = "polars")]
82pub(crate) use inference_input::polars_named_columns;
83pub(crate) use introspection::prediction_value_stats;
84pub(crate) use introspection::tree_structure_summary;
85pub(crate) use optimized_runtime::InferenceExecutor;
86pub(crate) use optimized_runtime::OBLIVIOUS_SIMD_LANES;
87pub(crate) use optimized_runtime::OptimizedBinaryClassifierNode;
88pub(crate) use optimized_runtime::OptimizedBinaryRegressorNode;
89pub(crate) use optimized_runtime::OptimizedClassifierNode;
90pub(crate) use optimized_runtime::OptimizedRuntime;
91pub(crate) use optimized_runtime::PARALLEL_INFERENCE_CHUNK_ROWS;
92pub(crate) use optimized_runtime::STANDARD_BATCH_INFERENCE_CHUNK_ROWS;
93pub(crate) use optimized_runtime::resolve_inference_thread_count;
94pub(crate) use runtime_planning::build_feature_index_map;
95pub(crate) use runtime_planning::build_feature_projection;
96pub(crate) use runtime_planning::model_used_feature_indices;
97pub(crate) use runtime_planning::ordered_ensemble_indices;
98pub(crate) use runtime_planning::remap_feature_index;
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
101pub enum TrainAlgorithm {
102    /// Train a single decision tree.
103    Dt,
104    /// Train a bootstrap-aggregated random forest.
105    Rf,
106    /// Train a second-order gradient-boosted ensemble.
107    Gbm,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum Criterion {
112    /// Let training choose the appropriate criterion for the requested setup.
113    Auto,
114    /// Gini impurity for classification.
115    Gini,
116    /// Entropy / information gain for classification.
117    Entropy,
118    /// Mean-based regression criterion.
119    Mean,
120    /// Median-based regression criterion.
121    Median,
122    /// Internal second-order criterion used by gradient boosting.
123    SecondOrder,
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum Task {
128    /// Predict a continuous numeric value.
129    Regression,
130    /// Predict one label from a finite set.
131    Classification,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub enum TreeType {
136    /// Multiway information-gain tree.
137    Id3,
138    /// C4.5-style multiway tree.
139    C45,
140    /// Standard binary threshold tree.
141    Cart,
142    /// CART-style tree with randomized candidate selection.
143    Randomized,
144    /// Symmetric tree where every level shares the same split.
145    Oblivious,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum MaxFeatures {
150    /// Task-aware default: `sqrt` for classification, `third` for regression.
151    Auto,
152    /// Use all features at each split.
153    All,
154    /// Use `floor(sqrt(feature_count))` features.
155    Sqrt,
156    /// Use roughly one third of the features.
157    Third,
158    /// Use exactly this many features, capped to the available count.
159    Count(usize),
160}
161
162impl MaxFeatures {
163    pub fn resolve(self, task: Task, feature_count: usize) -> usize {
164        match self {
165            MaxFeatures::Auto => match task {
166                Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
167                Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
168            },
169            MaxFeatures::All => feature_count.max(1),
170            MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
171            MaxFeatures::Third => (feature_count / 3).max(1),
172            MaxFeatures::Count(count) => count.min(feature_count).max(1),
173        }
174    }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq)]
178pub enum CanaryFilter {
179    /// Require the chosen real feature to appear within the top `n` scored splits.
180    TopN(usize),
181    /// Require the chosen real feature to appear within the top fraction of scored splits.
182    TopFraction(f64),
183}
184
185impl Default for CanaryFilter {
186    fn default() -> Self {
187        Self::TopN(1)
188    }
189}
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum InputFeatureKind {
194    /// Numeric features are compared through their binned representation.
195    Numeric,
196    /// Binary features stay boolean all the way through the pipeline.
197    Binary,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
201pub struct NumericBinBoundary {
202    /// Bin identifier in the preprocessed feature space.
203    pub bin: u16,
204    /// Largest raw floating-point value that still belongs to this bin.
205    pub upper_bound: f64,
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209#[serde(tag = "kind", rename_all = "snake_case")]
210pub enum FeaturePreprocessing {
211    /// Numeric features are represented by explicit bin boundaries.
212    Numeric {
213        bin_boundaries: Vec<NumericBinBoundary>,
214        missing_bin: u16,
215    },
216    /// Binary features do not require numeric bin boundaries.
217    Binary,
218}
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum MissingValueStrategy {
222    Heuristic,
223    Optimal,
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
227pub enum MissingValueStrategyConfig {
228    Global(MissingValueStrategy),
229    PerFeature(BTreeMap<usize, MissingValueStrategy>),
230}
231
232impl MissingValueStrategyConfig {
233    pub fn heuristic() -> Self {
234        Self::Global(MissingValueStrategy::Heuristic)
235    }
236
237    pub fn optimal() -> Self {
238        Self::Global(MissingValueStrategy::Optimal)
239    }
240
241    pub fn resolve_for_feature_count(
242        &self,
243        feature_count: usize,
244    ) -> Result<Vec<MissingValueStrategy>, TrainError> {
245        match self {
246            MissingValueStrategyConfig::Global(strategy) => Ok(vec![*strategy; feature_count]),
247            MissingValueStrategyConfig::PerFeature(strategies) => {
248                let mut resolved = vec![MissingValueStrategy::Heuristic; feature_count];
249                for (&feature_index, &strategy) in strategies {
250                    if feature_index >= feature_count {
251                        return Err(TrainError::InvalidMissingValueStrategyFeature {
252                            feature_index,
253                            feature_count,
254                        });
255                    }
256                    resolved[feature_index] = strategy;
257                }
258                Ok(resolved)
259            }
260        }
261    }
262}
263
264/// Unified training configuration shared by the Rust and Python entry points.
265///
266/// The crate keeps one normalized config type so the binding layer only has to
267/// perform input validation and type conversion; all semantic decisions happen
268/// from this one structure downward.
269#[derive(Debug, Clone)]
270pub struct TrainConfig {
271    /// High-level training family.
272    pub algorithm: TrainAlgorithm,
273    /// Regression or classification.
274    pub task: Task,
275    /// Tree learner used by the selected algorithm family.
276    pub tree_type: TreeType,
277    /// Split criterion. [`Criterion::Auto`] is resolved by the trainer.
278    pub criterion: Criterion,
279    /// Maximum tree depth.
280    pub max_depth: Option<usize>,
281    /// Smallest node size that is still allowed to split.
282    pub min_samples_split: Option<usize>,
283    /// Minimum child size after a split.
284    pub min_samples_leaf: Option<usize>,
285    /// Optional cap on training-side rayon threads.
286    pub physical_cores: Option<usize>,
287    /// Number of trees for ensemble algorithms.
288    pub n_trees: Option<usize>,
289    /// Feature subsampling strategy.
290    pub max_features: MaxFeatures,
291    /// Seed used for reproducible sampling and randomized splits.
292    pub seed: Option<u64>,
293    /// Window used when skipping canaries during split selection.
294    pub canary_filter: CanaryFilter,
295    /// Whether random forests should compute out-of-bag metrics.
296    pub compute_oob: bool,
297    /// Gradient boosting shrinkage factor.
298    pub learning_rate: Option<f64>,
299    /// Whether gradient boosting should bootstrap rows before gradient sampling.
300    pub bootstrap: bool,
301    /// Fraction of largest-gradient rows always kept by GOSS sampling.
302    pub top_gradient_fraction: Option<f64>,
303    /// Fraction of the remaining rows randomly retained by GOSS sampling.
304    pub other_gradient_fraction: Option<f64>,
305    /// Strategy used to evaluate missing-value routing during split search.
306    pub missing_value_strategy: MissingValueStrategyConfig,
307    /// Optional numeric histogram bin configuration for training-time split search.
308    ///
309    /// `None` preserves the incoming table's existing numeric bins. `Some(...)`
310    /// rebuilds the numeric training view at the requested resolution before
311    /// fitting, while leaving the caller's source table unchanged.
312    pub histogram_bins: Option<NumericBins>,
313}
314
315impl Default for TrainConfig {
316    fn default() -> Self {
317        Self {
318            algorithm: TrainAlgorithm::Dt,
319            task: Task::Regression,
320            tree_type: TreeType::Cart,
321            criterion: Criterion::Auto,
322            max_depth: None,
323            min_samples_split: None,
324            min_samples_leaf: None,
325            physical_cores: None,
326            n_trees: None,
327            max_features: MaxFeatures::Auto,
328            seed: None,
329            canary_filter: CanaryFilter::default(),
330            compute_oob: false,
331            learning_rate: None,
332            bootstrap: false,
333            top_gradient_fraction: None,
334            other_gradient_fraction: None,
335            missing_value_strategy: MissingValueStrategyConfig::heuristic(),
336            histogram_bins: None,
337        }
338    }
339}
340
341/// Top-level semantic model enum.
342///
343/// This type stays close to the learned structure rather than the fastest
344/// possible runtime layout. That is what makes it suitable for introspection,
345/// serialization, and exact behavior parity across bindings.
346#[derive(Debug, Clone)]
347pub enum Model {
348    DecisionTreeClassifier(DecisionTreeClassifier),
349    DecisionTreeRegressor(DecisionTreeRegressor),
350    RandomForest(RandomForest),
351    GradientBoostedTrees(GradientBoostedTrees),
352}
353
354#[derive(Debug)]
355pub enum TrainError {
356    DecisionTree(DecisionTreeError),
357    RegressionTree(RegressionTreeError),
358    Boosting(BoostingError),
359    InvalidPhysicalCoreCount {
360        requested: usize,
361        available: usize,
362    },
363    ThreadPoolBuildFailed(String),
364    UnsupportedConfiguration {
365        task: Task,
366        tree_type: TreeType,
367        criterion: Criterion,
368    },
369    InvalidMaxDepth(usize),
370    InvalidMinSamplesSplit(usize),
371    InvalidMinSamplesLeaf(usize),
372    InvalidTreeCount(usize),
373    InvalidMaxFeatures(usize),
374    InvalidCanaryFilterTopN(usize),
375    InvalidCanaryFilterTopFraction(f64),
376    InvalidMissingValueStrategyFeature {
377        feature_index: usize,
378        feature_count: usize,
379    },
380}
381
382impl Display for TrainError {
383    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
384        match self {
385            TrainError::DecisionTree(err) => err.fmt(f),
386            TrainError::RegressionTree(err) => err.fmt(f),
387            TrainError::Boosting(err) => err.fmt(f),
388            TrainError::InvalidPhysicalCoreCount {
389                requested,
390                available,
391            } => write!(
392                f,
393                "Requested {} physical cores, but the available physical core count is {}.",
394                requested, available
395            ),
396            TrainError::ThreadPoolBuildFailed(message) => {
397                write!(f, "Failed to build training thread pool: {}.", message)
398            }
399            TrainError::UnsupportedConfiguration {
400                task,
401                tree_type,
402                criterion,
403            } => write!(
404                f,
405                "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
406                task, tree_type, criterion
407            ),
408            TrainError::InvalidMaxDepth(value) => {
409                write!(f, "max_depth must be at least 1. Received {}.", value)
410            }
411            TrainError::InvalidMinSamplesSplit(value) => {
412                write!(
413                    f,
414                    "min_samples_split must be at least 1. Received {}.",
415                    value
416                )
417            }
418            TrainError::InvalidMinSamplesLeaf(value) => {
419                write!(
420                    f,
421                    "min_samples_leaf must be at least 1. Received {}.",
422                    value
423                )
424            }
425            TrainError::InvalidTreeCount(n_trees) => {
426                write!(
427                    f,
428                    "Random forest requires at least one tree. Received {}.",
429                    n_trees
430                )
431            }
432            TrainError::InvalidMaxFeatures(count) => {
433                write!(
434                    f,
435                    "max_features must be at least 1 when provided as an integer. Received {}.",
436                    count
437                )
438            }
439            TrainError::InvalidCanaryFilterTopN(count) => {
440                write!(
441                    f,
442                    "canary_filter top-n must be at least 1. Received {}.",
443                    count
444                )
445            }
446            TrainError::InvalidCanaryFilterTopFraction(fraction) => {
447                write!(
448                    f,
449                    "canary_filter top fraction must be finite and in (0, 1]. Received {}.",
450                    fraction
451                )
452            }
453            TrainError::InvalidMissingValueStrategyFeature {
454                feature_index,
455                feature_count,
456            } => write!(
457                f,
458                "missing_value_strategy references feature {}, but the training table only has {} features.",
459                feature_index, feature_count
460            ),
461        }
462    }
463}
464
465impl Error for TrainError {}
466
467impl CanaryFilter {
468    pub(crate) fn validate(self) -> Result<(), TrainError> {
469        match self {
470            Self::TopN(0) => Err(TrainError::InvalidCanaryFilterTopN(0)),
471            Self::TopN(_) => Ok(()),
472            Self::TopFraction(fraction)
473                if !fraction.is_finite() || fraction <= 0.0 || fraction > 1.0 =>
474            {
475                Err(TrainError::InvalidCanaryFilterTopFraction(fraction))
476            }
477            Self::TopFraction(_) => Ok(()),
478        }
479    }
480
481    pub(crate) fn selection_size(self, candidate_count: usize) -> usize {
482        if candidate_count == 0 {
483            return 0;
484        }
485
486        match self {
487            Self::TopN(count) => count.clamp(1, candidate_count),
488            Self::TopFraction(fraction) => {
489                ((fraction * candidate_count as f64).ceil() as usize).clamp(1, candidate_count)
490            }
491        }
492    }
493}
494
495#[derive(Debug, Clone, PartialEq)]
496pub enum PredictError {
497    ProbabilityPredictionRequiresClassification,
498    RaggedRows {
499        row: usize,
500        expected: usize,
501        actual: usize,
502    },
503    FeatureCountMismatch {
504        expected: usize,
505        actual: usize,
506    },
507    ColumnLengthMismatch {
508        feature: String,
509        expected: usize,
510        actual: usize,
511    },
512    MissingFeature(String),
513    UnexpectedFeature(String),
514    InvalidBinaryValue {
515        feature_index: usize,
516        row_index: usize,
517        value: f64,
518    },
519    NullValue {
520        feature: String,
521        row_index: usize,
522    },
523    UnsupportedFeatureType {
524        feature: String,
525        dtype: String,
526    },
527    Polars(String),
528}
529
530impl Display for PredictError {
531    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
532        match self {
533            PredictError::ProbabilityPredictionRequiresClassification => write!(
534                f,
535                "predict_proba is only available for classification models."
536            ),
537            PredictError::RaggedRows {
538                row,
539                expected,
540                actual,
541            } => write!(
542                f,
543                "Ragged inference row at index {}: expected {} columns, found {}.",
544                row, expected, actual
545            ),
546            PredictError::FeatureCountMismatch { expected, actual } => write!(
547                f,
548                "Inference input has {} features, but the model expects {}.",
549                actual, expected
550            ),
551            PredictError::ColumnLengthMismatch {
552                feature,
553                expected,
554                actual,
555            } => write!(
556                f,
557                "Feature '{}' has {} values, expected {}.",
558                feature, actual, expected
559            ),
560            PredictError::MissingFeature(feature) => {
561                write!(f, "Missing required feature '{}'.", feature)
562            }
563            PredictError::UnexpectedFeature(feature) => {
564                write!(f, "Unexpected feature '{}'.", feature)
565            }
566            PredictError::InvalidBinaryValue {
567                feature_index,
568                row_index,
569                value,
570            } => write!(
571                f,
572                "Feature {} at row {} must be binary for inference, found {}.",
573                feature_index, row_index, value
574            ),
575            PredictError::NullValue { feature, row_index } => write!(
576                f,
577                "Feature '{}' contains a null value at row {}.",
578                feature, row_index
579            ),
580            PredictError::UnsupportedFeatureType { feature, dtype } => write!(
581                f,
582                "Feature '{}' has unsupported dtype '{}'.",
583                feature, dtype
584            ),
585            PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
586        }
587    }
588}
589
590impl Error for PredictError {}
591
592#[cfg(feature = "polars")]
593impl From<polars::error::PolarsError> for PredictError {
594    fn from(value: polars::error::PolarsError) -> Self {
595        PredictError::Polars(value.to_string())
596    }
597}
598
599#[derive(Debug)]
600pub enum OptimizeError {
601    InvalidPhysicalCoreCount { requested: usize, available: usize },
602    ThreadPoolBuildFailed(String),
603    UnsupportedModelType(&'static str),
604}
605
606impl Display for OptimizeError {
607    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
608        match self {
609            OptimizeError::InvalidPhysicalCoreCount {
610                requested,
611                available,
612            } => write!(
613                f,
614                "Requested {} physical cores, but the available physical core count is {}.",
615                requested, available
616            ),
617            OptimizeError::ThreadPoolBuildFailed(message) => {
618                write!(f, "Failed to build inference thread pool: {}.", message)
619            }
620            OptimizeError::UnsupportedModelType(model_type) => {
621                write!(
622                    f,
623                    "Optimized inference is not supported for model type '{}'.",
624                    model_type
625                )
626            }
627        }
628    }
629}
630
631impl Error for OptimizeError {}
632
633#[derive(Debug, Clone, Copy)]
634pub(crate) struct Parallelism {
635    thread_count: usize,
636}
637
638impl Parallelism {
639    pub(crate) fn sequential() -> Self {
640        Self { thread_count: 1 }
641    }
642
643    #[cfg(test)]
644    pub(crate) fn with_threads(thread_count: usize) -> Self {
645        Self {
646            thread_count: thread_count.max(1),
647        }
648    }
649
650    pub(crate) fn enabled(self) -> bool {
651        self.thread_count > 1
652    }
653}
654
655pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
656    (0..table.n_features())
657        .map(|feature_index| {
658            if table.is_binary_feature(feature_index) {
659                FeaturePreprocessing::Binary
660            } else {
661                let values = (0..table.n_rows())
662                    .map(|row_index| table.feature_value(feature_index, row_index))
663                    .collect::<Vec<_>>();
664                FeaturePreprocessing::Numeric {
665                    bin_boundaries: numeric_bin_boundaries(
666                        &values,
667                        NumericBins::Fixed(table.numeric_bin_cap()),
668                    )
669                    .into_iter()
670                    .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
671                    .collect(),
672                    missing_bin: numeric_missing_bin(NumericBins::Fixed(table.numeric_bin_cap())),
673                }
674            }
675        })
676        .collect()
677}
678
679fn missing_feature_enabled(
680    feature_index: usize,
681    missing_features: Option<&BTreeSet<usize>>,
682) -> bool {
683    missing_features.is_none_or(|features| features.contains(&feature_index))
684}
685
686fn optimized_missing_bin(
687    preprocessing: &[FeaturePreprocessing],
688    feature_index: usize,
689    missing_features: Option<&BTreeSet<usize>>,
690) -> Option<u16> {
691    if !missing_feature_enabled(feature_index, missing_features) {
692        return None;
693    }
694
695    match preprocessing.get(feature_index) {
696        Some(FeaturePreprocessing::Binary) => Some(forestfire_data::BINARY_MISSING_BIN),
697        Some(FeaturePreprocessing::Numeric { missing_bin, .. }) => Some(*missing_bin),
698        None => None,
699    }
700}
701
702impl OptimizedRuntime {
703    fn supports_batch_matrix(&self) -> bool {
704        matches!(
705            self,
706            OptimizedRuntime::BinaryClassifier { .. }
707                | OptimizedRuntime::BinaryRegressor { .. }
708                | OptimizedRuntime::ObliviousClassifier { .. }
709                | OptimizedRuntime::ObliviousRegressor { .. }
710                | OptimizedRuntime::ForestClassifier { .. }
711                | OptimizedRuntime::ForestRegressor { .. }
712                | OptimizedRuntime::BoostedBinaryClassifier { .. }
713                | OptimizedRuntime::BoostedRegressor { .. }
714        )
715    }
716
717    fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
718        n_rows > 1 && self.supports_batch_matrix()
719    }
720
721    fn from_model(
722        model: &Model,
723        feature_index_map: &[usize],
724        missing_features: Option<&BTreeSet<usize>>,
725    ) -> Self {
726        match model {
727            Model::DecisionTreeClassifier(classifier) => {
728                Self::from_classifier(classifier, feature_index_map, missing_features)
729            }
730            Model::DecisionTreeRegressor(regressor) => {
731                Self::from_regressor(regressor, feature_index_map, missing_features)
732            }
733            Model::RandomForest(forest) => match forest.task() {
734                Task::Classification => {
735                    let tree_order = ordered_ensemble_indices(forest.trees());
736                    Self::ForestClassifier {
737                        trees: tree_order
738                            .into_iter()
739                            .map(|tree_index| {
740                                Self::from_model(
741                                    &forest.trees()[tree_index],
742                                    feature_index_map,
743                                    missing_features,
744                                )
745                            })
746                            .collect(),
747                        class_labels: forest
748                            .class_labels()
749                            .expect("classification forest stores class labels"),
750                    }
751                }
752                Task::Regression => {
753                    let tree_order = ordered_ensemble_indices(forest.trees());
754                    Self::ForestRegressor {
755                        trees: tree_order
756                            .into_iter()
757                            .map(|tree_index| {
758                                Self::from_model(
759                                    &forest.trees()[tree_index],
760                                    feature_index_map,
761                                    missing_features,
762                                )
763                            })
764                            .collect(),
765                    }
766                }
767            },
768            Model::GradientBoostedTrees(model) => match model.task() {
769                Task::Classification => {
770                    let tree_order = ordered_ensemble_indices(model.trees());
771                    Self::BoostedBinaryClassifier {
772                        trees: tree_order
773                            .iter()
774                            .map(|tree_index| {
775                                Self::from_model(
776                                    &model.trees()[*tree_index],
777                                    feature_index_map,
778                                    missing_features,
779                                )
780                            })
781                            .collect(),
782                        tree_weights: tree_order
783                            .iter()
784                            .map(|tree_index| model.tree_weights()[*tree_index])
785                            .collect(),
786                        base_score: model.base_score(),
787                        class_labels: model
788                            .class_labels()
789                            .expect("classification boosting stores class labels"),
790                    }
791                }
792                Task::Regression => {
793                    let tree_order = ordered_ensemble_indices(model.trees());
794                    Self::BoostedRegressor {
795                        trees: tree_order
796                            .iter()
797                            .map(|tree_index| {
798                                Self::from_model(
799                                    &model.trees()[*tree_index],
800                                    feature_index_map,
801                                    missing_features,
802                                )
803                            })
804                            .collect(),
805                        tree_weights: tree_order
806                            .iter()
807                            .map(|tree_index| model.tree_weights()[*tree_index])
808                            .collect(),
809                        base_score: model.base_score(),
810                    }
811                }
812            },
813        }
814    }
815
816    fn from_classifier(
817        classifier: &DecisionTreeClassifier,
818        feature_index_map: &[usize],
819        missing_features: Option<&BTreeSet<usize>>,
820    ) -> Self {
821        match classifier.structure() {
822            tree::classifier::TreeStructure::Standard { nodes, root } => {
823                if classifier_nodes_are_binary_only(nodes) {
824                    return Self::BinaryClassifier {
825                        nodes: build_binary_classifier_layout(
826                            nodes,
827                            *root,
828                            classifier.class_labels(),
829                            feature_index_map,
830                            classifier.feature_preprocessing(),
831                            missing_features,
832                        ),
833                        class_labels: classifier.class_labels().to_vec(),
834                    };
835                }
836
837                let optimized_nodes = nodes
838                    .iter()
839                    .map(|node| match node {
840                        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
841                            OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
842                                class_counts,
843                            ))
844                        }
845                        tree::classifier::TreeNode::BinarySplit {
846                            feature_index,
847                            threshold_bin,
848                            missing_direction,
849                            left_child,
850                            right_child,
851                            class_counts,
852                            ..
853                        } => OptimizedClassifierNode::Binary {
854                            feature_index: remap_feature_index(*feature_index, feature_index_map),
855                            threshold_bin: *threshold_bin,
856                            children: [*left_child, *right_child],
857                            missing_bin: optimized_missing_bin(
858                                classifier.feature_preprocessing(),
859                                *feature_index,
860                                missing_features,
861                            ),
862                            missing_child: if missing_feature_enabled(
863                                *feature_index,
864                                missing_features,
865                            ) {
866                                match missing_direction {
867                                    tree::shared::MissingBranchDirection::Left => Some(*left_child),
868                                    tree::shared::MissingBranchDirection::Right => {
869                                        Some(*right_child)
870                                    }
871                                    tree::shared::MissingBranchDirection::Node => None,
872                                }
873                            } else {
874                                None
875                            },
876                            missing_probabilities: if missing_feature_enabled(
877                                *feature_index,
878                                missing_features,
879                            ) && matches!(
880                                missing_direction,
881                                tree::shared::MissingBranchDirection::Node
882                            ) {
883                                Some(normalized_probabilities_from_counts(class_counts))
884                            } else {
885                                None
886                            },
887                        },
888                        tree::classifier::TreeNode::MultiwaySplit {
889                            feature_index,
890                            class_counts,
891                            branches,
892                            missing_child,
893                            ..
894                        } => {
895                            let max_bin_index = branches
896                                .iter()
897                                .map(|(bin, _)| usize::from(*bin))
898                                .max()
899                                .unwrap_or(0);
900                            let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
901                            for (bin, child_index) in branches {
902                                child_lookup[usize::from(*bin)] = *child_index;
903                            }
904                            OptimizedClassifierNode::Multiway {
905                                feature_index: remap_feature_index(
906                                    *feature_index,
907                                    feature_index_map,
908                                ),
909                                child_lookup,
910                                max_bin_index,
911                                missing_bin: optimized_missing_bin(
912                                    classifier.feature_preprocessing(),
913                                    *feature_index,
914                                    missing_features,
915                                ),
916                                missing_child: if missing_feature_enabled(
917                                    *feature_index,
918                                    missing_features,
919                                ) {
920                                    *missing_child
921                                } else {
922                                    None
923                                },
924                                fallback_probabilities: normalized_probabilities_from_counts(
925                                    class_counts,
926                                ),
927                            }
928                        }
929                    })
930                    .collect();
931
932                Self::StandardClassifier {
933                    nodes: optimized_nodes,
934                    root: *root,
935                    class_labels: classifier.class_labels().to_vec(),
936                }
937            }
938            tree::classifier::TreeStructure::Oblivious {
939                splits,
940                leaf_class_counts,
941                ..
942            } => Self::ObliviousClassifier {
943                feature_indices: splits
944                    .iter()
945                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
946                    .collect(),
947                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
948                leaf_values: leaf_class_counts
949                    .iter()
950                    .map(|class_counts| normalized_probabilities_from_counts(class_counts))
951                    .collect(),
952                class_labels: classifier.class_labels().to_vec(),
953            },
954        }
955    }
956
957    fn from_regressor(
958        regressor: &DecisionTreeRegressor,
959        feature_index_map: &[usize],
960        missing_features: Option<&BTreeSet<usize>>,
961    ) -> Self {
962        match regressor.structure() {
963            tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
964                Self::BinaryRegressor {
965                    nodes: build_binary_regressor_layout(
966                        nodes,
967                        *root,
968                        feature_index_map,
969                        regressor.feature_preprocessing(),
970                        missing_features,
971                    ),
972                }
973            }
974            tree::regressor::RegressionTreeStructure::Oblivious {
975                splits,
976                leaf_values,
977                ..
978            } => Self::ObliviousRegressor {
979                feature_indices: splits
980                    .iter()
981                    .map(|split| remap_feature_index(split.feature_index, feature_index_map))
982                    .collect(),
983                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
984                leaf_values: leaf_values.clone(),
985            },
986        }
987    }
988
989    #[inline(always)]
990    fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
991        match self {
992            OptimizedRuntime::BinaryClassifier { .. }
993            | OptimizedRuntime::StandardClassifier { .. }
994            | OptimizedRuntime::ObliviousClassifier { .. }
995            | OptimizedRuntime::ForestClassifier { .. }
996            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
997                let probabilities = self
998                    .predict_proba_table_row(table, row_index)
999                    .expect("classifier runtime supports probability prediction");
1000                class_label_from_probabilities(&probabilities, self.class_labels())
1001            }
1002            OptimizedRuntime::BinaryRegressor { nodes } => {
1003                predict_binary_regressor_row(nodes, |feature_index| {
1004                    table.binned_value(feature_index, row_index)
1005                })
1006            }
1007            OptimizedRuntime::ObliviousRegressor {
1008                feature_indices,
1009                threshold_bins,
1010                leaf_values,
1011            } => predict_oblivious_row(
1012                feature_indices,
1013                threshold_bins,
1014                leaf_values,
1015                |feature_index| table.binned_value(feature_index, row_index),
1016            ),
1017            OptimizedRuntime::ForestRegressor { trees } => {
1018                trees
1019                    .iter()
1020                    .map(|tree| tree.predict_table_row(table, row_index))
1021                    .sum::<f64>()
1022                    / trees.len() as f64
1023            }
1024            OptimizedRuntime::BoostedRegressor {
1025                trees,
1026                tree_weights,
1027                base_score,
1028            } => {
1029                *base_score
1030                    + trees
1031                        .iter()
1032                        .zip(tree_weights.iter().copied())
1033                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1034                        .sum::<f64>()
1035            }
1036        }
1037    }
1038
1039    #[inline(always)]
1040    fn predict_proba_table_row(
1041        &self,
1042        table: &dyn TableAccess,
1043        row_index: usize,
1044    ) -> Result<Vec<f64>, PredictError> {
1045        match self {
1046            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1047                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1048                    table.binned_value(feature_index, row_index)
1049                })
1050                .to_vec(),
1051            ),
1052            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1053                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1054                    table.binned_value(feature_index, row_index)
1055                })
1056                .to_vec(),
1057            ),
1058            OptimizedRuntime::ObliviousClassifier {
1059                feature_indices,
1060                threshold_bins,
1061                leaf_values,
1062                ..
1063            } => Ok(predict_oblivious_probabilities_row(
1064                feature_indices,
1065                threshold_bins,
1066                leaf_values,
1067                |feature_index| table.binned_value(feature_index, row_index),
1068            )
1069            .to_vec()),
1070            OptimizedRuntime::ForestClassifier { trees, .. } => {
1071                let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1072                for tree in &trees[1..] {
1073                    let row = tree.predict_proba_table_row(table, row_index)?;
1074                    for (total, value) in totals.iter_mut().zip(row) {
1075                        *total += value;
1076                    }
1077                }
1078                let tree_count = trees.len() as f64;
1079                for value in &mut totals {
1080                    *value /= tree_count;
1081                }
1082                Ok(totals)
1083            }
1084            OptimizedRuntime::BoostedBinaryClassifier {
1085                trees,
1086                tree_weights,
1087                base_score,
1088                ..
1089            } => {
1090                let raw_score = *base_score
1091                    + trees
1092                        .iter()
1093                        .zip(tree_weights.iter().copied())
1094                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1095                        .sum::<f64>();
1096                let positive = sigmoid(raw_score);
1097                Ok(vec![1.0 - positive, positive])
1098            }
1099            OptimizedRuntime::BinaryRegressor { .. }
1100            | OptimizedRuntime::ObliviousRegressor { .. }
1101            | OptimizedRuntime::ForestRegressor { .. }
1102            | OptimizedRuntime::BoostedRegressor { .. } => {
1103                Err(PredictError::ProbabilityPredictionRequiresClassification)
1104            }
1105        }
1106    }
1107
1108    fn predict_proba_table(
1109        &self,
1110        table: &dyn TableAccess,
1111        executor: &InferenceExecutor,
1112    ) -> Result<Vec<Vec<f64>>, PredictError> {
1113        match self {
1114            OptimizedRuntime::BinaryClassifier { .. }
1115            | OptimizedRuntime::StandardClassifier { .. }
1116            | OptimizedRuntime::ObliviousClassifier { .. }
1117            | OptimizedRuntime::ForestClassifier { .. }
1118            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1119                if self.should_use_batch_matrix(table.n_rows()) {
1120                    let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1121                    self.predict_proba_column_major_matrix(&matrix, executor)
1122                } else {
1123                    (0..table.n_rows())
1124                        .map(|row_index| self.predict_proba_table_row(table, row_index))
1125                        .collect()
1126                }
1127            }
1128            OptimizedRuntime::BinaryRegressor { .. }
1129            | OptimizedRuntime::ObliviousRegressor { .. }
1130            | OptimizedRuntime::ForestRegressor { .. }
1131            | OptimizedRuntime::BoostedRegressor { .. } => {
1132                Err(PredictError::ProbabilityPredictionRequiresClassification)
1133            }
1134        }
1135    }
1136
1137    fn predict_column_major_matrix(
1138        &self,
1139        matrix: &ColumnMajorBinnedMatrix,
1140        executor: &InferenceExecutor,
1141    ) -> Vec<f64> {
1142        match self {
1143            OptimizedRuntime::BinaryClassifier { .. }
1144            | OptimizedRuntime::StandardClassifier { .. }
1145            | OptimizedRuntime::ObliviousClassifier { .. }
1146            | OptimizedRuntime::ForestClassifier { .. }
1147            | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1148                .predict_proba_column_major_matrix(matrix, executor)
1149                .expect("classifier runtime supports probability prediction")
1150                .into_iter()
1151                .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1152                .collect(),
1153            OptimizedRuntime::BinaryRegressor { nodes } => {
1154                predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1155            }
1156            OptimizedRuntime::ObliviousRegressor {
1157                feature_indices,
1158                threshold_bins,
1159                leaf_values,
1160            } => predict_oblivious_column_major_matrix(
1161                feature_indices,
1162                threshold_bins,
1163                leaf_values,
1164                matrix,
1165                executor,
1166            ),
1167            OptimizedRuntime::ForestRegressor { trees } => {
1168                let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1169                for tree in &trees[1..] {
1170                    let values = tree.predict_column_major_matrix(matrix, executor);
1171                    for (total, value) in totals.iter_mut().zip(values) {
1172                        *total += value;
1173                    }
1174                }
1175                let tree_count = trees.len() as f64;
1176                for total in &mut totals {
1177                    *total /= tree_count;
1178                }
1179                totals
1180            }
1181            OptimizedRuntime::BoostedRegressor {
1182                trees,
1183                tree_weights,
1184                base_score,
1185            } => {
1186                let mut totals = vec![*base_score; matrix.n_rows];
1187                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1188                    let values = tree.predict_column_major_matrix(matrix, executor);
1189                    for (total, value) in totals.iter_mut().zip(values) {
1190                        *total += weight * value;
1191                    }
1192                }
1193                totals
1194            }
1195        }
1196    }
1197
1198    fn predict_proba_column_major_matrix(
1199        &self,
1200        matrix: &ColumnMajorBinnedMatrix,
1201        executor: &InferenceExecutor,
1202    ) -> Result<Vec<Vec<f64>>, PredictError> {
1203        match self {
1204            OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1205                Ok(predict_binary_classifier_probabilities_column_major_matrix(
1206                    nodes, matrix, executor,
1207                ))
1208            }
1209            OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1210                .map(|row_index| {
1211                    self.predict_proba_binned_row_from_columns(matrix, row_index)
1212                        .expect("classifier runtime supports probability prediction")
1213                })
1214                .collect()),
1215            OptimizedRuntime::ObliviousClassifier {
1216                feature_indices,
1217                threshold_bins,
1218                leaf_values,
1219                ..
1220            } => Ok(predict_oblivious_probabilities_column_major_matrix(
1221                feature_indices,
1222                threshold_bins,
1223                leaf_values,
1224                matrix,
1225                executor,
1226            )),
1227            OptimizedRuntime::ForestClassifier { trees, .. } => {
1228                let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1229                for tree in &trees[1..] {
1230                    let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1231                    for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1232                        for (total, value) in row_totals.iter_mut().zip(row_values) {
1233                            *total += value;
1234                        }
1235                    }
1236                }
1237                let tree_count = trees.len() as f64;
1238                for row in &mut totals {
1239                    for value in row {
1240                        *value /= tree_count;
1241                    }
1242                }
1243                Ok(totals)
1244            }
1245            OptimizedRuntime::BoostedBinaryClassifier {
1246                trees,
1247                tree_weights,
1248                base_score,
1249                ..
1250            } => {
1251                let mut raw_scores = vec![*base_score; matrix.n_rows];
1252                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1253                    let values = tree.predict_column_major_matrix(matrix, executor);
1254                    for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1255                        *raw_score += weight * value;
1256                    }
1257                }
1258                Ok(raw_scores
1259                    .into_iter()
1260                    .map(|raw_score| {
1261                        let positive = sigmoid(raw_score);
1262                        vec![1.0 - positive, positive]
1263                    })
1264                    .collect())
1265            }
1266            OptimizedRuntime::BinaryRegressor { .. }
1267            | OptimizedRuntime::ObliviousRegressor { .. }
1268            | OptimizedRuntime::ForestRegressor { .. }
1269            | OptimizedRuntime::BoostedRegressor { .. } => {
1270                Err(PredictError::ProbabilityPredictionRequiresClassification)
1271            }
1272        }
1273    }
1274
1275    fn class_labels(&self) -> &[f64] {
1276        match self {
1277            OptimizedRuntime::BinaryClassifier { class_labels, .. }
1278            | OptimizedRuntime::StandardClassifier { class_labels, .. }
1279            | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1280            | OptimizedRuntime::ForestClassifier { class_labels, .. }
1281            | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1282            _ => &[],
1283        }
1284    }
1285
1286    #[inline(always)]
1287    fn predict_binned_row_from_columns(
1288        &self,
1289        matrix: &ColumnMajorBinnedMatrix,
1290        row_index: usize,
1291    ) -> f64 {
1292        match self {
1293            OptimizedRuntime::BinaryRegressor { nodes } => {
1294                predict_binary_regressor_row(nodes, |feature_index| {
1295                    matrix.column(feature_index).value_at(row_index)
1296                })
1297            }
1298            OptimizedRuntime::ObliviousRegressor {
1299                feature_indices,
1300                threshold_bins,
1301                leaf_values,
1302            } => predict_oblivious_row(
1303                feature_indices,
1304                threshold_bins,
1305                leaf_values,
1306                |feature_index| matrix.column(feature_index).value_at(row_index),
1307            ),
1308            OptimizedRuntime::BoostedRegressor {
1309                trees,
1310                tree_weights,
1311                base_score,
1312            } => {
1313                *base_score
1314                    + trees
1315                        .iter()
1316                        .zip(tree_weights.iter().copied())
1317                        .map(|(tree, weight)| {
1318                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1319                        })
1320                        .sum::<f64>()
1321            }
1322            _ => self.predict_column_major_matrix(
1323                matrix,
1324                &InferenceExecutor::new(1).expect("inference executor"),
1325            )[row_index],
1326        }
1327    }
1328
1329    #[inline(always)]
1330    fn predict_proba_binned_row_from_columns(
1331        &self,
1332        matrix: &ColumnMajorBinnedMatrix,
1333        row_index: usize,
1334    ) -> Result<Vec<f64>, PredictError> {
1335        match self {
1336            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1337                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1338                    matrix.column(feature_index).value_at(row_index)
1339                })
1340                .to_vec(),
1341            ),
1342            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1343                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1344                    matrix.column(feature_index).value_at(row_index)
1345                })
1346                .to_vec(),
1347            ),
1348            OptimizedRuntime::ObliviousClassifier {
1349                feature_indices,
1350                threshold_bins,
1351                leaf_values,
1352                ..
1353            } => Ok(predict_oblivious_probabilities_row(
1354                feature_indices,
1355                threshold_bins,
1356                leaf_values,
1357                |feature_index| matrix.column(feature_index).value_at(row_index),
1358            )
1359            .to_vec()),
1360            OptimizedRuntime::ForestClassifier { trees, .. } => {
1361                let mut totals =
1362                    trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1363                for tree in &trees[1..] {
1364                    let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1365                    for (total, value) in totals.iter_mut().zip(row) {
1366                        *total += value;
1367                    }
1368                }
1369                let tree_count = trees.len() as f64;
1370                for value in &mut totals {
1371                    *value /= tree_count;
1372                }
1373                Ok(totals)
1374            }
1375            OptimizedRuntime::BoostedBinaryClassifier {
1376                trees,
1377                tree_weights,
1378                base_score,
1379                ..
1380            } => {
1381                let raw_score = *base_score
1382                    + trees
1383                        .iter()
1384                        .zip(tree_weights.iter().copied())
1385                        .map(|(tree, weight)| {
1386                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1387                        })
1388                        .sum::<f64>();
1389                let positive = sigmoid(raw_score);
1390                Ok(vec![1.0 - positive, positive])
1391            }
1392            OptimizedRuntime::BinaryRegressor { .. }
1393            | OptimizedRuntime::ObliviousRegressor { .. }
1394            | OptimizedRuntime::ForestRegressor { .. }
1395            | OptimizedRuntime::BoostedRegressor { .. } => {
1396                Err(PredictError::ProbabilityPredictionRequiresClassification)
1397            }
1398        }
1399    }
1400}
1401
1402#[inline(always)]
1403fn predict_standard_classifier_probabilities_row<F>(
1404    nodes: &[OptimizedClassifierNode],
1405    root: usize,
1406    bin_at: F,
1407) -> &[f64]
1408where
1409    F: Fn(usize) -> u16,
1410{
1411    let mut node_index = root;
1412    loop {
1413        match &nodes[node_index] {
1414            OptimizedClassifierNode::Leaf(value) => return value,
1415            OptimizedClassifierNode::Binary {
1416                feature_index,
1417                threshold_bin,
1418                children,
1419                missing_bin,
1420                missing_child,
1421                missing_probabilities,
1422            } => {
1423                let bin = bin_at(*feature_index);
1424                if missing_bin.is_some_and(|expected| expected == bin) {
1425                    if let Some(probabilities) = missing_probabilities {
1426                        return probabilities;
1427                    }
1428                    if let Some(child_index) = missing_child {
1429                        node_index = *child_index;
1430                        continue;
1431                    }
1432                }
1433                let go_right = usize::from(bin > *threshold_bin);
1434                node_index = children[go_right];
1435            }
1436            OptimizedClassifierNode::Multiway {
1437                feature_index,
1438                child_lookup,
1439                max_bin_index,
1440                missing_bin,
1441                missing_child,
1442                fallback_probabilities,
1443            } => {
1444                let bin_value = bin_at(*feature_index);
1445                if missing_bin.is_some_and(|expected| expected == bin_value) {
1446                    if let Some(child_index) = missing_child {
1447                        node_index = *child_index;
1448                        continue;
1449                    }
1450                    return fallback_probabilities;
1451                }
1452                let bin = usize::from(bin_value);
1453                if bin > *max_bin_index {
1454                    return fallback_probabilities;
1455                }
1456                let child_index = child_lookup[bin];
1457                if child_index == usize::MAX {
1458                    return fallback_probabilities;
1459                }
1460                node_index = child_index;
1461            }
1462        }
1463    }
1464}
1465
1466#[inline(always)]
1467fn predict_binary_classifier_probabilities_row<F>(
1468    nodes: &[OptimizedBinaryClassifierNode],
1469    bin_at: F,
1470) -> &[f64]
1471where
1472    F: Fn(usize) -> u16,
1473{
1474    let mut node_index = 0usize;
1475    loop {
1476        match &nodes[node_index] {
1477            OptimizedBinaryClassifierNode::Leaf(value) => return value,
1478            OptimizedBinaryClassifierNode::Branch {
1479                feature_index,
1480                threshold_bin,
1481                jump_index,
1482                jump_if_greater,
1483                missing_bin,
1484                missing_jump_index,
1485                missing_probabilities,
1486            } => {
1487                let bin = bin_at(*feature_index);
1488                if missing_bin.is_some_and(|expected| expected == bin) {
1489                    if let Some(probabilities) = missing_probabilities {
1490                        return probabilities;
1491                    }
1492                    if let Some(jump_index) = missing_jump_index {
1493                        node_index = *jump_index;
1494                        continue;
1495                    }
1496                }
1497                let go_right = bin > *threshold_bin;
1498                node_index = if go_right == *jump_if_greater {
1499                    *jump_index
1500                } else {
1501                    node_index + 1
1502                };
1503            }
1504        }
1505    }
1506}
1507
1508#[inline(always)]
1509fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
1510where
1511    F: Fn(usize) -> u16,
1512{
1513    let mut node_index = 0usize;
1514    loop {
1515        match &nodes[node_index] {
1516            OptimizedBinaryRegressorNode::Leaf(value) => return *value,
1517            OptimizedBinaryRegressorNode::Branch {
1518                feature_index,
1519                threshold_bin,
1520                jump_index,
1521                jump_if_greater,
1522                missing_bin,
1523                missing_jump_index,
1524                missing_value,
1525            } => {
1526                let bin = bin_at(*feature_index);
1527                if missing_bin.is_some_and(|expected| expected == bin) {
1528                    if let Some(value) = missing_value {
1529                        return *value;
1530                    }
1531                    if let Some(jump_index) = missing_jump_index {
1532                        node_index = *jump_index;
1533                        continue;
1534                    }
1535                }
1536                let go_right = bin > *threshold_bin;
1537                node_index = if go_right == *jump_if_greater {
1538                    *jump_index
1539                } else {
1540                    node_index + 1
1541                };
1542            }
1543        }
1544    }
1545}
1546
1547fn predict_binary_classifier_probabilities_column_major_matrix(
1548    nodes: &[OptimizedBinaryClassifierNode],
1549    matrix: &ColumnMajorBinnedMatrix,
1550    _executor: &InferenceExecutor,
1551) -> Vec<Vec<f64>> {
1552    if binary_classifier_nodes_require_rowwise_missing(nodes) {
1553        return (0..matrix.n_rows)
1554            .map(|row_index| {
1555                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1556                    matrix.column(feature_index).value_at(row_index)
1557                })
1558                .to_vec()
1559            })
1560            .collect();
1561    }
1562    (0..matrix.n_rows)
1563        .map(|row_index| {
1564            predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1565                matrix.column(feature_index).value_at(row_index)
1566            })
1567            .to_vec()
1568        })
1569        .collect()
1570}
1571
1572fn predict_binary_regressor_column_major_matrix(
1573    nodes: &[OptimizedBinaryRegressorNode],
1574    matrix: &ColumnMajorBinnedMatrix,
1575    executor: &InferenceExecutor,
1576) -> Vec<f64> {
1577    if binary_regressor_nodes_require_rowwise_missing(nodes) {
1578        return (0..matrix.n_rows)
1579            .map(|row_index| {
1580                predict_binary_regressor_row(nodes, |feature_index| {
1581                    matrix.column(feature_index).value_at(row_index)
1582                })
1583            })
1584            .collect();
1585    }
1586    let mut outputs = vec![0.0; matrix.n_rows];
1587    executor.fill_chunks(
1588        &mut outputs,
1589        STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
1590        |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
1591    );
1592    outputs
1593}
1594
1595fn predict_binary_regressor_chunk(
1596    nodes: &[OptimizedBinaryRegressorNode],
1597    matrix: &ColumnMajorBinnedMatrix,
1598    start_row: usize,
1599    output: &mut [f64],
1600) {
1601    let mut row_indices: Vec<usize> = (0..output.len()).collect();
1602    let mut stack = vec![(0usize, 0usize, output.len())];
1603
1604    while let Some((node_index, start, end)) = stack.pop() {
1605        match &nodes[node_index] {
1606            OptimizedBinaryRegressorNode::Leaf(value) => {
1607                for position in start..end {
1608                    output[row_indices[position]] = *value;
1609                }
1610            }
1611            OptimizedBinaryRegressorNode::Branch {
1612                feature_index,
1613                threshold_bin,
1614                jump_index,
1615                jump_if_greater,
1616                ..
1617            } => {
1618                let fallthrough_index = node_index + 1;
1619                if *jump_index == fallthrough_index {
1620                    stack.push((fallthrough_index, start, end));
1621                    continue;
1622                }
1623
1624                let column = matrix.column(*feature_index);
1625                let mut partition = start;
1626                let mut jump_start = end;
1627                match column {
1628                    CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
1629                        let threshold = *threshold_bin as u8;
1630                        while partition < jump_start {
1631                            let row_offset = row_indices[partition];
1632                            let go_right = values[start_row + row_offset] > threshold;
1633                            let goes_jump = go_right == *jump_if_greater;
1634                            if goes_jump {
1635                                jump_start -= 1;
1636                                row_indices.swap(partition, jump_start);
1637                            } else {
1638                                partition += 1;
1639                            }
1640                        }
1641                    }
1642                    _ => {
1643                        while partition < jump_start {
1644                            let row_offset = row_indices[partition];
1645                            let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
1646                            let goes_jump = go_right == *jump_if_greater;
1647                            if goes_jump {
1648                                jump_start -= 1;
1649                                row_indices.swap(partition, jump_start);
1650                            } else {
1651                                partition += 1;
1652                            }
1653                        }
1654                    }
1655                }
1656
1657                if jump_start < end {
1658                    stack.push((*jump_index, jump_start, end));
1659                }
1660                if start < jump_start {
1661                    stack.push((fallthrough_index, start, jump_start));
1662                }
1663            }
1664        }
1665    }
1666}
1667
1668fn binary_classifier_nodes_require_rowwise_missing(
1669    nodes: &[OptimizedBinaryClassifierNode],
1670) -> bool {
1671    nodes.iter().any(|node| match node {
1672        OptimizedBinaryClassifierNode::Leaf(_) => false,
1673        OptimizedBinaryClassifierNode::Branch {
1674            missing_bin,
1675            missing_jump_index,
1676            missing_probabilities,
1677            ..
1678        } => {
1679            missing_bin.is_some() || missing_jump_index.is_some() || missing_probabilities.is_some()
1680        }
1681    })
1682}
1683
1684fn binary_regressor_nodes_require_rowwise_missing(nodes: &[OptimizedBinaryRegressorNode]) -> bool {
1685    nodes.iter().any(|node| match node {
1686        OptimizedBinaryRegressorNode::Leaf(_) => false,
1687        OptimizedBinaryRegressorNode::Branch {
1688            missing_bin,
1689            missing_jump_index,
1690            missing_value,
1691            ..
1692        } => missing_bin.is_some() || missing_jump_index.is_some() || missing_value.is_some(),
1693    })
1694}
1695
1696#[inline(always)]
1697fn predict_oblivious_row<F>(
1698    feature_indices: &[usize],
1699    threshold_bins: &[u16],
1700    leaf_values: &[f64],
1701    bin_at: F,
1702) -> f64
1703where
1704    F: Fn(usize) -> u16,
1705{
1706    let mut leaf_index = 0usize;
1707    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1708        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1709        leaf_index = (leaf_index << 1) | go_right;
1710    }
1711    leaf_values[leaf_index]
1712}
1713
1714#[inline(always)]
1715fn predict_oblivious_probabilities_row<'a, F>(
1716    feature_indices: &[usize],
1717    threshold_bins: &[u16],
1718    leaf_values: &'a [Vec<f64>],
1719    bin_at: F,
1720) -> &'a [f64]
1721where
1722    F: Fn(usize) -> u16,
1723{
1724    let mut leaf_index = 0usize;
1725    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
1726        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
1727        leaf_index = (leaf_index << 1) | go_right;
1728    }
1729    leaf_values[leaf_index].as_slice()
1730}
1731
1732fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
1733    let total = class_counts.iter().sum::<usize>();
1734    if total == 0 {
1735        return vec![0.0; class_counts.len()];
1736    }
1737
1738    class_counts
1739        .iter()
1740        .map(|count| *count as f64 / total as f64)
1741        .collect()
1742}
1743
1744fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
1745    let best_index = probabilities
1746        .iter()
1747        .copied()
1748        .enumerate()
1749        .max_by(|(left_index, left), (right_index, right)| {
1750            left.total_cmp(right)
1751                .then_with(|| right_index.cmp(left_index))
1752        })
1753        .map(|(index, _)| index)
1754        .expect("classification probability row is non-empty");
1755    class_labels[best_index]
1756}
1757
1758#[inline(always)]
1759fn sigmoid(value: f64) -> f64 {
1760    if value >= 0.0 {
1761        let exp = (-value).exp();
1762        1.0 / (1.0 + exp)
1763    } else {
1764        let exp = value.exp();
1765        exp / (1.0 + exp)
1766    }
1767}
1768
1769fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
1770    nodes.iter().all(|node| {
1771        matches!(
1772            node,
1773            tree::classifier::TreeNode::Leaf { .. }
1774                | tree::classifier::TreeNode::BinarySplit { .. }
1775        )
1776    })
1777}
1778
1779fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
1780    match &nodes[node_index] {
1781        tree::classifier::TreeNode::Leaf { sample_count, .. }
1782        | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
1783        | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
1784    }
1785}
1786
1787fn build_binary_classifier_layout(
1788    nodes: &[tree::classifier::TreeNode],
1789    root: usize,
1790    _class_labels: &[f64],
1791    feature_index_map: &[usize],
1792    preprocessing: &[FeaturePreprocessing],
1793    missing_features: Option<&BTreeSet<usize>>,
1794) -> Vec<OptimizedBinaryClassifierNode> {
1795    let mut layout = Vec::with_capacity(nodes.len());
1796    append_binary_classifier_node(
1797        nodes,
1798        root,
1799        &mut layout,
1800        feature_index_map,
1801        preprocessing,
1802        missing_features,
1803    );
1804    layout
1805}
1806
1807fn append_binary_classifier_node(
1808    nodes: &[tree::classifier::TreeNode],
1809    node_index: usize,
1810    layout: &mut Vec<OptimizedBinaryClassifierNode>,
1811    feature_index_map: &[usize],
1812    preprocessing: &[FeaturePreprocessing],
1813    missing_features: Option<&BTreeSet<usize>>,
1814) -> usize {
1815    let current_index = layout.len();
1816    layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
1817
1818    match &nodes[node_index] {
1819        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1820            layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
1821                normalized_probabilities_from_counts(class_counts),
1822            );
1823        }
1824        tree::classifier::TreeNode::BinarySplit {
1825            feature_index,
1826            threshold_bin,
1827            missing_direction,
1828            left_child,
1829            right_child,
1830            class_counts,
1831            ..
1832        } => {
1833            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1834                (*left_child, *left_child, true)
1835            } else {
1836                let left_count = classifier_node_sample_count(nodes, *left_child);
1837                let right_count = classifier_node_sample_count(nodes, *right_child);
1838                if left_count >= right_count {
1839                    (*left_child, *right_child, true)
1840                } else {
1841                    (*right_child, *left_child, false)
1842                }
1843            };
1844
1845            let fallthrough_index = append_binary_classifier_node(
1846                nodes,
1847                fallthrough_child,
1848                layout,
1849                feature_index_map,
1850                preprocessing,
1851                missing_features,
1852            );
1853            debug_assert_eq!(fallthrough_index, current_index + 1);
1854            let jump_index = if jump_child == fallthrough_child {
1855                fallthrough_index
1856            } else {
1857                append_binary_classifier_node(
1858                    nodes,
1859                    jump_child,
1860                    layout,
1861                    feature_index_map,
1862                    preprocessing,
1863                    missing_features,
1864                )
1865            };
1866
1867            let missing_bin =
1868                optimized_missing_bin(preprocessing, *feature_index, missing_features);
1869            let (missing_jump_index, missing_probabilities) =
1870                if missing_feature_enabled(*feature_index, missing_features) {
1871                    match missing_direction {
1872                        tree::shared::MissingBranchDirection::Left => (
1873                            Some(if *left_child == fallthrough_child {
1874                                fallthrough_index
1875                            } else {
1876                                jump_index
1877                            }),
1878                            None,
1879                        ),
1880                        tree::shared::MissingBranchDirection::Right => (
1881                            Some(if *right_child == fallthrough_child {
1882                                fallthrough_index
1883                            } else {
1884                                jump_index
1885                            }),
1886                            None,
1887                        ),
1888                        tree::shared::MissingBranchDirection::Node => (
1889                            None,
1890                            Some(normalized_probabilities_from_counts(class_counts)),
1891                        ),
1892                    }
1893                } else {
1894                    (None, None)
1895                };
1896
1897            layout[current_index] = OptimizedBinaryClassifierNode::Branch {
1898                feature_index: remap_feature_index(*feature_index, feature_index_map),
1899                threshold_bin: *threshold_bin,
1900                jump_index,
1901                jump_if_greater,
1902                missing_bin,
1903                missing_jump_index,
1904                missing_probabilities,
1905            };
1906        }
1907        tree::classifier::TreeNode::MultiwaySplit { .. } => {
1908            unreachable!("multiway nodes are filtered out before binary layout construction");
1909        }
1910    }
1911
1912    current_index
1913}
1914
1915fn regressor_node_sample_count(
1916    nodes: &[tree::regressor::RegressionNode],
1917    node_index: usize,
1918) -> usize {
1919    match &nodes[node_index] {
1920        tree::regressor::RegressionNode::Leaf { sample_count, .. }
1921        | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
1922    }
1923}
1924
1925fn build_binary_regressor_layout(
1926    nodes: &[tree::regressor::RegressionNode],
1927    root: usize,
1928    feature_index_map: &[usize],
1929    preprocessing: &[FeaturePreprocessing],
1930    missing_features: Option<&BTreeSet<usize>>,
1931) -> Vec<OptimizedBinaryRegressorNode> {
1932    let mut layout = Vec::with_capacity(nodes.len());
1933    append_binary_regressor_node(
1934        nodes,
1935        root,
1936        &mut layout,
1937        feature_index_map,
1938        preprocessing,
1939        missing_features,
1940    );
1941    layout
1942}
1943
1944fn append_binary_regressor_node(
1945    nodes: &[tree::regressor::RegressionNode],
1946    node_index: usize,
1947    layout: &mut Vec<OptimizedBinaryRegressorNode>,
1948    feature_index_map: &[usize],
1949    preprocessing: &[FeaturePreprocessing],
1950    missing_features: Option<&BTreeSet<usize>>,
1951) -> usize {
1952    let current_index = layout.len();
1953    layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
1954
1955    match &nodes[node_index] {
1956        tree::regressor::RegressionNode::Leaf { value, .. } => {
1957            layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
1958        }
1959        tree::regressor::RegressionNode::BinarySplit {
1960            feature_index,
1961            threshold_bin,
1962            missing_direction,
1963            missing_value,
1964            left_child,
1965            right_child,
1966            ..
1967        } => {
1968            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
1969                (*left_child, *left_child, true)
1970            } else {
1971                let left_count = regressor_node_sample_count(nodes, *left_child);
1972                let right_count = regressor_node_sample_count(nodes, *right_child);
1973                if left_count >= right_count {
1974                    (*left_child, *right_child, true)
1975                } else {
1976                    (*right_child, *left_child, false)
1977                }
1978            };
1979
1980            let fallthrough_index = append_binary_regressor_node(
1981                nodes,
1982                fallthrough_child,
1983                layout,
1984                feature_index_map,
1985                preprocessing,
1986                missing_features,
1987            );
1988            debug_assert_eq!(fallthrough_index, current_index + 1);
1989            let jump_index = if jump_child == fallthrough_child {
1990                fallthrough_index
1991            } else {
1992                append_binary_regressor_node(
1993                    nodes,
1994                    jump_child,
1995                    layout,
1996                    feature_index_map,
1997                    preprocessing,
1998                    missing_features,
1999                )
2000            };
2001
2002            let missing_bin =
2003                optimized_missing_bin(preprocessing, *feature_index, missing_features);
2004            let (missing_jump_index, missing_value) =
2005                if missing_feature_enabled(*feature_index, missing_features) {
2006                    match missing_direction {
2007                        tree::shared::MissingBranchDirection::Left => (
2008                            Some(if *left_child == fallthrough_child {
2009                                fallthrough_index
2010                            } else {
2011                                jump_index
2012                            }),
2013                            None,
2014                        ),
2015                        tree::shared::MissingBranchDirection::Right => (
2016                            Some(if *right_child == fallthrough_child {
2017                                fallthrough_index
2018                            } else {
2019                                jump_index
2020                            }),
2021                            None,
2022                        ),
2023                        tree::shared::MissingBranchDirection::Node => (None, Some(*missing_value)),
2024                    }
2025                } else {
2026                    (None, None)
2027                };
2028
2029            layout[current_index] = OptimizedBinaryRegressorNode::Branch {
2030                feature_index: remap_feature_index(*feature_index, feature_index_map),
2031                threshold_bin: *threshold_bin,
2032                jump_index,
2033                jump_if_greater,
2034                missing_bin,
2035                missing_jump_index,
2036                missing_value,
2037            };
2038        }
2039    }
2040
2041    current_index
2042}
2043
2044fn predict_oblivious_column_major_matrix(
2045    feature_indices: &[usize],
2046    threshold_bins: &[u16],
2047    leaf_values: &[f64],
2048    matrix: &ColumnMajorBinnedMatrix,
2049    executor: &InferenceExecutor,
2050) -> Vec<f64> {
2051    let mut outputs = vec![0.0; matrix.n_rows];
2052    executor.fill_chunks(
2053        &mut outputs,
2054        PARALLEL_INFERENCE_CHUNK_ROWS,
2055        |start_row, chunk| {
2056            predict_oblivious_chunk(
2057                feature_indices,
2058                threshold_bins,
2059                leaf_values,
2060                matrix,
2061                start_row,
2062                chunk,
2063            )
2064        },
2065    );
2066    outputs
2067}
2068
2069fn predict_oblivious_probabilities_column_major_matrix(
2070    feature_indices: &[usize],
2071    threshold_bins: &[u16],
2072    leaf_values: &[Vec<f64>],
2073    matrix: &ColumnMajorBinnedMatrix,
2074    _executor: &InferenceExecutor,
2075) -> Vec<Vec<f64>> {
2076    (0..matrix.n_rows)
2077        .map(|row_index| {
2078            predict_oblivious_probabilities_row(
2079                feature_indices,
2080                threshold_bins,
2081                leaf_values,
2082                |feature_index| matrix.column(feature_index).value_at(row_index),
2083            )
2084            .to_vec()
2085        })
2086        .collect()
2087}
2088
2089fn predict_oblivious_chunk(
2090    feature_indices: &[usize],
2091    threshold_bins: &[u16],
2092    leaf_values: &[f64],
2093    matrix: &ColumnMajorBinnedMatrix,
2094    start_row: usize,
2095    output: &mut [f64],
2096) {
2097    let processed = simd_predict_oblivious_chunk(
2098        feature_indices,
2099        threshold_bins,
2100        leaf_values,
2101        matrix,
2102        start_row,
2103        output,
2104    );
2105
2106    for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2107        let row_index = start_row + offset;
2108        *out = predict_oblivious_row(
2109            feature_indices,
2110            threshold_bins,
2111            leaf_values,
2112            |feature_index| matrix.column(feature_index).value_at(row_index),
2113        );
2114    }
2115}
2116
2117fn simd_predict_oblivious_chunk(
2118    feature_indices: &[usize],
2119    threshold_bins: &[u16],
2120    leaf_values: &[f64],
2121    matrix: &ColumnMajorBinnedMatrix,
2122    start_row: usize,
2123    output: &mut [f64],
2124) -> usize {
2125    let mut processed = 0usize;
2126    let ones = u32x8::splat(1);
2127
2128    while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2129        let base_row = start_row + processed;
2130        let mut leaf_indices = u32x8::splat(0);
2131
2132        for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2133            let column = matrix.column(feature_index);
2134            let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2135                let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2136                    .try_into()
2137                    .expect("lane width matches the fixed SIMD width");
2138                u32x8::new([
2139                    u32::from(lanes[0]),
2140                    u32::from(lanes[1]),
2141                    u32::from(lanes[2]),
2142                    u32::from(lanes[3]),
2143                    u32::from(lanes[4]),
2144                    u32::from(lanes[5]),
2145                    u32::from(lanes[6]),
2146                    u32::from(lanes[7]),
2147                ])
2148            } else {
2149                let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2150                    .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2151                    .expect("column is u16 when not u8")
2152                    .try_into()
2153                    .expect("lane width matches the fixed SIMD width");
2154                u32x8::from(u16x8::new(lanes))
2155            };
2156            let threshold = u32x8::splat(u32::from(threshold_bin));
2157            let bit = bins.cmp_gt(threshold) & ones;
2158            leaf_indices = (leaf_indices << 1) | bit;
2159        }
2160
2161        let lane_indices = leaf_indices.to_array();
2162        for lane in 0..OBLIVIOUS_SIMD_LANES {
2163            output[processed + lane] =
2164                leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2165        }
2166        processed += OBLIVIOUS_SIMD_LANES;
2167    }
2168
2169    processed
2170}
2171
2172pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2173    training::train(train_set, config)
2174}
2175
2176#[cfg(test)]
2177mod tests;