Skip to main content

forestfire_core/
lib.rs

1use forestfire_data::{
2    BinnedColumnKind, MAX_NUMERIC_BINS, NumericBins, TableAccess, numeric_bin_boundaries,
3};
4#[cfg(feature = "polars")]
5use polars::prelude::{Column, DataFrame, DataType, IdxSize, LazyFrame};
6use rayon::ThreadPoolBuilder;
7use rayon::prelude::*;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::fmt::{Display, Formatter};
13use std::sync::Arc;
14use wide::{u16x8, u32x8};
15
16mod boosting;
17mod bootstrap;
18mod forest;
19pub mod ir;
20mod sampling;
21mod training;
22pub mod tree;
23
24pub use boosting::BoostingError;
25pub use boosting::GradientBoostedTrees;
26pub use forest::RandomForest;
27pub use ir::IrError;
28pub use ir::ModelPackageIr;
29pub use tree::classifier::DecisionTreeAlgorithm;
30pub use tree::classifier::DecisionTreeClassifier;
31pub use tree::classifier::DecisionTreeError;
32pub use tree::classifier::DecisionTreeOptions;
33pub use tree::classifier::train_c45;
34pub use tree::classifier::train_cart;
35pub use tree::classifier::train_id3;
36pub use tree::classifier::train_oblivious;
37pub use tree::classifier::train_randomized;
38pub use tree::regressor::DecisionTreeRegressor;
39pub use tree::regressor::RegressionTreeAlgorithm;
40pub use tree::regressor::RegressionTreeError;
41pub use tree::regressor::RegressionTreeOptions;
42pub use tree::regressor::train_cart_regressor;
43pub use tree::regressor::train_oblivious_regressor;
44pub use tree::regressor::train_randomized_regressor;
45
46const PARALLEL_INFERENCE_ROW_THRESHOLD: usize = 256;
47const PARALLEL_INFERENCE_CHUNK_ROWS: usize = 256;
48const STANDARD_BATCH_INFERENCE_CHUNK_ROWS: usize = 4096;
49const OBLIVIOUS_SIMD_LANES: usize = 8;
50#[cfg(feature = "polars")]
51const LAZYFRAME_PREDICT_BATCH_ROWS: usize = 10_000;
52const COMPILED_ARTIFACT_MAGIC: [u8; 4] = *b"FFCA";
53const COMPILED_ARTIFACT_VERSION: u16 = 1;
54const COMPILED_ARTIFACT_BACKEND_CPU: u16 = 1;
55const COMPILED_ARTIFACT_HEADER_LEN: usize = 8;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum TrainAlgorithm {
59    Dt,
60    Rf,
61    Gbm,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum Criterion {
66    Auto,
67    Gini,
68    Entropy,
69    Mean,
70    Median,
71    SecondOrder,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum Task {
76    Regression,
77    Classification,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum TreeType {
82    Id3,
83    C45,
84    Cart,
85    Randomized,
86    Oblivious,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum MaxFeatures {
91    Auto,
92    All,
93    Sqrt,
94    Third,
95    Count(usize),
96}
97
98impl MaxFeatures {
99    pub fn resolve(self, task: Task, feature_count: usize) -> usize {
100        match self {
101            MaxFeatures::Auto => match task {
102                Task::Classification => MaxFeatures::Sqrt.resolve(task, feature_count),
103                Task::Regression => MaxFeatures::Third.resolve(task, feature_count),
104            },
105            MaxFeatures::All => feature_count.max(1),
106            MaxFeatures::Sqrt => ((feature_count as f64).sqrt().floor() as usize).max(1),
107            MaxFeatures::Third => (feature_count / 3).max(1),
108            MaxFeatures::Count(count) => count.min(feature_count).max(1),
109        }
110    }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum InputFeatureKind {
116    Numeric,
117    Binary,
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
121pub struct NumericBinBoundary {
122    pub bin: u16,
123    pub upper_bound: f64,
124}
125
126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
127#[serde(tag = "kind", rename_all = "snake_case")]
128pub enum FeaturePreprocessing {
129    Numeric {
130        bin_boundaries: Vec<NumericBinBoundary>,
131    },
132    Binary,
133}
134
135#[derive(Debug, Clone, Copy)]
136pub struct TrainConfig {
137    pub algorithm: TrainAlgorithm,
138    pub task: Task,
139    pub tree_type: TreeType,
140    pub criterion: Criterion,
141    pub max_depth: Option<usize>,
142    pub min_samples_split: Option<usize>,
143    pub min_samples_leaf: Option<usize>,
144    pub physical_cores: Option<usize>,
145    pub n_trees: Option<usize>,
146    pub max_features: MaxFeatures,
147    pub seed: Option<u64>,
148    pub compute_oob: bool,
149    pub learning_rate: Option<f64>,
150    pub bootstrap: bool,
151    pub top_gradient_fraction: Option<f64>,
152    pub other_gradient_fraction: Option<f64>,
153}
154
155impl Default for TrainConfig {
156    fn default() -> Self {
157        Self {
158            algorithm: TrainAlgorithm::Dt,
159            task: Task::Regression,
160            tree_type: TreeType::Cart,
161            criterion: Criterion::Auto,
162            max_depth: None,
163            min_samples_split: None,
164            min_samples_leaf: None,
165            physical_cores: None,
166            n_trees: None,
167            max_features: MaxFeatures::Auto,
168            seed: None,
169            compute_oob: false,
170            learning_rate: None,
171            bootstrap: false,
172            top_gradient_fraction: None,
173            other_gradient_fraction: None,
174        }
175    }
176}
177
178#[derive(Debug, Clone)]
179pub enum Model {
180    DecisionTreeClassifier(DecisionTreeClassifier),
181    DecisionTreeRegressor(DecisionTreeRegressor),
182    RandomForest(RandomForest),
183    GradientBoostedTrees(GradientBoostedTrees),
184}
185
186#[derive(Debug)]
187pub enum TrainError {
188    DecisionTree(DecisionTreeError),
189    RegressionTree(RegressionTreeError),
190    Boosting(BoostingError),
191    InvalidPhysicalCoreCount {
192        requested: usize,
193        available: usize,
194    },
195    ThreadPoolBuildFailed(String),
196    UnsupportedConfiguration {
197        task: Task,
198        tree_type: TreeType,
199        criterion: Criterion,
200    },
201    InvalidMaxDepth(usize),
202    InvalidMinSamplesSplit(usize),
203    InvalidMinSamplesLeaf(usize),
204    InvalidTreeCount(usize),
205    InvalidMaxFeatures(usize),
206}
207
208impl Display for TrainError {
209    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
210        match self {
211            TrainError::DecisionTree(err) => err.fmt(f),
212            TrainError::RegressionTree(err) => err.fmt(f),
213            TrainError::Boosting(err) => err.fmt(f),
214            TrainError::InvalidPhysicalCoreCount {
215                requested,
216                available,
217            } => write!(
218                f,
219                "Requested {} physical cores, but the available physical core count is {}.",
220                requested, available
221            ),
222            TrainError::ThreadPoolBuildFailed(message) => {
223                write!(f, "Failed to build training thread pool: {}.", message)
224            }
225            TrainError::UnsupportedConfiguration {
226                task,
227                tree_type,
228                criterion,
229            } => write!(
230                f,
231                "Unsupported training configuration: task={:?}, tree_type={:?}, criterion={:?}.",
232                task, tree_type, criterion
233            ),
234            TrainError::InvalidMaxDepth(value) => {
235                write!(f, "max_depth must be at least 1. Received {}.", value)
236            }
237            TrainError::InvalidMinSamplesSplit(value) => {
238                write!(
239                    f,
240                    "min_samples_split must be at least 1. Received {}.",
241                    value
242                )
243            }
244            TrainError::InvalidMinSamplesLeaf(value) => {
245                write!(
246                    f,
247                    "min_samples_leaf must be at least 1. Received {}.",
248                    value
249                )
250            }
251            TrainError::InvalidTreeCount(n_trees) => {
252                write!(
253                    f,
254                    "Random forest requires at least one tree. Received {}.",
255                    n_trees
256                )
257            }
258            TrainError::InvalidMaxFeatures(count) => {
259                write!(
260                    f,
261                    "max_features must be at least 1 when provided as an integer. Received {}.",
262                    count
263                )
264            }
265        }
266    }
267}
268
269impl Error for TrainError {}
270
271#[derive(Debug, Clone, PartialEq)]
272pub enum PredictError {
273    ProbabilityPredictionRequiresClassification,
274    RaggedRows {
275        row: usize,
276        expected: usize,
277        actual: usize,
278    },
279    FeatureCountMismatch {
280        expected: usize,
281        actual: usize,
282    },
283    ColumnLengthMismatch {
284        feature: String,
285        expected: usize,
286        actual: usize,
287    },
288    MissingFeature(String),
289    UnexpectedFeature(String),
290    InvalidBinaryValue {
291        feature_index: usize,
292        row_index: usize,
293        value: f64,
294    },
295    NullValue {
296        feature: String,
297        row_index: usize,
298    },
299    UnsupportedFeatureType {
300        feature: String,
301        dtype: String,
302    },
303    Polars(String),
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct TreeStructureSummary {
308    pub representation: String,
309    pub node_count: usize,
310    pub internal_node_count: usize,
311    pub leaf_count: usize,
312    pub actual_depth: usize,
313    pub shortest_path: usize,
314    pub longest_path: usize,
315    pub average_path: f64,
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct PredictionValueStats {
320    pub count: usize,
321    pub unique_count: usize,
322    pub min: f64,
323    pub max: f64,
324    pub mean: f64,
325    pub std_dev: f64,
326    pub histogram: Vec<PredictionHistogramEntry>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct PredictionHistogramEntry {
331    pub prediction: f64,
332    pub count: usize,
333}
334
335#[derive(Debug, Clone, PartialEq, Eq)]
336pub enum IntrospectionError {
337    TreeIndexOutOfBounds { requested: usize, available: usize },
338    NodeIndexOutOfBounds { requested: usize, available: usize },
339    LevelIndexOutOfBounds { requested: usize, available: usize },
340    LeafIndexOutOfBounds { requested: usize, available: usize },
341    NotANodeTree,
342    NotAnObliviousTree,
343}
344
345impl Display for IntrospectionError {
346    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
347        match self {
348            IntrospectionError::TreeIndexOutOfBounds {
349                requested,
350                available,
351            } => write!(
352                f,
353                "Tree index {} is out of bounds for model with {} trees.",
354                requested, available
355            ),
356            IntrospectionError::NodeIndexOutOfBounds {
357                requested,
358                available,
359            } => write!(
360                f,
361                "Node index {} is out of bounds for tree with {} nodes.",
362                requested, available
363            ),
364            IntrospectionError::LevelIndexOutOfBounds {
365                requested,
366                available,
367            } => write!(
368                f,
369                "Level index {} is out of bounds for tree with {} levels.",
370                requested, available
371            ),
372            IntrospectionError::LeafIndexOutOfBounds {
373                requested,
374                available,
375            } => write!(
376                f,
377                "Leaf index {} is out of bounds for tree with {} leaves.",
378                requested, available
379            ),
380            IntrospectionError::NotANodeTree => write!(
381                f,
382                "This tree uses oblivious-level representation; inspect levels or leaves instead."
383            ),
384            IntrospectionError::NotAnObliviousTree => write!(
385                f,
386                "This tree uses node-tree representation; inspect nodes instead."
387            ),
388        }
389    }
390}
391
392impl Error for IntrospectionError {}
393
394impl Display for PredictError {
395    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396        match self {
397            PredictError::ProbabilityPredictionRequiresClassification => write!(
398                f,
399                "predict_proba is only available for classification models."
400            ),
401            PredictError::RaggedRows {
402                row,
403                expected,
404                actual,
405            } => write!(
406                f,
407                "Ragged inference row at index {}: expected {} columns, found {}.",
408                row, expected, actual
409            ),
410            PredictError::FeatureCountMismatch { expected, actual } => write!(
411                f,
412                "Inference input has {} features, but the model expects {}.",
413                actual, expected
414            ),
415            PredictError::ColumnLengthMismatch {
416                feature,
417                expected,
418                actual,
419            } => write!(
420                f,
421                "Feature '{}' has {} values, expected {}.",
422                feature, actual, expected
423            ),
424            PredictError::MissingFeature(feature) => {
425                write!(f, "Missing required feature '{}'.", feature)
426            }
427            PredictError::UnexpectedFeature(feature) => {
428                write!(f, "Unexpected feature '{}'.", feature)
429            }
430            PredictError::InvalidBinaryValue {
431                feature_index,
432                row_index,
433                value,
434            } => write!(
435                f,
436                "Feature {} at row {} must be binary for inference, found {}.",
437                feature_index, row_index, value
438            ),
439            PredictError::NullValue { feature, row_index } => write!(
440                f,
441                "Feature '{}' contains a null value at row {}.",
442                feature, row_index
443            ),
444            PredictError::UnsupportedFeatureType { feature, dtype } => write!(
445                f,
446                "Feature '{}' has unsupported dtype '{}'.",
447                feature, dtype
448            ),
449            PredictError::Polars(message) => write!(f, "Polars inference failed: {}.", message),
450        }
451    }
452}
453
454impl Error for PredictError {}
455
456#[cfg(feature = "polars")]
457impl From<polars::error::PolarsError> for PredictError {
458    fn from(value: polars::error::PolarsError) -> Self {
459        PredictError::Polars(value.to_string())
460    }
461}
462
463#[derive(Debug)]
464pub enum OptimizeError {
465    InvalidPhysicalCoreCount { requested: usize, available: usize },
466    ThreadPoolBuildFailed(String),
467    UnsupportedModelType(&'static str),
468}
469
470impl Display for OptimizeError {
471    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
472        match self {
473            OptimizeError::InvalidPhysicalCoreCount {
474                requested,
475                available,
476            } => write!(
477                f,
478                "Requested {} physical cores, but the available physical core count is {}.",
479                requested, available
480            ),
481            OptimizeError::ThreadPoolBuildFailed(message) => {
482                write!(f, "Failed to build inference thread pool: {}.", message)
483            }
484            OptimizeError::UnsupportedModelType(model_type) => {
485                write!(
486                    f,
487                    "Optimized inference is not supported for model type '{}'.",
488                    model_type
489                )
490            }
491        }
492    }
493}
494
495impl Error for OptimizeError {}
496
497#[derive(Debug)]
498pub enum CompiledArtifactError {
499    ArtifactTooShort { actual: usize, minimum: usize },
500    InvalidMagic([u8; 4]),
501    UnsupportedVersion(u16),
502    UnsupportedBackend(u16),
503    Encode(String),
504    Decode(String),
505    InvalidSemanticModel(IrError),
506    InvalidRuntime(OptimizeError),
507}
508
509impl Display for CompiledArtifactError {
510    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
511        match self {
512            CompiledArtifactError::ArtifactTooShort { actual, minimum } => write!(
513                f,
514                "Compiled artifact is too short: expected at least {} bytes, found {}.",
515                minimum, actual
516            ),
517            CompiledArtifactError::InvalidMagic(magic) => {
518                write!(f, "Compiled artifact has invalid magic bytes: {:?}.", magic)
519            }
520            CompiledArtifactError::UnsupportedVersion(version) => {
521                write!(f, "Unsupported compiled artifact version: {}.", version)
522            }
523            CompiledArtifactError::UnsupportedBackend(backend) => {
524                write!(f, "Unsupported compiled artifact backend: {}.", backend)
525            }
526            CompiledArtifactError::Encode(message) => {
527                write!(f, "Failed to encode compiled artifact: {}.", message)
528            }
529            CompiledArtifactError::Decode(message) => {
530                write!(f, "Failed to decode compiled artifact: {}.", message)
531            }
532            CompiledArtifactError::InvalidSemanticModel(err) => err.fmt(f),
533            CompiledArtifactError::InvalidRuntime(err) => err.fmt(f),
534        }
535    }
536}
537
538impl Error for CompiledArtifactError {}
539
540#[derive(Debug, Clone, Copy)]
541pub(crate) struct Parallelism {
542    thread_count: usize,
543}
544
545impl Parallelism {
546    pub(crate) fn sequential() -> Self {
547        Self { thread_count: 1 }
548    }
549
550    pub(crate) fn enabled(self) -> bool {
551        self.thread_count > 1
552    }
553}
554
555pub(crate) fn capture_feature_preprocessing(table: &dyn TableAccess) -> Vec<FeaturePreprocessing> {
556    (0..table.n_features())
557        .map(|feature_index| {
558            if table.is_binary_feature(feature_index) {
559                FeaturePreprocessing::Binary
560            } else {
561                let values = (0..table.n_rows())
562                    .map(|row_index| table.feature_value(feature_index, row_index))
563                    .collect::<Vec<_>>();
564                FeaturePreprocessing::Numeric {
565                    bin_boundaries: numeric_bin_boundaries(
566                        &values,
567                        NumericBins::Fixed(table.numeric_bin_cap()),
568                    )
569                    .into_iter()
570                    .map(|(bin, upper_bound)| NumericBinBoundary { bin, upper_bound })
571                    .collect(),
572                }
573            }
574        })
575        .collect()
576}
577
578#[derive(Debug, Clone)]
579enum InferenceFeatureColumn {
580    Numeric(Vec<f64>),
581    Binary(Vec<bool>),
582}
583
584#[derive(Debug, Clone)]
585enum InferenceBinnedColumn {
586    Numeric(Vec<u16>),
587    Binary(Vec<bool>),
588}
589
590#[derive(Debug, Clone)]
591enum CompactBinnedColumn {
592    U8(Vec<u8>),
593    U16(Vec<u16>),
594}
595
596impl CompactBinnedColumn {
597    #[inline(always)]
598    fn value_at(&self, row_index: usize) -> u16 {
599        match self {
600            CompactBinnedColumn::U8(values) => u16::from(values[row_index]),
601            CompactBinnedColumn::U16(values) => values[row_index],
602        }
603    }
604
605    #[inline(always)]
606    fn slice_u8(&self, start: usize, len: usize) -> Option<&[u8]> {
607        match self {
608            CompactBinnedColumn::U8(values) => Some(&values[start..start + len]),
609            CompactBinnedColumn::U16(_) => None,
610        }
611    }
612
613    #[inline(always)]
614    fn slice_u16(&self, start: usize, len: usize) -> Option<&[u16]> {
615        match self {
616            CompactBinnedColumn::U8(_) => None,
617            CompactBinnedColumn::U16(values) => Some(&values[start..start + len]),
618        }
619    }
620}
621
622#[derive(Debug, Clone)]
623pub(crate) struct InferenceTable {
624    feature_columns: Vec<InferenceFeatureColumn>,
625    binned_feature_columns: Vec<InferenceBinnedColumn>,
626    n_rows: usize,
627}
628
629impl InferenceTable {
630    pub(crate) fn from_rows(
631        rows: Vec<Vec<f64>>,
632        preprocessing: &[FeaturePreprocessing],
633    ) -> Result<Self, PredictError> {
634        let expected = preprocessing.len();
635        if let Some((row_index, actual)) = rows
636            .iter()
637            .enumerate()
638            .find_map(|(row_index, row)| (row.len() != expected).then_some((row_index, row.len())))
639        {
640            return Err(PredictError::RaggedRows {
641                row: row_index,
642                expected,
643                actual,
644            });
645        }
646
647        let columns = (0..expected)
648            .map(|feature_index| {
649                rows.iter()
650                    .map(|row| row[feature_index])
651                    .collect::<Vec<_>>()
652            })
653            .collect::<Vec<_>>();
654
655        Self::from_columns(columns, preprocessing)
656    }
657
658    pub(crate) fn from_named_columns(
659        columns: BTreeMap<String, Vec<f64>>,
660        preprocessing: &[FeaturePreprocessing],
661    ) -> Result<Self, PredictError> {
662        let expected = preprocessing.len();
663        if columns.len() != expected {
664            for feature_index in 0..expected {
665                let name = format!("f{}", feature_index);
666                if !columns.contains_key(&name) {
667                    return Err(PredictError::MissingFeature(name));
668                }
669            }
670            if let Some(unexpected) = columns.keys().find(|name| {
671                name.strip_prefix('f')
672                    .and_then(|idx| idx.parse::<usize>().ok())
673                    .is_none_or(|idx| idx >= expected)
674            }) {
675                return Err(PredictError::UnexpectedFeature(unexpected.clone()));
676            }
677        }
678
679        let n_rows = columns.values().next().map_or(0, Vec::len);
680        let ordered = (0..expected)
681            .map(|feature_index| {
682                let feature_name = format!("f{}", feature_index);
683                let values = columns
684                    .get(&feature_name)
685                    .ok_or_else(|| PredictError::MissingFeature(feature_name.clone()))?;
686                if values.len() != n_rows {
687                    return Err(PredictError::ColumnLengthMismatch {
688                        feature: feature_name,
689                        expected: n_rows,
690                        actual: values.len(),
691                    });
692                }
693                Ok(values.clone())
694            })
695            .collect::<Result<Vec<_>, _>>()?;
696
697        Self::from_columns(ordered, preprocessing)
698    }
699
700    pub(crate) fn from_sparse_binary_columns(
701        n_rows: usize,
702        n_features: usize,
703        columns: Vec<Vec<usize>>,
704        preprocessing: &[FeaturePreprocessing],
705    ) -> Result<Self, PredictError> {
706        if n_features != preprocessing.len() {
707            return Err(PredictError::FeatureCountMismatch {
708                expected: preprocessing.len(),
709                actual: n_features,
710            });
711        }
712
713        let mut dense_columns = Vec::with_capacity(n_features);
714        for (feature_index, row_indices) in columns.into_iter().enumerate() {
715            match preprocessing.get(feature_index) {
716                Some(FeaturePreprocessing::Binary) => {
717                    let mut values = vec![false; n_rows];
718                    for row_index in row_indices {
719                        if row_index >= n_rows {
720                            return Err(PredictError::ColumnLengthMismatch {
721                                feature: format!("f{}", feature_index),
722                                expected: n_rows,
723                                actual: row_index + 1,
724                            });
725                        }
726                        values[row_index] = true;
727                    }
728                    dense_columns.push(values.into_iter().map(f64::from).collect());
729                }
730                Some(FeaturePreprocessing::Numeric { .. }) => {
731                    return Err(PredictError::InvalidBinaryValue {
732                        feature_index,
733                        row_index: 0,
734                        value: 1.0,
735                    });
736                }
737                None => unreachable!("validated feature count"),
738            }
739        }
740
741        Self::from_columns(dense_columns, preprocessing)
742    }
743
744    fn from_columns(
745        columns: Vec<Vec<f64>>,
746        preprocessing: &[FeaturePreprocessing],
747    ) -> Result<Self, PredictError> {
748        if columns.len() != preprocessing.len() {
749            return Err(PredictError::FeatureCountMismatch {
750                expected: preprocessing.len(),
751                actual: columns.len(),
752            });
753        }
754
755        let n_rows = columns.first().map_or(0, Vec::len);
756        let mut feature_columns = Vec::with_capacity(columns.len());
757        let mut binned_feature_columns = Vec::with_capacity(columns.len());
758
759        for (feature_index, (column, feature_preprocessing)) in
760            columns.into_iter().zip(preprocessing.iter()).enumerate()
761        {
762            if column.len() != n_rows {
763                return Err(PredictError::ColumnLengthMismatch {
764                    feature: format!("f{}", feature_index),
765                    expected: n_rows,
766                    actual: column.len(),
767                });
768            }
769            match feature_preprocessing {
770                FeaturePreprocessing::Binary => {
771                    let values = column
772                        .into_iter()
773                        .enumerate()
774                        .map(|(row_index, value)| match value {
775                            v if v.total_cmp(&0.0).is_eq() => Ok(false),
776                            v if v.total_cmp(&1.0).is_eq() => Ok(true),
777                            v => Err(PredictError::InvalidBinaryValue {
778                                feature_index,
779                                row_index,
780                                value: v,
781                            }),
782                        })
783                        .collect::<Result<Vec<_>, _>>()?;
784                    feature_columns.push(InferenceFeatureColumn::Binary(values.clone()));
785                    binned_feature_columns.push(InferenceBinnedColumn::Binary(values));
786                }
787                FeaturePreprocessing::Numeric { bin_boundaries } => {
788                    let bins = column
789                        .iter()
790                        .map(|value| infer_numeric_bin(*value, bin_boundaries))
791                        .collect();
792                    feature_columns.push(InferenceFeatureColumn::Numeric(column));
793                    binned_feature_columns.push(InferenceBinnedColumn::Numeric(bins));
794                }
795            }
796        }
797
798        Ok(Self {
799            feature_columns,
800            binned_feature_columns,
801            n_rows,
802        })
803    }
804
805    pub(crate) fn to_column_major_binned_matrix(&self) -> ColumnMajorBinnedMatrix {
806        let n_features = self.feature_columns.len();
807        let columns = (0..n_features)
808            .map(
809                |feature_index| match &self.binned_feature_columns[feature_index] {
810                    InferenceBinnedColumn::Numeric(values) => compact_binned_column(values),
811                    InferenceBinnedColumn::Binary(values) => CompactBinnedColumn::U8(
812                        values.iter().map(|value| u8::from(*value)).collect(),
813                    ),
814                },
815            )
816            .collect();
817
818        ColumnMajorBinnedMatrix {
819            n_rows: self.n_rows,
820            columns,
821        }
822    }
823}
824
825#[derive(Debug, Clone)]
826struct ColumnMajorBinnedMatrix {
827    n_rows: usize,
828    columns: Vec<CompactBinnedColumn>,
829}
830
831impl ColumnMajorBinnedMatrix {
832    fn from_table_access(table: &dyn TableAccess) -> Self {
833        let columns = (0..table.n_features())
834            .map(|feature_index| {
835                if table.is_binary_binned_feature(feature_index) {
836                    CompactBinnedColumn::U8(
837                        (0..table.n_rows())
838                            .map(|row_index| {
839                                u8::from(
840                                    table
841                                        .binned_boolean_value(feature_index, row_index)
842                                        .unwrap_or(false),
843                                )
844                            })
845                            .collect(),
846                    )
847                } else {
848                    compact_binned_column(
849                        &(0..table.n_rows())
850                            .map(|row_index| table.binned_value(feature_index, row_index))
851                            .collect::<Vec<_>>(),
852                    )
853                }
854            })
855            .collect();
856
857        Self {
858            n_rows: table.n_rows(),
859            columns,
860        }
861    }
862
863    #[inline(always)]
864    fn column(&self, feature_index: usize) -> &CompactBinnedColumn {
865        &self.columns[feature_index]
866    }
867}
868
869fn infer_numeric_bin(value: f64, boundaries: &[NumericBinBoundary]) -> u16 {
870    boundaries
871        .iter()
872        .find(|boundary| value <= boundary.upper_bound)
873        .map_or_else(
874            || boundaries.last().map_or(0, |boundary| boundary.bin),
875            |boundary| boundary.bin,
876        )
877}
878
879fn compact_binned_column(values: &[u16]) -> CompactBinnedColumn {
880    if values.iter().all(|value| *value <= u16::from(u8::MAX)) {
881        CompactBinnedColumn::U8(values.iter().map(|value| *value as u8).collect())
882    } else {
883        CompactBinnedColumn::U16(values.to_vec())
884    }
885}
886
887impl TableAccess for InferenceTable {
888    fn n_rows(&self) -> usize {
889        self.n_rows
890    }
891
892    fn n_features(&self) -> usize {
893        self.feature_columns.len()
894    }
895
896    fn canaries(&self) -> usize {
897        0
898    }
899
900    fn numeric_bin_cap(&self) -> usize {
901        MAX_NUMERIC_BINS
902    }
903
904    fn binned_feature_count(&self) -> usize {
905        self.binned_feature_columns.len()
906    }
907
908    fn feature_value(&self, feature_index: usize, row_index: usize) -> f64 {
909        match &self.feature_columns[feature_index] {
910            InferenceFeatureColumn::Numeric(values) => values[row_index],
911            InferenceFeatureColumn::Binary(values) => f64::from(u8::from(values[row_index])),
912        }
913    }
914
915    fn is_binary_feature(&self, index: usize) -> bool {
916        matches!(
917            self.feature_columns[index],
918            InferenceFeatureColumn::Binary(_)
919        )
920    }
921
922    fn binned_value(&self, feature_index: usize, row_index: usize) -> u16 {
923        match &self.binned_feature_columns[feature_index] {
924            InferenceBinnedColumn::Numeric(values) => values[row_index],
925            InferenceBinnedColumn::Binary(values) => u16::from(values[row_index]),
926        }
927    }
928
929    fn binned_boolean_value(&self, feature_index: usize, row_index: usize) -> Option<bool> {
930        match &self.binned_feature_columns[feature_index] {
931            InferenceBinnedColumn::Numeric(_) => None,
932            InferenceBinnedColumn::Binary(values) => Some(values[row_index]),
933        }
934    }
935
936    fn binned_column_kind(&self, index: usize) -> BinnedColumnKind {
937        BinnedColumnKind::Real {
938            source_index: index,
939        }
940    }
941
942    fn is_binary_binned_feature(&self, index: usize) -> bool {
943        matches!(
944            self.binned_feature_columns[index],
945            InferenceBinnedColumn::Binary(_)
946        )
947    }
948
949    fn target_value(&self, _row_index: usize) -> f64 {
950        0.0
951    }
952}
953
954#[derive(Debug, Clone, Serialize, Deserialize)]
955enum OptimizedRuntime {
956    BinaryClassifier {
957        nodes: Vec<OptimizedBinaryClassifierNode>,
958        class_labels: Vec<f64>,
959    },
960    StandardClassifier {
961        nodes: Vec<OptimizedClassifierNode>,
962        root: usize,
963        class_labels: Vec<f64>,
964    },
965    ObliviousClassifier {
966        feature_indices: Vec<usize>,
967        threshold_bins: Vec<u16>,
968        leaf_values: Vec<Vec<f64>>,
969        class_labels: Vec<f64>,
970    },
971    BinaryRegressor {
972        nodes: Vec<OptimizedBinaryRegressorNode>,
973    },
974    ObliviousRegressor {
975        feature_indices: Vec<usize>,
976        threshold_bins: Vec<u16>,
977        leaf_values: Vec<f64>,
978    },
979    ForestClassifier {
980        trees: Vec<OptimizedRuntime>,
981        class_labels: Vec<f64>,
982    },
983    ForestRegressor {
984        trees: Vec<OptimizedRuntime>,
985    },
986    BoostedBinaryClassifier {
987        trees: Vec<OptimizedRuntime>,
988        tree_weights: Vec<f64>,
989        base_score: f64,
990        class_labels: Vec<f64>,
991    },
992    BoostedRegressor {
993        trees: Vec<OptimizedRuntime>,
994        tree_weights: Vec<f64>,
995        base_score: f64,
996    },
997}
998
999#[derive(Debug, Clone, Serialize, Deserialize)]
1000enum OptimizedClassifierNode {
1001    Leaf(Vec<f64>),
1002    Binary {
1003        feature_index: usize,
1004        threshold_bin: u16,
1005        children: [usize; 2],
1006    },
1007    Multiway {
1008        feature_index: usize,
1009        child_lookup: Vec<usize>,
1010        max_bin_index: usize,
1011        fallback_probabilities: Vec<f64>,
1012    },
1013}
1014
1015#[derive(Debug, Clone, Serialize, Deserialize)]
1016enum OptimizedBinaryClassifierNode {
1017    Leaf(Vec<f64>),
1018    Branch {
1019        feature_index: usize,
1020        threshold_bin: u16,
1021        jump_index: usize,
1022        jump_if_greater: bool,
1023    },
1024}
1025
1026#[derive(Debug, Clone, Serialize, Deserialize)]
1027enum OptimizedBinaryRegressorNode {
1028    Leaf(f64),
1029    Branch {
1030        feature_index: usize,
1031        threshold_bin: u16,
1032        jump_index: usize,
1033        jump_if_greater: bool,
1034    },
1035}
1036
1037#[derive(Debug, Clone)]
1038struct InferenceExecutor {
1039    thread_count: usize,
1040    pool: Option<Arc<rayon::ThreadPool>>,
1041}
1042
1043#[derive(Debug, Clone, Serialize, Deserialize)]
1044struct CompiledArtifactPayload {
1045    semantic_ir: ModelPackageIr,
1046    runtime: OptimizedRuntime,
1047}
1048
1049impl InferenceExecutor {
1050    fn new(thread_count: usize) -> Result<Self, OptimizeError> {
1051        let pool = if thread_count > 1 {
1052            Some(Arc::new(
1053                ThreadPoolBuilder::new()
1054                    .num_threads(thread_count)
1055                    .build()
1056                    .map_err(|err| OptimizeError::ThreadPoolBuildFailed(err.to_string()))?,
1057            ))
1058        } else {
1059            None
1060        };
1061
1062        Ok(Self { thread_count, pool })
1063    }
1064
1065    fn predict_rows<F>(&self, n_rows: usize, predict_row: F) -> Vec<f64>
1066    where
1067        F: Fn(usize) -> f64 + Sync + Send,
1068    {
1069        if self.thread_count == 1 || n_rows < PARALLEL_INFERENCE_ROW_THRESHOLD {
1070            return (0..n_rows).map(predict_row).collect();
1071        }
1072
1073        self.pool
1074            .as_ref()
1075            .expect("thread pool exists when parallel inference is enabled")
1076            .install(|| (0..n_rows).into_par_iter().map(predict_row).collect())
1077    }
1078
1079    fn fill_chunks<F>(&self, outputs: &mut [f64], chunk_rows: usize, fill_chunk: F)
1080    where
1081        F: Fn(usize, &mut [f64]) + Sync + Send,
1082    {
1083        if self.thread_count == 1 || outputs.len() < PARALLEL_INFERENCE_ROW_THRESHOLD {
1084            for (chunk_index, chunk) in outputs.chunks_mut(chunk_rows).enumerate() {
1085                fill_chunk(chunk_index * chunk_rows, chunk);
1086            }
1087            return;
1088        }
1089
1090        self.pool
1091            .as_ref()
1092            .expect("thread pool exists when parallel inference is enabled")
1093            .install(|| {
1094                outputs
1095                    .par_chunks_mut(chunk_rows)
1096                    .enumerate()
1097                    .for_each(|(chunk_index, chunk)| fill_chunk(chunk_index * chunk_rows, chunk));
1098            });
1099    }
1100}
1101
1102#[derive(Debug, Clone)]
1103pub struct OptimizedModel {
1104    source_model: Model,
1105    runtime: OptimizedRuntime,
1106    executor: InferenceExecutor,
1107}
1108
1109impl OptimizedModel {
1110    fn new(source_model: Model, physical_cores: Option<usize>) -> Result<Self, OptimizeError> {
1111        let thread_count = resolve_inference_thread_count(physical_cores)?;
1112        let runtime = OptimizedRuntime::from_model(&source_model);
1113        let executor = InferenceExecutor::new(thread_count)?;
1114
1115        Ok(Self {
1116            source_model,
1117            runtime,
1118            executor,
1119        })
1120    }
1121
1122    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
1123        if self.runtime.should_use_batch_matrix(table.n_rows()) {
1124            let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1125            return self.predict_column_major_binned_matrix(&matrix);
1126        }
1127
1128        self.executor.predict_rows(table.n_rows(), |row_index| {
1129            self.runtime.predict_table_row(table, row_index)
1130        })
1131    }
1132
1133    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
1134        let table = InferenceTable::from_rows(rows, self.source_model.feature_preprocessing())?;
1135        if self.runtime.should_use_batch_matrix(table.n_rows()) {
1136            let matrix = table.to_column_major_binned_matrix();
1137            Ok(self.predict_column_major_binned_matrix(&matrix))
1138        } else {
1139            Ok(self.predict_table(&table))
1140        }
1141    }
1142
1143    pub fn predict_named_columns(
1144        &self,
1145        columns: BTreeMap<String, Vec<f64>>,
1146    ) -> Result<Vec<f64>, PredictError> {
1147        let table =
1148            InferenceTable::from_named_columns(columns, self.source_model.feature_preprocessing())?;
1149        if self.runtime.should_use_batch_matrix(table.n_rows()) {
1150            let matrix = table.to_column_major_binned_matrix();
1151            Ok(self.predict_column_major_binned_matrix(&matrix))
1152        } else {
1153            Ok(self.predict_table(&table))
1154        }
1155    }
1156
1157    pub fn predict_proba_table(
1158        &self,
1159        table: &dyn TableAccess,
1160    ) -> Result<Vec<Vec<f64>>, PredictError> {
1161        self.runtime.predict_proba_table(table, &self.executor)
1162    }
1163
1164    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
1165        let table = InferenceTable::from_rows(rows, self.source_model.feature_preprocessing())?;
1166        self.predict_proba_table(&table)
1167    }
1168
1169    pub fn predict_proba_named_columns(
1170        &self,
1171        columns: BTreeMap<String, Vec<f64>>,
1172    ) -> Result<Vec<Vec<f64>>, PredictError> {
1173        let table =
1174            InferenceTable::from_named_columns(columns, self.source_model.feature_preprocessing())?;
1175        self.predict_proba_table(&table)
1176    }
1177
1178    pub fn predict_proba_sparse_binary_columns(
1179        &self,
1180        n_rows: usize,
1181        n_features: usize,
1182        columns: Vec<Vec<usize>>,
1183    ) -> Result<Vec<Vec<f64>>, PredictError> {
1184        let table = InferenceTable::from_sparse_binary_columns(
1185            n_rows,
1186            n_features,
1187            columns,
1188            self.source_model.feature_preprocessing(),
1189        )?;
1190        self.predict_proba_table(&table)
1191    }
1192
1193    pub fn predict_sparse_binary_columns(
1194        &self,
1195        n_rows: usize,
1196        n_features: usize,
1197        columns: Vec<Vec<usize>>,
1198    ) -> Result<Vec<f64>, PredictError> {
1199        let table = InferenceTable::from_sparse_binary_columns(
1200            n_rows,
1201            n_features,
1202            columns,
1203            self.source_model.feature_preprocessing(),
1204        )?;
1205        if self.runtime.should_use_batch_matrix(table.n_rows()) {
1206            let matrix = table.to_column_major_binned_matrix();
1207            Ok(self.predict_column_major_binned_matrix(&matrix))
1208        } else {
1209            Ok(self.predict_table(&table))
1210        }
1211    }
1212
1213    #[cfg(feature = "polars")]
1214    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
1215        let columns = polars_named_columns(df)?;
1216        self.predict_named_columns(columns)
1217    }
1218
1219    #[cfg(feature = "polars")]
1220    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
1221        let mut predictions = Vec::new();
1222        let mut offset = 0i64;
1223        loop {
1224            let batch = lf
1225                .clone()
1226                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
1227                .collect()?;
1228            let height = batch.height();
1229            if height == 0 {
1230                break;
1231            }
1232            predictions.extend(self.predict_polars_dataframe(&batch)?);
1233            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
1234                break;
1235            }
1236            offset += height as i64;
1237        }
1238        Ok(predictions)
1239    }
1240
1241    pub fn algorithm(&self) -> TrainAlgorithm {
1242        self.source_model.algorithm()
1243    }
1244
1245    pub fn task(&self) -> Task {
1246        self.source_model.task()
1247    }
1248
1249    pub fn criterion(&self) -> Criterion {
1250        self.source_model.criterion()
1251    }
1252
1253    pub fn tree_type(&self) -> TreeType {
1254        self.source_model.tree_type()
1255    }
1256
1257    pub fn mean_value(&self) -> Option<f64> {
1258        self.source_model.mean_value()
1259    }
1260
1261    pub fn canaries(&self) -> usize {
1262        self.source_model.canaries()
1263    }
1264
1265    pub fn max_depth(&self) -> Option<usize> {
1266        self.source_model.max_depth()
1267    }
1268
1269    pub fn min_samples_split(&self) -> Option<usize> {
1270        self.source_model.min_samples_split()
1271    }
1272
1273    pub fn min_samples_leaf(&self) -> Option<usize> {
1274        self.source_model.min_samples_leaf()
1275    }
1276
1277    pub fn n_trees(&self) -> Option<usize> {
1278        self.source_model.n_trees()
1279    }
1280
1281    pub fn max_features(&self) -> Option<usize> {
1282        self.source_model.max_features()
1283    }
1284
1285    pub fn seed(&self) -> Option<u64> {
1286        self.source_model.seed()
1287    }
1288
1289    pub fn compute_oob(&self) -> bool {
1290        self.source_model.compute_oob()
1291    }
1292
1293    pub fn oob_score(&self) -> Option<f64> {
1294        self.source_model.oob_score()
1295    }
1296
1297    pub fn learning_rate(&self) -> Option<f64> {
1298        self.source_model.learning_rate()
1299    }
1300
1301    pub fn bootstrap(&self) -> bool {
1302        self.source_model.bootstrap()
1303    }
1304
1305    pub fn top_gradient_fraction(&self) -> Option<f64> {
1306        self.source_model.top_gradient_fraction()
1307    }
1308
1309    pub fn other_gradient_fraction(&self) -> Option<f64> {
1310        self.source_model.other_gradient_fraction()
1311    }
1312
1313    pub fn tree_count(&self) -> usize {
1314        self.source_model.tree_count()
1315    }
1316
1317    pub fn tree_structure(
1318        &self,
1319        tree_index: usize,
1320    ) -> Result<TreeStructureSummary, IntrospectionError> {
1321        self.source_model.tree_structure(tree_index)
1322    }
1323
1324    pub fn tree_prediction_stats(
1325        &self,
1326        tree_index: usize,
1327    ) -> Result<PredictionValueStats, IntrospectionError> {
1328        self.source_model.tree_prediction_stats(tree_index)
1329    }
1330
1331    pub fn tree_node(
1332        &self,
1333        tree_index: usize,
1334        node_index: usize,
1335    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
1336        self.source_model.tree_node(tree_index, node_index)
1337    }
1338
1339    pub fn tree_level(
1340        &self,
1341        tree_index: usize,
1342        level_index: usize,
1343    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
1344        self.source_model.tree_level(tree_index, level_index)
1345    }
1346
1347    pub fn tree_leaf(
1348        &self,
1349        tree_index: usize,
1350        leaf_index: usize,
1351    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
1352        self.source_model.tree_leaf(tree_index, leaf_index)
1353    }
1354
1355    pub fn to_ir(&self) -> ModelPackageIr {
1356        self.source_model.to_ir()
1357    }
1358
1359    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
1360        self.source_model.to_ir_json()
1361    }
1362
1363    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
1364        self.source_model.to_ir_json_pretty()
1365    }
1366
1367    pub fn serialize(&self) -> Result<String, serde_json::Error> {
1368        self.source_model.serialize()
1369    }
1370
1371    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
1372        self.source_model.serialize_pretty()
1373    }
1374
1375    pub fn serialize_compiled(&self) -> Result<Vec<u8>, CompiledArtifactError> {
1376        let payload = CompiledArtifactPayload {
1377            semantic_ir: self.source_model.to_ir(),
1378            runtime: self.runtime.clone(),
1379        };
1380        let mut payload_bytes = Vec::new();
1381        ciborium::into_writer(&payload, &mut payload_bytes)
1382            .map_err(|err| CompiledArtifactError::Encode(err.to_string()))?;
1383        let mut bytes = Vec::with_capacity(COMPILED_ARTIFACT_HEADER_LEN + payload_bytes.len());
1384        bytes.extend_from_slice(&COMPILED_ARTIFACT_MAGIC);
1385        bytes.extend_from_slice(&COMPILED_ARTIFACT_VERSION.to_le_bytes());
1386        bytes.extend_from_slice(&COMPILED_ARTIFACT_BACKEND_CPU.to_le_bytes());
1387        bytes.extend_from_slice(&payload_bytes);
1388        Ok(bytes)
1389    }
1390
1391    pub fn deserialize_compiled(
1392        serialized: &[u8],
1393        physical_cores: Option<usize>,
1394    ) -> Result<Self, CompiledArtifactError> {
1395        if serialized.len() < COMPILED_ARTIFACT_HEADER_LEN {
1396            return Err(CompiledArtifactError::ArtifactTooShort {
1397                actual: serialized.len(),
1398                minimum: COMPILED_ARTIFACT_HEADER_LEN,
1399            });
1400        }
1401
1402        let magic = [serialized[0], serialized[1], serialized[2], serialized[3]];
1403        if magic != COMPILED_ARTIFACT_MAGIC {
1404            return Err(CompiledArtifactError::InvalidMagic(magic));
1405        }
1406
1407        let version = u16::from_le_bytes([serialized[4], serialized[5]]);
1408        if version != COMPILED_ARTIFACT_VERSION {
1409            return Err(CompiledArtifactError::UnsupportedVersion(version));
1410        }
1411
1412        let backend = u16::from_le_bytes([serialized[6], serialized[7]]);
1413        if backend != COMPILED_ARTIFACT_BACKEND_CPU {
1414            return Err(CompiledArtifactError::UnsupportedBackend(backend));
1415        }
1416
1417        let payload: CompiledArtifactPayload = ciborium::from_reader(std::io::Cursor::new(
1418            &serialized[COMPILED_ARTIFACT_HEADER_LEN..],
1419        ))
1420        .map_err(|err| CompiledArtifactError::Decode(err.to_string()))?;
1421        let source_model = ir::model_from_ir(payload.semantic_ir)
1422            .map_err(CompiledArtifactError::InvalidSemanticModel)?;
1423        let thread_count = resolve_inference_thread_count(physical_cores)
1424            .map_err(CompiledArtifactError::InvalidRuntime)?;
1425        let executor =
1426            InferenceExecutor::new(thread_count).map_err(CompiledArtifactError::InvalidRuntime)?;
1427
1428        Ok(Self {
1429            source_model,
1430            runtime: payload.runtime,
1431            executor,
1432        })
1433    }
1434
1435    fn predict_column_major_binned_matrix(&self, matrix: &ColumnMajorBinnedMatrix) -> Vec<f64> {
1436        self.runtime
1437            .predict_column_major_matrix(matrix, &self.executor)
1438    }
1439}
1440
1441impl OptimizedRuntime {
1442    fn supports_batch_matrix(&self) -> bool {
1443        matches!(
1444            self,
1445            OptimizedRuntime::BinaryClassifier { .. }
1446                | OptimizedRuntime::BinaryRegressor { .. }
1447                | OptimizedRuntime::ObliviousClassifier { .. }
1448                | OptimizedRuntime::ObliviousRegressor { .. }
1449                | OptimizedRuntime::ForestClassifier { .. }
1450                | OptimizedRuntime::ForestRegressor { .. }
1451                | OptimizedRuntime::BoostedBinaryClassifier { .. }
1452                | OptimizedRuntime::BoostedRegressor { .. }
1453        )
1454    }
1455
1456    fn should_use_batch_matrix(&self, n_rows: usize) -> bool {
1457        n_rows > 1 && self.supports_batch_matrix()
1458    }
1459
1460    fn from_model(model: &Model) -> Self {
1461        match model {
1462            Model::DecisionTreeClassifier(classifier) => Self::from_classifier(classifier),
1463            Model::DecisionTreeRegressor(regressor) => Self::from_regressor(regressor),
1464            Model::RandomForest(forest) => match forest.task() {
1465                Task::Classification => Self::ForestClassifier {
1466                    trees: forest.trees().iter().map(Self::from_model).collect(),
1467                    class_labels: forest
1468                        .class_labels()
1469                        .expect("classification forest stores class labels"),
1470                },
1471                Task::Regression => Self::ForestRegressor {
1472                    trees: forest.trees().iter().map(Self::from_model).collect(),
1473                },
1474            },
1475            Model::GradientBoostedTrees(model) => match model.task() {
1476                Task::Classification => Self::BoostedBinaryClassifier {
1477                    trees: model.trees().iter().map(Self::from_model).collect(),
1478                    tree_weights: model.tree_weights().to_vec(),
1479                    base_score: model.base_score(),
1480                    class_labels: model
1481                        .class_labels()
1482                        .expect("classification boosting stores class labels"),
1483                },
1484                Task::Regression => Self::BoostedRegressor {
1485                    trees: model.trees().iter().map(Self::from_model).collect(),
1486                    tree_weights: model.tree_weights().to_vec(),
1487                    base_score: model.base_score(),
1488                },
1489            },
1490        }
1491    }
1492
1493    fn from_classifier(classifier: &DecisionTreeClassifier) -> Self {
1494        match classifier.structure() {
1495            tree::classifier::TreeStructure::Standard { nodes, root } => {
1496                if classifier_nodes_are_binary_only(nodes) {
1497                    return Self::BinaryClassifier {
1498                        nodes: build_binary_classifier_layout(
1499                            nodes,
1500                            *root,
1501                            classifier.class_labels(),
1502                        ),
1503                        class_labels: classifier.class_labels().to_vec(),
1504                    };
1505                }
1506
1507                let optimized_nodes = nodes
1508                    .iter()
1509                    .map(|node| match node {
1510                        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
1511                            OptimizedClassifierNode::Leaf(normalized_probabilities_from_counts(
1512                                class_counts,
1513                            ))
1514                        }
1515                        tree::classifier::TreeNode::BinarySplit {
1516                            feature_index,
1517                            threshold_bin,
1518                            left_child,
1519                            right_child,
1520                            ..
1521                        } => OptimizedClassifierNode::Binary {
1522                            feature_index: *feature_index,
1523                            threshold_bin: *threshold_bin,
1524                            children: [*left_child, *right_child],
1525                        },
1526                        tree::classifier::TreeNode::MultiwaySplit {
1527                            feature_index,
1528                            class_counts,
1529                            branches,
1530                            ..
1531                        } => {
1532                            let max_bin_index = branches
1533                                .iter()
1534                                .map(|(bin, _)| usize::from(*bin))
1535                                .max()
1536                                .unwrap_or(0);
1537                            let mut child_lookup = vec![usize::MAX; max_bin_index + 1];
1538                            for (bin, child_index) in branches {
1539                                child_lookup[usize::from(*bin)] = *child_index;
1540                            }
1541                            OptimizedClassifierNode::Multiway {
1542                                feature_index: *feature_index,
1543                                child_lookup,
1544                                max_bin_index,
1545                                fallback_probabilities: normalized_probabilities_from_counts(
1546                                    class_counts,
1547                                ),
1548                            }
1549                        }
1550                    })
1551                    .collect();
1552
1553                Self::StandardClassifier {
1554                    nodes: optimized_nodes,
1555                    root: *root,
1556                    class_labels: classifier.class_labels().to_vec(),
1557                }
1558            }
1559            tree::classifier::TreeStructure::Oblivious {
1560                splits,
1561                leaf_class_counts,
1562                ..
1563            } => Self::ObliviousClassifier {
1564                feature_indices: splits.iter().map(|split| split.feature_index).collect(),
1565                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
1566                leaf_values: leaf_class_counts
1567                    .iter()
1568                    .map(|class_counts| normalized_probabilities_from_counts(class_counts))
1569                    .collect(),
1570                class_labels: classifier.class_labels().to_vec(),
1571            },
1572        }
1573    }
1574
1575    fn from_regressor(regressor: &DecisionTreeRegressor) -> Self {
1576        match regressor.structure() {
1577            tree::regressor::RegressionTreeStructure::Standard { nodes, root } => {
1578                Self::BinaryRegressor {
1579                    nodes: build_binary_regressor_layout(nodes, *root),
1580                }
1581            }
1582            tree::regressor::RegressionTreeStructure::Oblivious {
1583                splits,
1584                leaf_values,
1585                ..
1586            } => Self::ObliviousRegressor {
1587                feature_indices: splits.iter().map(|split| split.feature_index).collect(),
1588                threshold_bins: splits.iter().map(|split| split.threshold_bin).collect(),
1589                leaf_values: leaf_values.clone(),
1590            },
1591        }
1592    }
1593
1594    #[inline(always)]
1595    fn predict_table_row(&self, table: &dyn TableAccess, row_index: usize) -> f64 {
1596        match self {
1597            OptimizedRuntime::BinaryClassifier { .. }
1598            | OptimizedRuntime::StandardClassifier { .. }
1599            | OptimizedRuntime::ObliviousClassifier { .. }
1600            | OptimizedRuntime::ForestClassifier { .. }
1601            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1602                let probabilities = self
1603                    .predict_proba_table_row(table, row_index)
1604                    .expect("classifier runtime supports probability prediction");
1605                class_label_from_probabilities(&probabilities, self.class_labels())
1606            }
1607            OptimizedRuntime::BinaryRegressor { nodes } => {
1608                predict_binary_regressor_row(nodes, |feature_index| {
1609                    table.binned_value(feature_index, row_index)
1610                })
1611            }
1612            OptimizedRuntime::ObliviousRegressor {
1613                feature_indices,
1614                threshold_bins,
1615                leaf_values,
1616            } => predict_oblivious_row(
1617                feature_indices,
1618                threshold_bins,
1619                leaf_values,
1620                |feature_index| table.binned_value(feature_index, row_index),
1621            ),
1622            OptimizedRuntime::ForestRegressor { trees } => {
1623                trees
1624                    .iter()
1625                    .map(|tree| tree.predict_table_row(table, row_index))
1626                    .sum::<f64>()
1627                    / trees.len() as f64
1628            }
1629            OptimizedRuntime::BoostedRegressor {
1630                trees,
1631                tree_weights,
1632                base_score,
1633            } => {
1634                *base_score
1635                    + trees
1636                        .iter()
1637                        .zip(tree_weights.iter().copied())
1638                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1639                        .sum::<f64>()
1640            }
1641        }
1642    }
1643
1644    #[inline(always)]
1645    fn predict_proba_table_row(
1646        &self,
1647        table: &dyn TableAccess,
1648        row_index: usize,
1649    ) -> Result<Vec<f64>, PredictError> {
1650        match self {
1651            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1652                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1653                    table.binned_value(feature_index, row_index)
1654                })
1655                .to_vec(),
1656            ),
1657            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1658                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1659                    table.binned_value(feature_index, row_index)
1660                })
1661                .to_vec(),
1662            ),
1663            OptimizedRuntime::ObliviousClassifier {
1664                feature_indices,
1665                threshold_bins,
1666                leaf_values,
1667                ..
1668            } => Ok(predict_oblivious_probabilities_row(
1669                feature_indices,
1670                threshold_bins,
1671                leaf_values,
1672                |feature_index| table.binned_value(feature_index, row_index),
1673            )
1674            .to_vec()),
1675            OptimizedRuntime::ForestClassifier { trees, .. } => {
1676                let mut totals = trees[0].predict_proba_table_row(table, row_index)?;
1677                for tree in &trees[1..] {
1678                    let row = tree.predict_proba_table_row(table, row_index)?;
1679                    for (total, value) in totals.iter_mut().zip(row) {
1680                        *total += value;
1681                    }
1682                }
1683                let tree_count = trees.len() as f64;
1684                for value in &mut totals {
1685                    *value /= tree_count;
1686                }
1687                Ok(totals)
1688            }
1689            OptimizedRuntime::BoostedBinaryClassifier {
1690                trees,
1691                tree_weights,
1692                base_score,
1693                ..
1694            } => {
1695                let raw_score = *base_score
1696                    + trees
1697                        .iter()
1698                        .zip(tree_weights.iter().copied())
1699                        .map(|(tree, weight)| weight * tree.predict_table_row(table, row_index))
1700                        .sum::<f64>();
1701                let positive = sigmoid(raw_score);
1702                Ok(vec![1.0 - positive, positive])
1703            }
1704            OptimizedRuntime::BinaryRegressor { .. }
1705            | OptimizedRuntime::ObliviousRegressor { .. }
1706            | OptimizedRuntime::ForestRegressor { .. }
1707            | OptimizedRuntime::BoostedRegressor { .. } => {
1708                Err(PredictError::ProbabilityPredictionRequiresClassification)
1709            }
1710        }
1711    }
1712
1713    fn predict_proba_table(
1714        &self,
1715        table: &dyn TableAccess,
1716        executor: &InferenceExecutor,
1717    ) -> Result<Vec<Vec<f64>>, PredictError> {
1718        match self {
1719            OptimizedRuntime::BinaryClassifier { .. }
1720            | OptimizedRuntime::StandardClassifier { .. }
1721            | OptimizedRuntime::ObliviousClassifier { .. }
1722            | OptimizedRuntime::ForestClassifier { .. }
1723            | OptimizedRuntime::BoostedBinaryClassifier { .. } => {
1724                if self.should_use_batch_matrix(table.n_rows()) {
1725                    let matrix = ColumnMajorBinnedMatrix::from_table_access(table);
1726                    self.predict_proba_column_major_matrix(&matrix, executor)
1727                } else {
1728                    (0..table.n_rows())
1729                        .map(|row_index| self.predict_proba_table_row(table, row_index))
1730                        .collect()
1731                }
1732            }
1733            OptimizedRuntime::BinaryRegressor { .. }
1734            | OptimizedRuntime::ObliviousRegressor { .. }
1735            | OptimizedRuntime::ForestRegressor { .. }
1736            | OptimizedRuntime::BoostedRegressor { .. } => {
1737                Err(PredictError::ProbabilityPredictionRequiresClassification)
1738            }
1739        }
1740    }
1741
1742    fn predict_column_major_matrix(
1743        &self,
1744        matrix: &ColumnMajorBinnedMatrix,
1745        executor: &InferenceExecutor,
1746    ) -> Vec<f64> {
1747        match self {
1748            OptimizedRuntime::BinaryClassifier { .. }
1749            | OptimizedRuntime::StandardClassifier { .. }
1750            | OptimizedRuntime::ObliviousClassifier { .. }
1751            | OptimizedRuntime::ForestClassifier { .. }
1752            | OptimizedRuntime::BoostedBinaryClassifier { .. } => self
1753                .predict_proba_column_major_matrix(matrix, executor)
1754                .expect("classifier runtime supports probability prediction")
1755                .into_iter()
1756                .map(|row| class_label_from_probabilities(&row, self.class_labels()))
1757                .collect(),
1758            OptimizedRuntime::BinaryRegressor { nodes } => {
1759                predict_binary_regressor_column_major_matrix(nodes, matrix, executor)
1760            }
1761            OptimizedRuntime::ObliviousRegressor {
1762                feature_indices,
1763                threshold_bins,
1764                leaf_values,
1765            } => predict_oblivious_column_major_matrix(
1766                feature_indices,
1767                threshold_bins,
1768                leaf_values,
1769                matrix,
1770                executor,
1771            ),
1772            OptimizedRuntime::ForestRegressor { trees } => {
1773                let mut totals = trees[0].predict_column_major_matrix(matrix, executor);
1774                for tree in &trees[1..] {
1775                    let values = tree.predict_column_major_matrix(matrix, executor);
1776                    for (total, value) in totals.iter_mut().zip(values) {
1777                        *total += value;
1778                    }
1779                }
1780                let tree_count = trees.len() as f64;
1781                for total in &mut totals {
1782                    *total /= tree_count;
1783                }
1784                totals
1785            }
1786            OptimizedRuntime::BoostedRegressor {
1787                trees,
1788                tree_weights,
1789                base_score,
1790            } => {
1791                let mut totals = vec![*base_score; matrix.n_rows];
1792                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1793                    let values = tree.predict_column_major_matrix(matrix, executor);
1794                    for (total, value) in totals.iter_mut().zip(values) {
1795                        *total += weight * value;
1796                    }
1797                }
1798                totals
1799            }
1800        }
1801    }
1802
1803    fn predict_proba_column_major_matrix(
1804        &self,
1805        matrix: &ColumnMajorBinnedMatrix,
1806        executor: &InferenceExecutor,
1807    ) -> Result<Vec<Vec<f64>>, PredictError> {
1808        match self {
1809            OptimizedRuntime::BinaryClassifier { nodes, .. } => {
1810                Ok(predict_binary_classifier_probabilities_column_major_matrix(
1811                    nodes, matrix, executor,
1812                ))
1813            }
1814            OptimizedRuntime::StandardClassifier { .. } => Ok((0..matrix.n_rows)
1815                .map(|row_index| {
1816                    self.predict_proba_binned_row_from_columns(matrix, row_index)
1817                        .expect("classifier runtime supports probability prediction")
1818                })
1819                .collect()),
1820            OptimizedRuntime::ObliviousClassifier {
1821                feature_indices,
1822                threshold_bins,
1823                leaf_values,
1824                ..
1825            } => Ok(predict_oblivious_probabilities_column_major_matrix(
1826                feature_indices,
1827                threshold_bins,
1828                leaf_values,
1829                matrix,
1830                executor,
1831            )),
1832            OptimizedRuntime::ForestClassifier { trees, .. } => {
1833                let mut totals = trees[0].predict_proba_column_major_matrix(matrix, executor)?;
1834                for tree in &trees[1..] {
1835                    let rows = tree.predict_proba_column_major_matrix(matrix, executor)?;
1836                    for (row_totals, row_values) in totals.iter_mut().zip(rows) {
1837                        for (total, value) in row_totals.iter_mut().zip(row_values) {
1838                            *total += value;
1839                        }
1840                    }
1841                }
1842                let tree_count = trees.len() as f64;
1843                for row in &mut totals {
1844                    for value in row {
1845                        *value /= tree_count;
1846                    }
1847                }
1848                Ok(totals)
1849            }
1850            OptimizedRuntime::BoostedBinaryClassifier {
1851                trees,
1852                tree_weights,
1853                base_score,
1854                ..
1855            } => {
1856                let mut raw_scores = vec![*base_score; matrix.n_rows];
1857                for (tree, weight) in trees.iter().zip(tree_weights.iter().copied()) {
1858                    let values = tree.predict_column_major_matrix(matrix, executor);
1859                    for (raw_score, value) in raw_scores.iter_mut().zip(values) {
1860                        *raw_score += weight * value;
1861                    }
1862                }
1863                Ok(raw_scores
1864                    .into_iter()
1865                    .map(|raw_score| {
1866                        let positive = sigmoid(raw_score);
1867                        vec![1.0 - positive, positive]
1868                    })
1869                    .collect())
1870            }
1871            OptimizedRuntime::BinaryRegressor { .. }
1872            | OptimizedRuntime::ObliviousRegressor { .. }
1873            | OptimizedRuntime::ForestRegressor { .. }
1874            | OptimizedRuntime::BoostedRegressor { .. } => {
1875                Err(PredictError::ProbabilityPredictionRequiresClassification)
1876            }
1877        }
1878    }
1879
1880    fn class_labels(&self) -> &[f64] {
1881        match self {
1882            OptimizedRuntime::BinaryClassifier { class_labels, .. }
1883            | OptimizedRuntime::StandardClassifier { class_labels, .. }
1884            | OptimizedRuntime::ObliviousClassifier { class_labels, .. }
1885            | OptimizedRuntime::ForestClassifier { class_labels, .. }
1886            | OptimizedRuntime::BoostedBinaryClassifier { class_labels, .. } => class_labels,
1887            _ => &[],
1888        }
1889    }
1890
1891    #[inline(always)]
1892    fn predict_binned_row_from_columns(
1893        &self,
1894        matrix: &ColumnMajorBinnedMatrix,
1895        row_index: usize,
1896    ) -> f64 {
1897        match self {
1898            OptimizedRuntime::BinaryRegressor { nodes } => {
1899                predict_binary_regressor_row(nodes, |feature_index| {
1900                    matrix.column(feature_index).value_at(row_index)
1901                })
1902            }
1903            OptimizedRuntime::ObliviousRegressor {
1904                feature_indices,
1905                threshold_bins,
1906                leaf_values,
1907            } => predict_oblivious_row(
1908                feature_indices,
1909                threshold_bins,
1910                leaf_values,
1911                |feature_index| matrix.column(feature_index).value_at(row_index),
1912            ),
1913            OptimizedRuntime::BoostedRegressor {
1914                trees,
1915                tree_weights,
1916                base_score,
1917            } => {
1918                *base_score
1919                    + trees
1920                        .iter()
1921                        .zip(tree_weights.iter().copied())
1922                        .map(|(tree, weight)| {
1923                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1924                        })
1925                        .sum::<f64>()
1926            }
1927            _ => self.predict_column_major_matrix(
1928                matrix,
1929                &InferenceExecutor::new(1).expect("inference executor"),
1930            )[row_index],
1931        }
1932    }
1933
1934    #[inline(always)]
1935    fn predict_proba_binned_row_from_columns(
1936        &self,
1937        matrix: &ColumnMajorBinnedMatrix,
1938        row_index: usize,
1939    ) -> Result<Vec<f64>, PredictError> {
1940        match self {
1941            OptimizedRuntime::BinaryClassifier { nodes, .. } => Ok(
1942                predict_binary_classifier_probabilities_row(nodes, |feature_index| {
1943                    matrix.column(feature_index).value_at(row_index)
1944                })
1945                .to_vec(),
1946            ),
1947            OptimizedRuntime::StandardClassifier { nodes, root, .. } => Ok(
1948                predict_standard_classifier_probabilities_row(nodes, *root, |feature_index| {
1949                    matrix.column(feature_index).value_at(row_index)
1950                })
1951                .to_vec(),
1952            ),
1953            OptimizedRuntime::ObliviousClassifier {
1954                feature_indices,
1955                threshold_bins,
1956                leaf_values,
1957                ..
1958            } => Ok(predict_oblivious_probabilities_row(
1959                feature_indices,
1960                threshold_bins,
1961                leaf_values,
1962                |feature_index| matrix.column(feature_index).value_at(row_index),
1963            )
1964            .to_vec()),
1965            OptimizedRuntime::ForestClassifier { trees, .. } => {
1966                let mut totals =
1967                    trees[0].predict_proba_binned_row_from_columns(matrix, row_index)?;
1968                for tree in &trees[1..] {
1969                    let row = tree.predict_proba_binned_row_from_columns(matrix, row_index)?;
1970                    for (total, value) in totals.iter_mut().zip(row) {
1971                        *total += value;
1972                    }
1973                }
1974                let tree_count = trees.len() as f64;
1975                for value in &mut totals {
1976                    *value /= tree_count;
1977                }
1978                Ok(totals)
1979            }
1980            OptimizedRuntime::BoostedBinaryClassifier {
1981                trees,
1982                tree_weights,
1983                base_score,
1984                ..
1985            } => {
1986                let raw_score = *base_score
1987                    + trees
1988                        .iter()
1989                        .zip(tree_weights.iter().copied())
1990                        .map(|(tree, weight)| {
1991                            weight * tree.predict_binned_row_from_columns(matrix, row_index)
1992                        })
1993                        .sum::<f64>();
1994                let positive = sigmoid(raw_score);
1995                Ok(vec![1.0 - positive, positive])
1996            }
1997            OptimizedRuntime::BinaryRegressor { .. }
1998            | OptimizedRuntime::ObliviousRegressor { .. }
1999            | OptimizedRuntime::ForestRegressor { .. }
2000            | OptimizedRuntime::BoostedRegressor { .. } => {
2001                Err(PredictError::ProbabilityPredictionRequiresClassification)
2002            }
2003        }
2004    }
2005}
2006
2007#[inline(always)]
2008fn predict_standard_classifier_probabilities_row<F>(
2009    nodes: &[OptimizedClassifierNode],
2010    root: usize,
2011    bin_at: F,
2012) -> &[f64]
2013where
2014    F: Fn(usize) -> u16,
2015{
2016    let mut node_index = root;
2017    loop {
2018        match &nodes[node_index] {
2019            OptimizedClassifierNode::Leaf(value) => return value,
2020            OptimizedClassifierNode::Binary {
2021                feature_index,
2022                threshold_bin,
2023                children,
2024            } => {
2025                let go_right = usize::from(bin_at(*feature_index) > *threshold_bin);
2026                node_index = children[go_right];
2027            }
2028            OptimizedClassifierNode::Multiway {
2029                feature_index,
2030                child_lookup,
2031                max_bin_index,
2032                fallback_probabilities,
2033            } => {
2034                let bin = usize::from(bin_at(*feature_index));
2035                if bin > *max_bin_index {
2036                    return fallback_probabilities;
2037                }
2038                let child_index = child_lookup[bin];
2039                if child_index == usize::MAX {
2040                    return fallback_probabilities;
2041                }
2042                node_index = child_index;
2043            }
2044        }
2045    }
2046}
2047
2048#[inline(always)]
2049fn predict_binary_classifier_probabilities_row<F>(
2050    nodes: &[OptimizedBinaryClassifierNode],
2051    bin_at: F,
2052) -> &[f64]
2053where
2054    F: Fn(usize) -> u16,
2055{
2056    let mut node_index = 0usize;
2057    loop {
2058        match &nodes[node_index] {
2059            OptimizedBinaryClassifierNode::Leaf(value) => return value,
2060            OptimizedBinaryClassifierNode::Branch {
2061                feature_index,
2062                threshold_bin,
2063                jump_index,
2064                jump_if_greater,
2065            } => {
2066                let go_right = bin_at(*feature_index) > *threshold_bin;
2067                node_index = if go_right == *jump_if_greater {
2068                    *jump_index
2069                } else {
2070                    node_index + 1
2071                };
2072            }
2073        }
2074    }
2075}
2076
2077#[inline(always)]
2078fn predict_binary_regressor_row<F>(nodes: &[OptimizedBinaryRegressorNode], bin_at: F) -> f64
2079where
2080    F: Fn(usize) -> u16,
2081{
2082    let mut node_index = 0usize;
2083    loop {
2084        match &nodes[node_index] {
2085            OptimizedBinaryRegressorNode::Leaf(value) => return *value,
2086            OptimizedBinaryRegressorNode::Branch {
2087                feature_index,
2088                threshold_bin,
2089                jump_index,
2090                jump_if_greater,
2091            } => {
2092                let go_right = bin_at(*feature_index) > *threshold_bin;
2093                node_index = if go_right == *jump_if_greater {
2094                    *jump_index
2095                } else {
2096                    node_index + 1
2097                };
2098            }
2099        }
2100    }
2101}
2102
2103fn predict_binary_classifier_probabilities_column_major_matrix(
2104    nodes: &[OptimizedBinaryClassifierNode],
2105    matrix: &ColumnMajorBinnedMatrix,
2106    _executor: &InferenceExecutor,
2107) -> Vec<Vec<f64>> {
2108    (0..matrix.n_rows)
2109        .map(|row_index| {
2110            predict_binary_classifier_probabilities_row(nodes, |feature_index| {
2111                matrix.column(feature_index).value_at(row_index)
2112            })
2113            .to_vec()
2114        })
2115        .collect()
2116}
2117
2118fn predict_binary_regressor_column_major_matrix(
2119    nodes: &[OptimizedBinaryRegressorNode],
2120    matrix: &ColumnMajorBinnedMatrix,
2121    executor: &InferenceExecutor,
2122) -> Vec<f64> {
2123    let mut outputs = vec![0.0; matrix.n_rows];
2124    executor.fill_chunks(
2125        &mut outputs,
2126        STANDARD_BATCH_INFERENCE_CHUNK_ROWS,
2127        |start_row, chunk| predict_binary_regressor_chunk(nodes, matrix, start_row, chunk),
2128    );
2129    outputs
2130}
2131
2132fn predict_binary_regressor_chunk(
2133    nodes: &[OptimizedBinaryRegressorNode],
2134    matrix: &ColumnMajorBinnedMatrix,
2135    start_row: usize,
2136    output: &mut [f64],
2137) {
2138    let mut row_indices: Vec<usize> = (0..output.len()).collect();
2139    let mut stack = vec![(0usize, 0usize, output.len())];
2140
2141    while let Some((node_index, start, end)) = stack.pop() {
2142        match &nodes[node_index] {
2143            OptimizedBinaryRegressorNode::Leaf(value) => {
2144                for position in start..end {
2145                    output[row_indices[position]] = *value;
2146                }
2147            }
2148            OptimizedBinaryRegressorNode::Branch {
2149                feature_index,
2150                threshold_bin,
2151                jump_index,
2152                jump_if_greater,
2153            } => {
2154                let fallthrough_index = node_index + 1;
2155                if *jump_index == fallthrough_index {
2156                    stack.push((fallthrough_index, start, end));
2157                    continue;
2158                }
2159
2160                let column = matrix.column(*feature_index);
2161                let mut partition = start;
2162                let mut jump_start = end;
2163                match column {
2164                    CompactBinnedColumn::U8(values) if *threshold_bin <= u16::from(u8::MAX) => {
2165                        let threshold = *threshold_bin as u8;
2166                        while partition < jump_start {
2167                            let row_offset = row_indices[partition];
2168                            let go_right = values[start_row + row_offset] > threshold;
2169                            let goes_jump = go_right == *jump_if_greater;
2170                            if goes_jump {
2171                                jump_start -= 1;
2172                                row_indices.swap(partition, jump_start);
2173                            } else {
2174                                partition += 1;
2175                            }
2176                        }
2177                    }
2178                    _ => {
2179                        while partition < jump_start {
2180                            let row_offset = row_indices[partition];
2181                            let go_right = column.value_at(start_row + row_offset) > *threshold_bin;
2182                            let goes_jump = go_right == *jump_if_greater;
2183                            if goes_jump {
2184                                jump_start -= 1;
2185                                row_indices.swap(partition, jump_start);
2186                            } else {
2187                                partition += 1;
2188                            }
2189                        }
2190                    }
2191                }
2192
2193                if jump_start < end {
2194                    stack.push((*jump_index, jump_start, end));
2195                }
2196                if start < jump_start {
2197                    stack.push((fallthrough_index, start, jump_start));
2198                }
2199            }
2200        }
2201    }
2202}
2203
2204#[inline(always)]
2205fn predict_oblivious_row<F>(
2206    feature_indices: &[usize],
2207    threshold_bins: &[u16],
2208    leaf_values: &[f64],
2209    bin_at: F,
2210) -> f64
2211where
2212    F: Fn(usize) -> u16,
2213{
2214    let mut leaf_index = 0usize;
2215    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2216        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
2217        leaf_index = (leaf_index << 1) | go_right;
2218    }
2219    leaf_values[leaf_index]
2220}
2221
2222#[inline(always)]
2223fn predict_oblivious_probabilities_row<'a, F>(
2224    feature_indices: &[usize],
2225    threshold_bins: &[u16],
2226    leaf_values: &'a [Vec<f64>],
2227    bin_at: F,
2228) -> &'a [f64]
2229where
2230    F: Fn(usize) -> u16,
2231{
2232    let mut leaf_index = 0usize;
2233    for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2234        let go_right = usize::from(bin_at(feature_index) > threshold_bin);
2235        leaf_index = (leaf_index << 1) | go_right;
2236    }
2237    leaf_values[leaf_index].as_slice()
2238}
2239
2240fn normalized_probabilities_from_counts(class_counts: &[usize]) -> Vec<f64> {
2241    let total = class_counts.iter().sum::<usize>();
2242    if total == 0 {
2243        return vec![0.0; class_counts.len()];
2244    }
2245
2246    class_counts
2247        .iter()
2248        .map(|count| *count as f64 / total as f64)
2249        .collect()
2250}
2251
2252fn class_label_from_probabilities(probabilities: &[f64], class_labels: &[f64]) -> f64 {
2253    let best_index = probabilities
2254        .iter()
2255        .copied()
2256        .enumerate()
2257        .max_by(|(left_index, left), (right_index, right)| {
2258            left.total_cmp(right)
2259                .then_with(|| right_index.cmp(left_index))
2260        })
2261        .map(|(index, _)| index)
2262        .expect("classification probability row is non-empty");
2263    class_labels[best_index]
2264}
2265
2266#[inline(always)]
2267fn sigmoid(value: f64) -> f64 {
2268    if value >= 0.0 {
2269        let exp = (-value).exp();
2270        1.0 / (1.0 + exp)
2271    } else {
2272        let exp = value.exp();
2273        exp / (1.0 + exp)
2274    }
2275}
2276
2277fn classifier_nodes_are_binary_only(nodes: &[tree::classifier::TreeNode]) -> bool {
2278    nodes.iter().all(|node| {
2279        matches!(
2280            node,
2281            tree::classifier::TreeNode::Leaf { .. }
2282                | tree::classifier::TreeNode::BinarySplit { .. }
2283        )
2284    })
2285}
2286
2287fn classifier_node_sample_count(nodes: &[tree::classifier::TreeNode], node_index: usize) -> usize {
2288    match &nodes[node_index] {
2289        tree::classifier::TreeNode::Leaf { sample_count, .. }
2290        | tree::classifier::TreeNode::BinarySplit { sample_count, .. }
2291        | tree::classifier::TreeNode::MultiwaySplit { sample_count, .. } => *sample_count,
2292    }
2293}
2294
2295fn build_binary_classifier_layout(
2296    nodes: &[tree::classifier::TreeNode],
2297    root: usize,
2298    _class_labels: &[f64],
2299) -> Vec<OptimizedBinaryClassifierNode> {
2300    let mut layout = Vec::with_capacity(nodes.len());
2301    append_binary_classifier_node(nodes, root, &mut layout);
2302    layout
2303}
2304
2305fn append_binary_classifier_node(
2306    nodes: &[tree::classifier::TreeNode],
2307    node_index: usize,
2308    layout: &mut Vec<OptimizedBinaryClassifierNode>,
2309) -> usize {
2310    let current_index = layout.len();
2311    layout.push(OptimizedBinaryClassifierNode::Leaf(Vec::new()));
2312
2313    match &nodes[node_index] {
2314        tree::classifier::TreeNode::Leaf { class_counts, .. } => {
2315            layout[current_index] = OptimizedBinaryClassifierNode::Leaf(
2316                normalized_probabilities_from_counts(class_counts),
2317            );
2318        }
2319        tree::classifier::TreeNode::BinarySplit {
2320            feature_index,
2321            threshold_bin,
2322            left_child,
2323            right_child,
2324            ..
2325        } => {
2326            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
2327                (*left_child, *left_child, true)
2328            } else {
2329                let left_count = classifier_node_sample_count(nodes, *left_child);
2330                let right_count = classifier_node_sample_count(nodes, *right_child);
2331                if left_count >= right_count {
2332                    (*left_child, *right_child, true)
2333                } else {
2334                    (*right_child, *left_child, false)
2335                }
2336            };
2337
2338            let fallthrough_index = append_binary_classifier_node(nodes, fallthrough_child, layout);
2339            debug_assert_eq!(fallthrough_index, current_index + 1);
2340            let jump_index = if jump_child == fallthrough_child {
2341                fallthrough_index
2342            } else {
2343                append_binary_classifier_node(nodes, jump_child, layout)
2344            };
2345
2346            layout[current_index] = OptimizedBinaryClassifierNode::Branch {
2347                feature_index: *feature_index,
2348                threshold_bin: *threshold_bin,
2349                jump_index,
2350                jump_if_greater,
2351            };
2352        }
2353        tree::classifier::TreeNode::MultiwaySplit { .. } => {
2354            unreachable!("multiway nodes are filtered out before binary layout construction");
2355        }
2356    }
2357
2358    current_index
2359}
2360
2361fn regressor_node_sample_count(
2362    nodes: &[tree::regressor::RegressionNode],
2363    node_index: usize,
2364) -> usize {
2365    match &nodes[node_index] {
2366        tree::regressor::RegressionNode::Leaf { sample_count, .. }
2367        | tree::regressor::RegressionNode::BinarySplit { sample_count, .. } => *sample_count,
2368    }
2369}
2370
2371fn build_binary_regressor_layout(
2372    nodes: &[tree::regressor::RegressionNode],
2373    root: usize,
2374) -> Vec<OptimizedBinaryRegressorNode> {
2375    let mut layout = Vec::with_capacity(nodes.len());
2376    append_binary_regressor_node(nodes, root, &mut layout);
2377    layout
2378}
2379
2380fn append_binary_regressor_node(
2381    nodes: &[tree::regressor::RegressionNode],
2382    node_index: usize,
2383    layout: &mut Vec<OptimizedBinaryRegressorNode>,
2384) -> usize {
2385    let current_index = layout.len();
2386    layout.push(OptimizedBinaryRegressorNode::Leaf(0.0));
2387
2388    match &nodes[node_index] {
2389        tree::regressor::RegressionNode::Leaf { value, .. } => {
2390            layout[current_index] = OptimizedBinaryRegressorNode::Leaf(*value);
2391        }
2392        tree::regressor::RegressionNode::BinarySplit {
2393            feature_index,
2394            threshold_bin,
2395            left_child,
2396            right_child,
2397            ..
2398        } => {
2399            let (fallthrough_child, jump_child, jump_if_greater) = if left_child == right_child {
2400                (*left_child, *left_child, true)
2401            } else {
2402                let left_count = regressor_node_sample_count(nodes, *left_child);
2403                let right_count = regressor_node_sample_count(nodes, *right_child);
2404                if left_count >= right_count {
2405                    (*left_child, *right_child, true)
2406                } else {
2407                    (*right_child, *left_child, false)
2408                }
2409            };
2410
2411            let fallthrough_index = append_binary_regressor_node(nodes, fallthrough_child, layout);
2412            debug_assert_eq!(fallthrough_index, current_index + 1);
2413            let jump_index = if jump_child == fallthrough_child {
2414                fallthrough_index
2415            } else {
2416                append_binary_regressor_node(nodes, jump_child, layout)
2417            };
2418
2419            layout[current_index] = OptimizedBinaryRegressorNode::Branch {
2420                feature_index: *feature_index,
2421                threshold_bin: *threshold_bin,
2422                jump_index,
2423                jump_if_greater,
2424            };
2425        }
2426    }
2427
2428    current_index
2429}
2430
2431fn predict_oblivious_column_major_matrix(
2432    feature_indices: &[usize],
2433    threshold_bins: &[u16],
2434    leaf_values: &[f64],
2435    matrix: &ColumnMajorBinnedMatrix,
2436    executor: &InferenceExecutor,
2437) -> Vec<f64> {
2438    let mut outputs = vec![0.0; matrix.n_rows];
2439    executor.fill_chunks(
2440        &mut outputs,
2441        PARALLEL_INFERENCE_CHUNK_ROWS,
2442        |start_row, chunk| {
2443            predict_oblivious_chunk(
2444                feature_indices,
2445                threshold_bins,
2446                leaf_values,
2447                matrix,
2448                start_row,
2449                chunk,
2450            )
2451        },
2452    );
2453    outputs
2454}
2455
2456fn predict_oblivious_probabilities_column_major_matrix(
2457    feature_indices: &[usize],
2458    threshold_bins: &[u16],
2459    leaf_values: &[Vec<f64>],
2460    matrix: &ColumnMajorBinnedMatrix,
2461    _executor: &InferenceExecutor,
2462) -> Vec<Vec<f64>> {
2463    (0..matrix.n_rows)
2464        .map(|row_index| {
2465            predict_oblivious_probabilities_row(
2466                feature_indices,
2467                threshold_bins,
2468                leaf_values,
2469                |feature_index| matrix.column(feature_index).value_at(row_index),
2470            )
2471            .to_vec()
2472        })
2473        .collect()
2474}
2475
2476fn predict_oblivious_chunk(
2477    feature_indices: &[usize],
2478    threshold_bins: &[u16],
2479    leaf_values: &[f64],
2480    matrix: &ColumnMajorBinnedMatrix,
2481    start_row: usize,
2482    output: &mut [f64],
2483) {
2484    let processed = simd_predict_oblivious_chunk(
2485        feature_indices,
2486        threshold_bins,
2487        leaf_values,
2488        matrix,
2489        start_row,
2490        output,
2491    );
2492
2493    for (offset, out) in output.iter_mut().enumerate().skip(processed) {
2494        let row_index = start_row + offset;
2495        *out = predict_oblivious_row(
2496            feature_indices,
2497            threshold_bins,
2498            leaf_values,
2499            |feature_index| matrix.column(feature_index).value_at(row_index),
2500        );
2501    }
2502}
2503
2504fn simd_predict_oblivious_chunk(
2505    feature_indices: &[usize],
2506    threshold_bins: &[u16],
2507    leaf_values: &[f64],
2508    matrix: &ColumnMajorBinnedMatrix,
2509    start_row: usize,
2510    output: &mut [f64],
2511) -> usize {
2512    let mut processed = 0usize;
2513    let ones = u32x8::splat(1);
2514
2515    while processed + OBLIVIOUS_SIMD_LANES <= output.len() {
2516        let base_row = start_row + processed;
2517        let mut leaf_indices = u32x8::splat(0);
2518
2519        for (&feature_index, &threshold_bin) in feature_indices.iter().zip(threshold_bins) {
2520            let column = matrix.column(feature_index);
2521            let bins = if let Some(lanes) = column.slice_u8(base_row, OBLIVIOUS_SIMD_LANES) {
2522                let lanes: [u8; OBLIVIOUS_SIMD_LANES] = lanes
2523                    .try_into()
2524                    .expect("lane width matches the fixed SIMD width");
2525                u32x8::new([
2526                    u32::from(lanes[0]),
2527                    u32::from(lanes[1]),
2528                    u32::from(lanes[2]),
2529                    u32::from(lanes[3]),
2530                    u32::from(lanes[4]),
2531                    u32::from(lanes[5]),
2532                    u32::from(lanes[6]),
2533                    u32::from(lanes[7]),
2534                ])
2535            } else {
2536                let lanes: [u16; OBLIVIOUS_SIMD_LANES] = column
2537                    .slice_u16(base_row, OBLIVIOUS_SIMD_LANES)
2538                    .expect("column is u16 when not u8")
2539                    .try_into()
2540                    .expect("lane width matches the fixed SIMD width");
2541                u32x8::from(u16x8::new(lanes))
2542            };
2543            let threshold = u32x8::splat(u32::from(threshold_bin));
2544            let bit = bins.cmp_gt(threshold) & ones;
2545            leaf_indices = (leaf_indices << 1) | bit;
2546        }
2547
2548        let lane_indices = leaf_indices.to_array();
2549        for lane in 0..OBLIVIOUS_SIMD_LANES {
2550            output[processed + lane] =
2551                leaf_values[usize::try_from(lane_indices[lane]).expect("leaf index fits usize")];
2552        }
2553        processed += OBLIVIOUS_SIMD_LANES;
2554    }
2555
2556    processed
2557}
2558
2559pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
2560    training::train(train_set, config)
2561}
2562
2563fn resolve_inference_thread_count(physical_cores: Option<usize>) -> Result<usize, OptimizeError> {
2564    let available = num_cpus::get_physical().max(1);
2565    let requested = physical_cores.unwrap_or(available);
2566
2567    if requested == 0 {
2568        return Err(OptimizeError::InvalidPhysicalCoreCount {
2569            requested,
2570            available,
2571        });
2572    }
2573
2574    Ok(requested.min(available))
2575}
2576
2577impl Model {
2578    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
2579        match self {
2580            Model::DecisionTreeClassifier(model) => model.predict_table(table),
2581            Model::DecisionTreeRegressor(model) => model.predict_table(table),
2582            Model::RandomForest(model) => model.predict_table(table),
2583            Model::GradientBoostedTrees(model) => model.predict_table(table),
2584        }
2585    }
2586
2587    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
2588        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
2589        Ok(self.predict_table(&table))
2590    }
2591
2592    pub fn predict_proba_table(
2593        &self,
2594        table: &dyn TableAccess,
2595    ) -> Result<Vec<Vec<f64>>, PredictError> {
2596        match self {
2597            Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
2598            Model::RandomForest(model) => model.predict_proba_table(table),
2599            Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
2600            Model::DecisionTreeRegressor(_) => {
2601                Err(PredictError::ProbabilityPredictionRequiresClassification)
2602            }
2603        }
2604    }
2605
2606    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
2607        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
2608        self.predict_proba_table(&table)
2609    }
2610
2611    pub fn predict_named_columns(
2612        &self,
2613        columns: BTreeMap<String, Vec<f64>>,
2614    ) -> Result<Vec<f64>, PredictError> {
2615        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
2616        Ok(self.predict_table(&table))
2617    }
2618
2619    pub fn predict_proba_named_columns(
2620        &self,
2621        columns: BTreeMap<String, Vec<f64>>,
2622    ) -> Result<Vec<Vec<f64>>, PredictError> {
2623        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
2624        self.predict_proba_table(&table)
2625    }
2626
2627    pub fn predict_sparse_binary_columns(
2628        &self,
2629        n_rows: usize,
2630        n_features: usize,
2631        columns: Vec<Vec<usize>>,
2632    ) -> Result<Vec<f64>, PredictError> {
2633        let table = InferenceTable::from_sparse_binary_columns(
2634            n_rows,
2635            n_features,
2636            columns,
2637            self.feature_preprocessing(),
2638        )?;
2639        Ok(self.predict_table(&table))
2640    }
2641
2642    pub fn predict_proba_sparse_binary_columns(
2643        &self,
2644        n_rows: usize,
2645        n_features: usize,
2646        columns: Vec<Vec<usize>>,
2647    ) -> Result<Vec<Vec<f64>>, PredictError> {
2648        let table = InferenceTable::from_sparse_binary_columns(
2649            n_rows,
2650            n_features,
2651            columns,
2652            self.feature_preprocessing(),
2653        )?;
2654        self.predict_proba_table(&table)
2655    }
2656
2657    #[cfg(feature = "polars")]
2658    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
2659        let columns = polars_named_columns(df)?;
2660        self.predict_named_columns(columns)
2661    }
2662
2663    #[cfg(feature = "polars")]
2664    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
2665        let mut predictions = Vec::new();
2666        let mut offset = 0i64;
2667        loop {
2668            let batch = lf
2669                .clone()
2670                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
2671                .collect()?;
2672            let height = batch.height();
2673            if height == 0 {
2674                break;
2675            }
2676            predictions.extend(self.predict_polars_dataframe(&batch)?);
2677            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
2678                break;
2679            }
2680            offset += height as i64;
2681        }
2682        Ok(predictions)
2683    }
2684
2685    pub fn algorithm(&self) -> TrainAlgorithm {
2686        match self {
2687            Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
2688                TrainAlgorithm::Dt
2689            }
2690            Model::RandomForest(_) => TrainAlgorithm::Rf,
2691            Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
2692        }
2693    }
2694
2695    pub fn task(&self) -> Task {
2696        match self {
2697            Model::DecisionTreeRegressor(_) => Task::Regression,
2698            Model::DecisionTreeClassifier(_) => Task::Classification,
2699            Model::RandomForest(model) => model.task(),
2700            Model::GradientBoostedTrees(model) => model.task(),
2701        }
2702    }
2703
2704    pub fn criterion(&self) -> Criterion {
2705        match self {
2706            Model::DecisionTreeClassifier(model) => model.criterion(),
2707            Model::DecisionTreeRegressor(model) => model.criterion(),
2708            Model::RandomForest(model) => model.criterion(),
2709            Model::GradientBoostedTrees(model) => model.criterion(),
2710        }
2711    }
2712
2713    pub fn tree_type(&self) -> TreeType {
2714        match self {
2715            Model::DecisionTreeClassifier(model) => match model.algorithm() {
2716                DecisionTreeAlgorithm::Id3 => TreeType::Id3,
2717                DecisionTreeAlgorithm::C45 => TreeType::C45,
2718                DecisionTreeAlgorithm::Cart => TreeType::Cart,
2719                DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
2720                DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
2721            },
2722            Model::DecisionTreeRegressor(model) => match model.algorithm() {
2723                RegressionTreeAlgorithm::Cart => TreeType::Cart,
2724                RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
2725                RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
2726            },
2727            Model::RandomForest(model) => model.tree_type(),
2728            Model::GradientBoostedTrees(model) => model.tree_type(),
2729        }
2730    }
2731
2732    pub fn mean_value(&self) -> Option<f64> {
2733        match self {
2734            Model::DecisionTreeClassifier(_)
2735            | Model::DecisionTreeRegressor(_)
2736            | Model::RandomForest(_)
2737            | Model::GradientBoostedTrees(_) => None,
2738        }
2739    }
2740
2741    pub fn canaries(&self) -> usize {
2742        self.training_metadata().canaries
2743    }
2744
2745    pub fn max_depth(&self) -> Option<usize> {
2746        self.training_metadata().max_depth
2747    }
2748
2749    pub fn min_samples_split(&self) -> Option<usize> {
2750        self.training_metadata().min_samples_split
2751    }
2752
2753    pub fn min_samples_leaf(&self) -> Option<usize> {
2754        self.training_metadata().min_samples_leaf
2755    }
2756
2757    pub fn n_trees(&self) -> Option<usize> {
2758        self.training_metadata().n_trees
2759    }
2760
2761    pub fn max_features(&self) -> Option<usize> {
2762        self.training_metadata().max_features
2763    }
2764
2765    pub fn seed(&self) -> Option<u64> {
2766        self.training_metadata().seed
2767    }
2768
2769    pub fn compute_oob(&self) -> bool {
2770        self.training_metadata().compute_oob
2771    }
2772
2773    pub fn oob_score(&self) -> Option<f64> {
2774        self.training_metadata().oob_score
2775    }
2776
2777    pub fn learning_rate(&self) -> Option<f64> {
2778        self.training_metadata().learning_rate
2779    }
2780
2781    pub fn bootstrap(&self) -> bool {
2782        self.training_metadata().bootstrap.unwrap_or(false)
2783    }
2784
2785    pub fn top_gradient_fraction(&self) -> Option<f64> {
2786        self.training_metadata().top_gradient_fraction
2787    }
2788
2789    pub fn other_gradient_fraction(&self) -> Option<f64> {
2790        self.training_metadata().other_gradient_fraction
2791    }
2792
2793    pub fn tree_count(&self) -> usize {
2794        self.to_ir().model.trees.len()
2795    }
2796
2797    pub fn tree_structure(
2798        &self,
2799        tree_index: usize,
2800    ) -> Result<TreeStructureSummary, IntrospectionError> {
2801        tree_structure_summary(self.tree_definition(tree_index)?)
2802    }
2803
2804    pub fn tree_prediction_stats(
2805        &self,
2806        tree_index: usize,
2807    ) -> Result<PredictionValueStats, IntrospectionError> {
2808        prediction_value_stats(self.tree_definition(tree_index)?)
2809    }
2810
2811    pub fn tree_node(
2812        &self,
2813        tree_index: usize,
2814        node_index: usize,
2815    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
2816        match self.tree_definition(tree_index)? {
2817            ir::TreeDefinition::NodeTree { nodes, .. } => {
2818                let available = nodes.len();
2819                nodes
2820                    .into_iter()
2821                    .nth(node_index)
2822                    .ok_or(IntrospectionError::NodeIndexOutOfBounds {
2823                        requested: node_index,
2824                        available,
2825                    })
2826            }
2827            ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
2828        }
2829    }
2830
2831    pub fn tree_level(
2832        &self,
2833        tree_index: usize,
2834        level_index: usize,
2835    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
2836        match self.tree_definition(tree_index)? {
2837            ir::TreeDefinition::ObliviousLevels { levels, .. } => {
2838                let available = levels.len();
2839                levels.into_iter().nth(level_index).ok_or(
2840                    IntrospectionError::LevelIndexOutOfBounds {
2841                        requested: level_index,
2842                        available,
2843                    },
2844                )
2845            }
2846            ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
2847        }
2848    }
2849
2850    pub fn tree_leaf(
2851        &self,
2852        tree_index: usize,
2853        leaf_index: usize,
2854    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
2855        match self.tree_definition(tree_index)? {
2856            ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
2857                let available = leaves.len();
2858                leaves
2859                    .into_iter()
2860                    .nth(leaf_index)
2861                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
2862                        requested: leaf_index,
2863                        available,
2864                    })
2865            }
2866            ir::TreeDefinition::NodeTree { nodes, .. } => {
2867                let leaves = nodes
2868                    .into_iter()
2869                    .filter_map(|node| match node {
2870                        ir::NodeTreeNode::Leaf {
2871                            node_id,
2872                            leaf,
2873                            stats,
2874                            ..
2875                        } => Some(ir::IndexedLeaf {
2876                            leaf_index: node_id,
2877                            leaf,
2878                            stats: ir::NodeStats {
2879                                sample_count: stats.sample_count,
2880                                impurity: stats.impurity,
2881                                gain: stats.gain,
2882                                class_counts: stats.class_counts,
2883                                variance: stats.variance,
2884                            },
2885                        }),
2886                        _ => None,
2887                    })
2888                    .collect::<Vec<_>>();
2889                let available = leaves.len();
2890                leaves
2891                    .into_iter()
2892                    .nth(leaf_index)
2893                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
2894                        requested: leaf_index,
2895                        available,
2896                    })
2897            }
2898        }
2899    }
2900
2901    pub fn to_ir(&self) -> ModelPackageIr {
2902        ir::model_to_ir(self)
2903    }
2904
2905    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
2906        serde_json::to_string(&self.to_ir())
2907    }
2908
2909    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
2910        serde_json::to_string_pretty(&self.to_ir())
2911    }
2912
2913    pub fn serialize(&self) -> Result<String, serde_json::Error> {
2914        self.to_ir_json()
2915    }
2916
2917    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
2918        self.to_ir_json_pretty()
2919    }
2920
2921    pub fn optimize_inference(
2922        &self,
2923        physical_cores: Option<usize>,
2924    ) -> Result<OptimizedModel, OptimizeError> {
2925        OptimizedModel::new(self.clone(), physical_cores)
2926    }
2927
2928    pub fn json_schema() -> schemars::schema::RootSchema {
2929        ModelPackageIr::json_schema()
2930    }
2931
2932    pub fn json_schema_json() -> Result<String, IrError> {
2933        ModelPackageIr::json_schema_json()
2934    }
2935
2936    pub fn json_schema_json_pretty() -> Result<String, IrError> {
2937        ModelPackageIr::json_schema_json_pretty()
2938    }
2939
2940    pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
2941        let ir: ModelPackageIr =
2942            serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
2943        ir::model_from_ir(ir)
2944    }
2945
2946    pub(crate) fn num_features(&self) -> usize {
2947        match self {
2948            Model::DecisionTreeClassifier(model) => model.num_features(),
2949            Model::DecisionTreeRegressor(model) => model.num_features(),
2950            Model::RandomForest(model) => model.num_features(),
2951            Model::GradientBoostedTrees(model) => model.num_features(),
2952        }
2953    }
2954
2955    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
2956        match self {
2957            Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
2958            Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
2959            Model::RandomForest(model) => model.feature_preprocessing(),
2960            Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
2961        }
2962    }
2963
2964    pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
2965        match self {
2966            Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
2967            Model::RandomForest(model) => model.class_labels(),
2968            Model::GradientBoostedTrees(model) => model.class_labels(),
2969            Model::DecisionTreeRegressor(_) => None,
2970        }
2971    }
2972
2973    pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
2974        match self {
2975            Model::DecisionTreeClassifier(model) => model.training_metadata(),
2976            Model::DecisionTreeRegressor(model) => model.training_metadata(),
2977            Model::RandomForest(model) => model.training_metadata(),
2978            Model::GradientBoostedTrees(model) => model.training_metadata(),
2979        }
2980    }
2981
2982    fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
2983        let trees = self.to_ir().model.trees;
2984        let available = trees.len();
2985        trees
2986            .into_iter()
2987            .nth(tree_index)
2988            .ok_or(IntrospectionError::TreeIndexOutOfBounds {
2989                requested: tree_index,
2990                available,
2991            })
2992    }
2993}
2994
2995fn tree_structure_summary(
2996    tree: ir::TreeDefinition,
2997) -> Result<TreeStructureSummary, IntrospectionError> {
2998    match tree {
2999        ir::TreeDefinition::NodeTree {
3000            root_node_id,
3001            nodes,
3002            ..
3003        } => {
3004            let node_map = nodes
3005                .iter()
3006                .cloned()
3007                .map(|node| match &node {
3008                    ir::NodeTreeNode::Leaf { node_id, .. }
3009                    | ir::NodeTreeNode::BinaryBranch { node_id, .. }
3010                    | ir::NodeTreeNode::MultiwayBranch { node_id, .. } => (*node_id, node),
3011                })
3012                .collect::<BTreeMap<_, _>>();
3013            let mut leaf_depths = Vec::new();
3014            collect_leaf_depths(&node_map, root_node_id, &mut leaf_depths)?;
3015            let internal_node_count = nodes
3016                .iter()
3017                .filter(|node| !matches!(node, ir::NodeTreeNode::Leaf { .. }))
3018                .count();
3019            let leaf_count = leaf_depths.len();
3020            let shortest_path = *leaf_depths.iter().min().unwrap_or(&0);
3021            let longest_path = *leaf_depths.iter().max().unwrap_or(&0);
3022            let average_path = if leaf_depths.is_empty() {
3023                0.0
3024            } else {
3025                leaf_depths.iter().sum::<usize>() as f64 / leaf_depths.len() as f64
3026            };
3027            Ok(TreeStructureSummary {
3028                representation: "node_tree".to_string(),
3029                node_count: internal_node_count + leaf_count,
3030                internal_node_count,
3031                leaf_count,
3032                actual_depth: longest_path,
3033                shortest_path,
3034                longest_path,
3035                average_path,
3036            })
3037        }
3038        ir::TreeDefinition::ObliviousLevels { depth, leaves, .. } => Ok(TreeStructureSummary {
3039            representation: "oblivious_levels".to_string(),
3040            node_count: ((1usize << depth) - 1) + leaves.len(),
3041            internal_node_count: (1usize << depth) - 1,
3042            leaf_count: leaves.len(),
3043            actual_depth: depth,
3044            shortest_path: depth,
3045            longest_path: depth,
3046            average_path: depth as f64,
3047        }),
3048    }
3049}
3050
3051fn collect_leaf_depths(
3052    nodes: &BTreeMap<usize, ir::NodeTreeNode>,
3053    node_id: usize,
3054    output: &mut Vec<usize>,
3055) -> Result<(), IntrospectionError> {
3056    match nodes
3057        .get(&node_id)
3058        .ok_or(IntrospectionError::NodeIndexOutOfBounds {
3059            requested: node_id,
3060            available: nodes.len(),
3061        })? {
3062        ir::NodeTreeNode::Leaf { depth, .. } => output.push(*depth),
3063        ir::NodeTreeNode::BinaryBranch {
3064            depth: _, children, ..
3065        } => {
3066            collect_leaf_depths(nodes, children.left, output)?;
3067            collect_leaf_depths(nodes, children.right, output)?;
3068        }
3069        ir::NodeTreeNode::MultiwayBranch {
3070            depth,
3071            branches,
3072            unmatched_leaf: _,
3073            ..
3074        } => {
3075            output.push(depth + 1);
3076            for branch in branches {
3077                collect_leaf_depths(nodes, branch.child, output)?;
3078            }
3079        }
3080    }
3081    Ok(())
3082}
3083
3084fn prediction_value_stats(
3085    tree: ir::TreeDefinition,
3086) -> Result<PredictionValueStats, IntrospectionError> {
3087    let predictions = match tree {
3088        ir::TreeDefinition::NodeTree { nodes, .. } => nodes
3089            .into_iter()
3090            .flat_map(|node| match node {
3091                ir::NodeTreeNode::Leaf { leaf, .. } => vec![leaf_payload_value(&leaf)],
3092                ir::NodeTreeNode::MultiwayBranch { unmatched_leaf, .. } => {
3093                    vec![leaf_payload_value(&unmatched_leaf)]
3094                }
3095                ir::NodeTreeNode::BinaryBranch { .. } => Vec::new(),
3096            })
3097            .collect::<Vec<_>>(),
3098        ir::TreeDefinition::ObliviousLevels { leaves, .. } => leaves
3099            .into_iter()
3100            .map(|leaf| leaf_payload_value(&leaf.leaf))
3101            .collect::<Vec<_>>(),
3102    };
3103
3104    let count = predictions.len();
3105    let min = predictions
3106        .iter()
3107        .copied()
3108        .min_by(f64::total_cmp)
3109        .unwrap_or(0.0);
3110    let max = predictions
3111        .iter()
3112        .copied()
3113        .max_by(f64::total_cmp)
3114        .unwrap_or(0.0);
3115    let mean = if count == 0 {
3116        0.0
3117    } else {
3118        predictions.iter().sum::<f64>() / count as f64
3119    };
3120    let std_dev = if count == 0 {
3121        0.0
3122    } else {
3123        let variance = predictions
3124            .iter()
3125            .map(|value| (*value - mean).powi(2))
3126            .sum::<f64>()
3127            / count as f64;
3128        variance.sqrt()
3129    };
3130    let mut histogram = BTreeMap::<String, usize>::new();
3131    for prediction in &predictions {
3132        *histogram.entry(prediction.to_string()).or_insert(0) += 1;
3133    }
3134    let histogram = histogram
3135        .into_iter()
3136        .map(|(prediction, count)| PredictionHistogramEntry {
3137            prediction: prediction
3138                .parse::<f64>()
3139                .expect("histogram keys are numeric"),
3140            count,
3141        })
3142        .collect::<Vec<_>>();
3143
3144    Ok(PredictionValueStats {
3145        count,
3146        unique_count: histogram.len(),
3147        min,
3148        max,
3149        mean,
3150        std_dev,
3151        histogram,
3152    })
3153}
3154
3155fn leaf_payload_value(leaf: &ir::LeafPayload) -> f64 {
3156    match leaf {
3157        ir::LeafPayload::RegressionValue { value } => *value,
3158        ir::LeafPayload::ClassIndex { class_value, .. } => *class_value,
3159    }
3160}
3161
3162#[cfg(feature = "polars")]
3163fn polars_named_columns(df: &DataFrame) -> Result<BTreeMap<String, Vec<f64>>, PredictError> {
3164    df.get_columns()
3165        .iter()
3166        .map(|column| {
3167            let name = column.name().to_string();
3168            Ok((name, polars_column_values(column)?))
3169        })
3170        .collect()
3171}
3172
3173#[cfg(feature = "polars")]
3174fn polars_column_values(column: &Column) -> Result<Vec<f64>, PredictError> {
3175    let name = column.name().to_string();
3176    let series = column.as_materialized_series();
3177    match series.dtype() {
3178        DataType::Boolean => series
3179            .bool()?
3180            .into_iter()
3181            .enumerate()
3182            .map(|(row_index, value)| {
3183                value
3184                    .map(|value| f64::from(u8::from(value)))
3185                    .ok_or_else(|| PredictError::NullValue {
3186                        feature: name.clone(),
3187                        row_index,
3188                    })
3189            })
3190            .collect(),
3191        DataType::Float64 => series
3192            .f64()?
3193            .into_iter()
3194            .enumerate()
3195            .map(|(row_index, value)| {
3196                value.ok_or_else(|| PredictError::NullValue {
3197                    feature: name.clone(),
3198                    row_index,
3199                })
3200            })
3201            .collect(),
3202        DataType::Float32
3203        | DataType::Int8
3204        | DataType::Int16
3205        | DataType::Int32
3206        | DataType::Int64
3207        | DataType::UInt8
3208        | DataType::UInt16
3209        | DataType::UInt32
3210        | DataType::UInt64 => {
3211            let casted = series.cast(&DataType::Float64)?;
3212            casted
3213                .f64()?
3214                .into_iter()
3215                .enumerate()
3216                .map(|(row_index, value)| {
3217                    value.ok_or_else(|| PredictError::NullValue {
3218                        feature: name.clone(),
3219                        row_index,
3220                    })
3221                })
3222                .collect()
3223        }
3224        dtype => Err(PredictError::UnsupportedFeatureType {
3225            feature: name,
3226            dtype: dtype.to_string(),
3227        }),
3228    }
3229}
3230
3231#[cfg(test)]
3232mod tests;