Skip to main content

forestfire_core/
model_api.rs

1use super::*;
2use std::collections::BTreeSet;
3
4/// Runtime-lowered model used for faster inference.
5///
6/// The optimized model keeps a copy of the source [`Model`] so it can preserve
7/// serialization and introspection behavior even after the runtime has been
8/// lowered into lookup-table-friendly structures.
9#[derive(Debug, Clone)]
10pub struct OptimizedModel {
11    pub(crate) source_model: Model,
12    pub(crate) runtime: OptimizedRuntime,
13    pub(crate) executor: InferenceExecutor,
14    pub(crate) feature_projection: Vec<usize>,
15}
16
17impl OptimizedModel {
18    pub(crate) fn new(
19        source_model: Model,
20        physical_cores: Option<usize>,
21        missing_features: Option<&[usize]>,
22    ) -> Result<Self, OptimizeError> {
23        let thread_count = resolve_inference_thread_count(physical_cores)?;
24        let feature_projection = build_feature_projection(&source_model);
25        let feature_index_map =
26            build_feature_index_map(source_model.num_features(), &feature_projection);
27        let missing_feature_set =
28            missing_features.map(|features| features.iter().copied().collect::<BTreeSet<_>>());
29        let runtime = OptimizedRuntime::from_model(
30            &source_model,
31            &feature_index_map,
32            missing_feature_set.as_ref(),
33        );
34        let executor = InferenceExecutor::new(thread_count)?;
35
36        Ok(Self {
37            source_model,
38            runtime,
39            executor,
40            feature_projection,
41        })
42    }
43
44    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
45        let projected = ProjectedTableView::new(table, &self.feature_projection);
46        if self.runtime.should_use_batch_matrix(table.n_rows()) {
47            let matrix = ColumnMajorBinnedMatrix::from_table_access_projected(
48                table,
49                &self.feature_projection,
50            );
51            return self.predict_column_major_binned_matrix(&matrix);
52        }
53
54        self.executor.predict_rows(projected.n_rows(), |row_index| {
55            self.runtime.predict_table_row(&projected, row_index)
56        })
57    }
58
59    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
60        let table = InferenceTable::from_rows_projected(
61            rows,
62            self.source_model.feature_preprocessing(),
63            &self.feature_projection,
64        )?;
65        if self.runtime.should_use_batch_matrix(table.n_rows()) {
66            let matrix = table.to_column_major_binned_matrix();
67            Ok(self.predict_column_major_binned_matrix(&matrix))
68        } else {
69            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
70                self.runtime.predict_table_row(&table, row_index)
71            }))
72        }
73    }
74
75    pub fn predict_named_columns(
76        &self,
77        columns: BTreeMap<String, Vec<f64>>,
78    ) -> Result<Vec<f64>, PredictError> {
79        let table = InferenceTable::from_named_columns_projected(
80            columns,
81            self.source_model.feature_preprocessing(),
82            &self.feature_projection,
83        )?;
84        if self.runtime.should_use_batch_matrix(table.n_rows()) {
85            let matrix = table.to_column_major_binned_matrix();
86            Ok(self.predict_column_major_binned_matrix(&matrix))
87        } else {
88            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
89                self.runtime.predict_table_row(&table, row_index)
90            }))
91        }
92    }
93
94    pub fn predict_proba_table(
95        &self,
96        table: &dyn TableAccess,
97    ) -> Result<Vec<Vec<f64>>, PredictError> {
98        let projected = ProjectedTableView::new(table, &self.feature_projection);
99        self.runtime.predict_proba_table(&projected, &self.executor)
100    }
101
102    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
103        let table = InferenceTable::from_rows_projected(
104            rows,
105            self.source_model.feature_preprocessing(),
106            &self.feature_projection,
107        )?;
108        self.runtime.predict_proba_table(&table, &self.executor)
109    }
110
111    pub fn predict_proba_named_columns(
112        &self,
113        columns: BTreeMap<String, Vec<f64>>,
114    ) -> Result<Vec<Vec<f64>>, PredictError> {
115        let table = InferenceTable::from_named_columns_projected(
116            columns,
117            self.source_model.feature_preprocessing(),
118            &self.feature_projection,
119        )?;
120        self.runtime.predict_proba_table(&table, &self.executor)
121    }
122
123    pub fn predict_proba_sparse_binary_columns(
124        &self,
125        n_rows: usize,
126        n_features: usize,
127        columns: Vec<Vec<usize>>,
128    ) -> Result<Vec<Vec<f64>>, PredictError> {
129        let table = InferenceTable::from_sparse_binary_columns_projected(
130            n_rows,
131            n_features,
132            columns,
133            self.source_model.feature_preprocessing(),
134            &self.feature_projection,
135        )?;
136        self.runtime.predict_proba_table(&table, &self.executor)
137    }
138
139    pub fn predict_sparse_binary_columns(
140        &self,
141        n_rows: usize,
142        n_features: usize,
143        columns: Vec<Vec<usize>>,
144    ) -> Result<Vec<f64>, PredictError> {
145        let table = InferenceTable::from_sparse_binary_columns_projected(
146            n_rows,
147            n_features,
148            columns,
149            self.source_model.feature_preprocessing(),
150            &self.feature_projection,
151        )?;
152        if self.runtime.should_use_batch_matrix(table.n_rows()) {
153            let matrix = table.to_column_major_binned_matrix();
154            Ok(self.predict_column_major_binned_matrix(&matrix))
155        } else {
156            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
157                self.runtime.predict_table_row(&table, row_index)
158            }))
159        }
160    }
161
162    #[cfg(feature = "polars")]
163    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
164        let columns = polars_named_columns(df)?;
165        self.predict_named_columns(columns)
166    }
167
168    #[cfg(feature = "polars")]
169    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
170        let mut predictions = Vec::new();
171        let mut offset = 0i64;
172        loop {
173            let batch = lf
174                .clone()
175                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
176                .collect()?;
177            let height = batch.height();
178            if height == 0 {
179                break;
180            }
181            predictions.extend(self.predict_polars_dataframe(&batch)?);
182            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
183                break;
184            }
185            offset += height as i64;
186        }
187        Ok(predictions)
188    }
189
190    pub fn algorithm(&self) -> TrainAlgorithm {
191        self.source_model.algorithm()
192    }
193
194    pub fn task(&self) -> Task {
195        self.source_model.task()
196    }
197
198    pub fn criterion(&self) -> Criterion {
199        self.source_model.criterion()
200    }
201
202    pub fn tree_type(&self) -> TreeType {
203        self.source_model.tree_type()
204    }
205
206    pub fn mean_value(&self) -> Option<f64> {
207        self.source_model.mean_value()
208    }
209
210    pub fn canaries(&self) -> usize {
211        self.source_model.canaries()
212    }
213
214    pub fn max_depth(&self) -> Option<usize> {
215        self.source_model.max_depth()
216    }
217
218    pub fn min_samples_split(&self) -> Option<usize> {
219        self.source_model.min_samples_split()
220    }
221
222    pub fn min_samples_leaf(&self) -> Option<usize> {
223        self.source_model.min_samples_leaf()
224    }
225
226    pub fn n_trees(&self) -> Option<usize> {
227        self.source_model.n_trees()
228    }
229
230    pub fn max_features(&self) -> Option<usize> {
231        self.source_model.max_features()
232    }
233
234    pub fn seed(&self) -> Option<u64> {
235        self.source_model.seed()
236    }
237
238    pub fn compute_oob(&self) -> bool {
239        self.source_model.compute_oob()
240    }
241
242    pub fn oob_score(&self) -> Option<f64> {
243        self.source_model.oob_score()
244    }
245
246    pub fn learning_rate(&self) -> Option<f64> {
247        self.source_model.learning_rate()
248    }
249
250    pub fn bootstrap(&self) -> bool {
251        self.source_model.bootstrap()
252    }
253
254    pub fn top_gradient_fraction(&self) -> Option<f64> {
255        self.source_model.top_gradient_fraction()
256    }
257
258    pub fn other_gradient_fraction(&self) -> Option<f64> {
259        self.source_model.other_gradient_fraction()
260    }
261
262    pub fn tree_count(&self) -> usize {
263        self.source_model.tree_count()
264    }
265
266    pub fn used_feature_indices(&self) -> Vec<usize> {
267        self.feature_projection.clone()
268    }
269
270    pub fn used_feature_count(&self) -> usize {
271        self.feature_projection.len()
272    }
273
274    pub fn tree_structure(
275        &self,
276        tree_index: usize,
277    ) -> Result<TreeStructureSummary, IntrospectionError> {
278        self.source_model.tree_structure(tree_index)
279    }
280
281    pub fn tree_prediction_stats(
282        &self,
283        tree_index: usize,
284    ) -> Result<PredictionValueStats, IntrospectionError> {
285        self.source_model.tree_prediction_stats(tree_index)
286    }
287
288    pub fn tree_node(
289        &self,
290        tree_index: usize,
291        node_index: usize,
292    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
293        self.source_model.tree_node(tree_index, node_index)
294    }
295
296    pub fn tree_level(
297        &self,
298        tree_index: usize,
299        level_index: usize,
300    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
301        self.source_model.tree_level(tree_index, level_index)
302    }
303
304    pub fn tree_leaf(
305        &self,
306        tree_index: usize,
307        leaf_index: usize,
308    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
309        self.source_model.tree_leaf(tree_index, leaf_index)
310    }
311
312    pub fn to_ir(&self) -> ModelPackageIr {
313        self.source_model.to_ir()
314    }
315
316    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
317        self.source_model.to_ir_json()
318    }
319
320    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
321        self.source_model.to_ir_json_pretty()
322    }
323
324    pub fn serialize(&self) -> Result<String, serde_json::Error> {
325        self.source_model.serialize()
326    }
327
328    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
329        self.source_model.serialize_pretty()
330    }
331
332    pub(crate) fn predict_column_major_binned_matrix(
333        &self,
334        matrix: &ColumnMajorBinnedMatrix,
335    ) -> Vec<f64> {
336        self.runtime
337            .predict_column_major_matrix(matrix, &self.executor)
338    }
339}
340
341impl Model {
342    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
343        match self {
344            Model::DecisionTreeClassifier(model) => model.predict_table(table),
345            Model::DecisionTreeRegressor(model) => model.predict_table(table),
346            Model::RandomForest(model) => model.predict_table(table),
347            Model::GradientBoostedTrees(model) => model.predict_table(table),
348        }
349    }
350
351    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
352        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
353        Ok(self.predict_table(&table))
354    }
355
356    pub fn predict_proba_table(
357        &self,
358        table: &dyn TableAccess,
359    ) -> Result<Vec<Vec<f64>>, PredictError> {
360        match self {
361            Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
362            Model::RandomForest(model) => model.predict_proba_table(table),
363            Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
364            Model::DecisionTreeRegressor(_) => {
365                Err(PredictError::ProbabilityPredictionRequiresClassification)
366            }
367        }
368    }
369
370    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
371        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
372        self.predict_proba_table(&table)
373    }
374
375    pub fn predict_named_columns(
376        &self,
377        columns: BTreeMap<String, Vec<f64>>,
378    ) -> Result<Vec<f64>, PredictError> {
379        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
380        Ok(self.predict_table(&table))
381    }
382
383    pub fn predict_proba_named_columns(
384        &self,
385        columns: BTreeMap<String, Vec<f64>>,
386    ) -> Result<Vec<Vec<f64>>, PredictError> {
387        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
388        self.predict_proba_table(&table)
389    }
390
391    pub fn predict_sparse_binary_columns(
392        &self,
393        n_rows: usize,
394        n_features: usize,
395        columns: Vec<Vec<usize>>,
396    ) -> Result<Vec<f64>, PredictError> {
397        let table = InferenceTable::from_sparse_binary_columns(
398            n_rows,
399            n_features,
400            columns,
401            self.feature_preprocessing(),
402        )?;
403        Ok(self.predict_table(&table))
404    }
405
406    pub fn predict_proba_sparse_binary_columns(
407        &self,
408        n_rows: usize,
409        n_features: usize,
410        columns: Vec<Vec<usize>>,
411    ) -> Result<Vec<Vec<f64>>, PredictError> {
412        let table = InferenceTable::from_sparse_binary_columns(
413            n_rows,
414            n_features,
415            columns,
416            self.feature_preprocessing(),
417        )?;
418        self.predict_proba_table(&table)
419    }
420
421    #[cfg(feature = "polars")]
422    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
423        let columns = polars_named_columns(df)?;
424        self.predict_named_columns(columns)
425    }
426
427    #[cfg(feature = "polars")]
428    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
429        let mut predictions = Vec::new();
430        let mut offset = 0i64;
431        loop {
432            let batch = lf
433                .clone()
434                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
435                .collect()?;
436            let height = batch.height();
437            if height == 0 {
438                break;
439            }
440            predictions.extend(self.predict_polars_dataframe(&batch)?);
441            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
442                break;
443            }
444            offset += height as i64;
445        }
446        Ok(predictions)
447    }
448
449    pub fn algorithm(&self) -> TrainAlgorithm {
450        match self {
451            Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
452                TrainAlgorithm::Dt
453            }
454            Model::RandomForest(_) => TrainAlgorithm::Rf,
455            Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
456        }
457    }
458
459    pub fn task(&self) -> Task {
460        match self {
461            Model::DecisionTreeRegressor(_) => Task::Regression,
462            Model::DecisionTreeClassifier(_) => Task::Classification,
463            Model::RandomForest(model) => model.task(),
464            Model::GradientBoostedTrees(model) => model.task(),
465        }
466    }
467
468    pub fn criterion(&self) -> Criterion {
469        match self {
470            Model::DecisionTreeClassifier(model) => model.criterion(),
471            Model::DecisionTreeRegressor(model) => model.criterion(),
472            Model::RandomForest(model) => model.criterion(),
473            Model::GradientBoostedTrees(model) => model.criterion(),
474        }
475    }
476
477    pub fn tree_type(&self) -> TreeType {
478        match self {
479            Model::DecisionTreeClassifier(model) => match model.algorithm() {
480                DecisionTreeAlgorithm::Id3 => TreeType::Id3,
481                DecisionTreeAlgorithm::C45 => TreeType::C45,
482                DecisionTreeAlgorithm::Cart => TreeType::Cart,
483                DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
484                DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
485            },
486            Model::DecisionTreeRegressor(model) => match model.algorithm() {
487                RegressionTreeAlgorithm::Cart => TreeType::Cart,
488                RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
489                RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
490            },
491            Model::RandomForest(model) => model.tree_type(),
492            Model::GradientBoostedTrees(model) => model.tree_type(),
493        }
494    }
495
496    pub fn mean_value(&self) -> Option<f64> {
497        match self {
498            Model::DecisionTreeClassifier(_)
499            | Model::DecisionTreeRegressor(_)
500            | Model::RandomForest(_)
501            | Model::GradientBoostedTrees(_) => None,
502        }
503    }
504
505    pub fn canaries(&self) -> usize {
506        self.training_metadata().canaries
507    }
508
509    pub fn max_depth(&self) -> Option<usize> {
510        self.training_metadata().max_depth
511    }
512
513    pub fn min_samples_split(&self) -> Option<usize> {
514        self.training_metadata().min_samples_split
515    }
516
517    pub fn min_samples_leaf(&self) -> Option<usize> {
518        self.training_metadata().min_samples_leaf
519    }
520
521    pub fn n_trees(&self) -> Option<usize> {
522        self.training_metadata().n_trees
523    }
524
525    pub fn max_features(&self) -> Option<usize> {
526        self.training_metadata().max_features
527    }
528
529    pub fn seed(&self) -> Option<u64> {
530        self.training_metadata().seed
531    }
532
533    pub fn compute_oob(&self) -> bool {
534        self.training_metadata().compute_oob
535    }
536
537    pub fn oob_score(&self) -> Option<f64> {
538        self.training_metadata().oob_score
539    }
540
541    pub fn learning_rate(&self) -> Option<f64> {
542        self.training_metadata().learning_rate
543    }
544
545    pub fn bootstrap(&self) -> bool {
546        self.training_metadata().bootstrap.unwrap_or(false)
547    }
548
549    pub fn top_gradient_fraction(&self) -> Option<f64> {
550        self.training_metadata().top_gradient_fraction
551    }
552
553    pub fn other_gradient_fraction(&self) -> Option<f64> {
554        self.training_metadata().other_gradient_fraction
555    }
556
557    pub fn tree_count(&self) -> usize {
558        self.to_ir().model.trees.len()
559    }
560
561    pub fn tree_structure(
562        &self,
563        tree_index: usize,
564    ) -> Result<TreeStructureSummary, IntrospectionError> {
565        tree_structure_summary(self.tree_definition(tree_index)?)
566    }
567
568    pub fn tree_prediction_stats(
569        &self,
570        tree_index: usize,
571    ) -> Result<PredictionValueStats, IntrospectionError> {
572        prediction_value_stats(self.tree_definition(tree_index)?)
573    }
574
575    pub fn tree_node(
576        &self,
577        tree_index: usize,
578        node_index: usize,
579    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
580        match self.tree_definition(tree_index)? {
581            ir::TreeDefinition::NodeTree { nodes, .. } => {
582                let available = nodes.len();
583                nodes
584                    .into_iter()
585                    .nth(node_index)
586                    .ok_or(IntrospectionError::NodeIndexOutOfBounds {
587                        requested: node_index,
588                        available,
589                    })
590            }
591            ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
592        }
593    }
594
595    pub fn tree_level(
596        &self,
597        tree_index: usize,
598        level_index: usize,
599    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
600        match self.tree_definition(tree_index)? {
601            ir::TreeDefinition::ObliviousLevels { levels, .. } => {
602                let available = levels.len();
603                levels.into_iter().nth(level_index).ok_or(
604                    IntrospectionError::LevelIndexOutOfBounds {
605                        requested: level_index,
606                        available,
607                    },
608                )
609            }
610            ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
611        }
612    }
613
614    pub fn tree_leaf(
615        &self,
616        tree_index: usize,
617        leaf_index: usize,
618    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
619        match self.tree_definition(tree_index)? {
620            ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
621                let available = leaves.len();
622                leaves
623                    .into_iter()
624                    .nth(leaf_index)
625                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
626                        requested: leaf_index,
627                        available,
628                    })
629            }
630            ir::TreeDefinition::NodeTree { nodes, .. } => {
631                let leaves = nodes
632                    .into_iter()
633                    .filter_map(|node| match node {
634                        ir::NodeTreeNode::Leaf {
635                            node_id,
636                            leaf,
637                            stats,
638                            ..
639                        } => Some(ir::IndexedLeaf {
640                            leaf_index: node_id,
641                            leaf,
642                            stats: ir::NodeStats {
643                                sample_count: stats.sample_count,
644                                impurity: stats.impurity,
645                                gain: stats.gain,
646                                class_counts: stats.class_counts,
647                                variance: stats.variance,
648                            },
649                        }),
650                        _ => None,
651                    })
652                    .collect::<Vec<_>>();
653                let available = leaves.len();
654                leaves
655                    .into_iter()
656                    .nth(leaf_index)
657                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
658                        requested: leaf_index,
659                        available,
660                    })
661            }
662        }
663    }
664
665    pub fn to_ir(&self) -> ModelPackageIr {
666        ir::model_to_ir(self)
667    }
668
669    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
670        serde_json::to_string(&self.to_ir())
671    }
672
673    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
674        serde_json::to_string_pretty(&self.to_ir())
675    }
676
677    pub fn serialize(&self) -> Result<String, serde_json::Error> {
678        self.to_ir_json()
679    }
680
681    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
682        self.to_ir_json_pretty()
683    }
684
685    pub fn optimize_inference(
686        &self,
687        physical_cores: Option<usize>,
688    ) -> Result<OptimizedModel, OptimizeError> {
689        OptimizedModel::new(self.clone(), physical_cores, None)
690    }
691
692    pub fn optimize_inference_with_missing_features(
693        &self,
694        physical_cores: Option<usize>,
695        missing_features: Option<Vec<usize>>,
696    ) -> Result<OptimizedModel, OptimizeError> {
697        OptimizedModel::new(self.clone(), physical_cores, missing_features.as_deref())
698    }
699
700    pub fn json_schema() -> schemars::schema::RootSchema {
701        ModelPackageIr::json_schema()
702    }
703
704    pub fn json_schema_json() -> Result<String, IrError> {
705        ModelPackageIr::json_schema_json()
706    }
707
708    pub fn json_schema_json_pretty() -> Result<String, IrError> {
709        ModelPackageIr::json_schema_json_pretty()
710    }
711
712    pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
713        let ir: ModelPackageIr =
714            serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
715        ir::model_from_ir(ir)
716    }
717
718    pub(crate) fn num_features(&self) -> usize {
719        match self {
720            Model::DecisionTreeClassifier(model) => model.num_features(),
721            Model::DecisionTreeRegressor(model) => model.num_features(),
722            Model::RandomForest(model) => model.num_features(),
723            Model::GradientBoostedTrees(model) => model.num_features(),
724        }
725    }
726
727    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
728        match self {
729            Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
730            Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
731            Model::RandomForest(model) => model.feature_preprocessing(),
732            Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
733        }
734    }
735
736    pub fn used_feature_indices(&self) -> Vec<usize> {
737        model_used_feature_indices(self)
738    }
739
740    pub fn used_feature_count(&self) -> usize {
741        self.used_feature_indices().len()
742    }
743
744    pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
745        match self {
746            Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
747            Model::RandomForest(model) => model.class_labels(),
748            Model::GradientBoostedTrees(model) => model.class_labels(),
749            Model::DecisionTreeRegressor(_) => None,
750        }
751    }
752
753    pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
754        match self {
755            Model::DecisionTreeClassifier(model) => model.training_metadata(),
756            Model::DecisionTreeRegressor(model) => model.training_metadata(),
757            Model::RandomForest(model) => model.training_metadata(),
758            Model::GradientBoostedTrees(model) => model.training_metadata(),
759        }
760    }
761
762    fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
763        let trees = self.to_ir().model.trees;
764        let available = trees.len();
765        trees
766            .into_iter()
767            .nth(tree_index)
768            .ok_or(IntrospectionError::TreeIndexOutOfBounds {
769                requested: tree_index,
770                available,
771            })
772    }
773}