Skip to main content

forestfire_core/
model_api.rs

1use super::*;
2
3/// Runtime-lowered model used for faster inference.
4///
5/// The optimized model keeps a copy of the source [`Model`] so it can preserve
6/// serialization and introspection behavior even after the runtime has been
7/// lowered into lookup-table-friendly structures.
8#[derive(Debug, Clone)]
9pub struct OptimizedModel {
10    pub(crate) source_model: Model,
11    pub(crate) runtime: OptimizedRuntime,
12    pub(crate) executor: InferenceExecutor,
13    pub(crate) feature_projection: Vec<usize>,
14}
15
16impl OptimizedModel {
17    pub(crate) fn new(
18        source_model: Model,
19        physical_cores: Option<usize>,
20    ) -> Result<Self, OptimizeError> {
21        let thread_count = resolve_inference_thread_count(physical_cores)?;
22        let feature_projection = build_feature_projection(&source_model);
23        let feature_index_map =
24            build_feature_index_map(source_model.num_features(), &feature_projection);
25        let runtime = OptimizedRuntime::from_model(&source_model, &feature_index_map);
26        let executor = InferenceExecutor::new(thread_count)?;
27
28        Ok(Self {
29            source_model,
30            runtime,
31            executor,
32            feature_projection,
33        })
34    }
35
36    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
37        let projected = ProjectedTableView::new(table, &self.feature_projection);
38        if self.runtime.should_use_batch_matrix(table.n_rows()) {
39            let matrix = ColumnMajorBinnedMatrix::from_table_access_projected(
40                table,
41                &self.feature_projection,
42            );
43            return self.predict_column_major_binned_matrix(&matrix);
44        }
45
46        self.executor.predict_rows(projected.n_rows(), |row_index| {
47            self.runtime.predict_table_row(&projected, row_index)
48        })
49    }
50
51    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
52        let table = InferenceTable::from_rows_projected(
53            rows,
54            self.source_model.feature_preprocessing(),
55            &self.feature_projection,
56        )?;
57        if self.runtime.should_use_batch_matrix(table.n_rows()) {
58            let matrix = table.to_column_major_binned_matrix();
59            Ok(self.predict_column_major_binned_matrix(&matrix))
60        } else {
61            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
62                self.runtime.predict_table_row(&table, row_index)
63            }))
64        }
65    }
66
67    pub fn predict_named_columns(
68        &self,
69        columns: BTreeMap<String, Vec<f64>>,
70    ) -> Result<Vec<f64>, PredictError> {
71        let table = InferenceTable::from_named_columns_projected(
72            columns,
73            self.source_model.feature_preprocessing(),
74            &self.feature_projection,
75        )?;
76        if self.runtime.should_use_batch_matrix(table.n_rows()) {
77            let matrix = table.to_column_major_binned_matrix();
78            Ok(self.predict_column_major_binned_matrix(&matrix))
79        } else {
80            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
81                self.runtime.predict_table_row(&table, row_index)
82            }))
83        }
84    }
85
86    pub fn predict_proba_table(
87        &self,
88        table: &dyn TableAccess,
89    ) -> Result<Vec<Vec<f64>>, PredictError> {
90        let projected = ProjectedTableView::new(table, &self.feature_projection);
91        self.runtime.predict_proba_table(&projected, &self.executor)
92    }
93
94    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
95        let table = InferenceTable::from_rows_projected(
96            rows,
97            self.source_model.feature_preprocessing(),
98            &self.feature_projection,
99        )?;
100        self.runtime.predict_proba_table(&table, &self.executor)
101    }
102
103    pub fn predict_proba_named_columns(
104        &self,
105        columns: BTreeMap<String, Vec<f64>>,
106    ) -> Result<Vec<Vec<f64>>, PredictError> {
107        let table = InferenceTable::from_named_columns_projected(
108            columns,
109            self.source_model.feature_preprocessing(),
110            &self.feature_projection,
111        )?;
112        self.runtime.predict_proba_table(&table, &self.executor)
113    }
114
115    pub fn predict_proba_sparse_binary_columns(
116        &self,
117        n_rows: usize,
118        n_features: usize,
119        columns: Vec<Vec<usize>>,
120    ) -> Result<Vec<Vec<f64>>, PredictError> {
121        let table = InferenceTable::from_sparse_binary_columns_projected(
122            n_rows,
123            n_features,
124            columns,
125            self.source_model.feature_preprocessing(),
126            &self.feature_projection,
127        )?;
128        self.runtime.predict_proba_table(&table, &self.executor)
129    }
130
131    pub fn predict_sparse_binary_columns(
132        &self,
133        n_rows: usize,
134        n_features: usize,
135        columns: Vec<Vec<usize>>,
136    ) -> Result<Vec<f64>, PredictError> {
137        let table = InferenceTable::from_sparse_binary_columns_projected(
138            n_rows,
139            n_features,
140            columns,
141            self.source_model.feature_preprocessing(),
142            &self.feature_projection,
143        )?;
144        if self.runtime.should_use_batch_matrix(table.n_rows()) {
145            let matrix = table.to_column_major_binned_matrix();
146            Ok(self.predict_column_major_binned_matrix(&matrix))
147        } else {
148            Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
149                self.runtime.predict_table_row(&table, row_index)
150            }))
151        }
152    }
153
154    #[cfg(feature = "polars")]
155    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
156        let columns = polars_named_columns(df)?;
157        self.predict_named_columns(columns)
158    }
159
160    #[cfg(feature = "polars")]
161    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
162        let mut predictions = Vec::new();
163        let mut offset = 0i64;
164        loop {
165            let batch = lf
166                .clone()
167                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
168                .collect()?;
169            let height = batch.height();
170            if height == 0 {
171                break;
172            }
173            predictions.extend(self.predict_polars_dataframe(&batch)?);
174            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
175                break;
176            }
177            offset += height as i64;
178        }
179        Ok(predictions)
180    }
181
182    pub fn algorithm(&self) -> TrainAlgorithm {
183        self.source_model.algorithm()
184    }
185
186    pub fn task(&self) -> Task {
187        self.source_model.task()
188    }
189
190    pub fn criterion(&self) -> Criterion {
191        self.source_model.criterion()
192    }
193
194    pub fn tree_type(&self) -> TreeType {
195        self.source_model.tree_type()
196    }
197
198    pub fn mean_value(&self) -> Option<f64> {
199        self.source_model.mean_value()
200    }
201
202    pub fn canaries(&self) -> usize {
203        self.source_model.canaries()
204    }
205
206    pub fn max_depth(&self) -> Option<usize> {
207        self.source_model.max_depth()
208    }
209
210    pub fn min_samples_split(&self) -> Option<usize> {
211        self.source_model.min_samples_split()
212    }
213
214    pub fn min_samples_leaf(&self) -> Option<usize> {
215        self.source_model.min_samples_leaf()
216    }
217
218    pub fn n_trees(&self) -> Option<usize> {
219        self.source_model.n_trees()
220    }
221
222    pub fn max_features(&self) -> Option<usize> {
223        self.source_model.max_features()
224    }
225
226    pub fn seed(&self) -> Option<u64> {
227        self.source_model.seed()
228    }
229
230    pub fn compute_oob(&self) -> bool {
231        self.source_model.compute_oob()
232    }
233
234    pub fn oob_score(&self) -> Option<f64> {
235        self.source_model.oob_score()
236    }
237
238    pub fn learning_rate(&self) -> Option<f64> {
239        self.source_model.learning_rate()
240    }
241
242    pub fn bootstrap(&self) -> bool {
243        self.source_model.bootstrap()
244    }
245
246    pub fn top_gradient_fraction(&self) -> Option<f64> {
247        self.source_model.top_gradient_fraction()
248    }
249
250    pub fn other_gradient_fraction(&self) -> Option<f64> {
251        self.source_model.other_gradient_fraction()
252    }
253
254    pub fn tree_count(&self) -> usize {
255        self.source_model.tree_count()
256    }
257
258    pub fn used_feature_indices(&self) -> Vec<usize> {
259        self.feature_projection.clone()
260    }
261
262    pub fn used_feature_count(&self) -> usize {
263        self.feature_projection.len()
264    }
265
266    pub fn tree_structure(
267        &self,
268        tree_index: usize,
269    ) -> Result<TreeStructureSummary, IntrospectionError> {
270        self.source_model.tree_structure(tree_index)
271    }
272
273    pub fn tree_prediction_stats(
274        &self,
275        tree_index: usize,
276    ) -> Result<PredictionValueStats, IntrospectionError> {
277        self.source_model.tree_prediction_stats(tree_index)
278    }
279
280    pub fn tree_node(
281        &self,
282        tree_index: usize,
283        node_index: usize,
284    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
285        self.source_model.tree_node(tree_index, node_index)
286    }
287
288    pub fn tree_level(
289        &self,
290        tree_index: usize,
291        level_index: usize,
292    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
293        self.source_model.tree_level(tree_index, level_index)
294    }
295
296    pub fn tree_leaf(
297        &self,
298        tree_index: usize,
299        leaf_index: usize,
300    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
301        self.source_model.tree_leaf(tree_index, leaf_index)
302    }
303
304    pub fn to_ir(&self) -> ModelPackageIr {
305        self.source_model.to_ir()
306    }
307
308    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
309        self.source_model.to_ir_json()
310    }
311
312    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
313        self.source_model.to_ir_json_pretty()
314    }
315
316    pub fn serialize(&self) -> Result<String, serde_json::Error> {
317        self.source_model.serialize()
318    }
319
320    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
321        self.source_model.serialize_pretty()
322    }
323
324    pub(crate) fn predict_column_major_binned_matrix(
325        &self,
326        matrix: &ColumnMajorBinnedMatrix,
327    ) -> Vec<f64> {
328        self.runtime
329            .predict_column_major_matrix(matrix, &self.executor)
330    }
331}
332
333impl Model {
334    pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
335        match self {
336            Model::DecisionTreeClassifier(model) => model.predict_table(table),
337            Model::DecisionTreeRegressor(model) => model.predict_table(table),
338            Model::RandomForest(model) => model.predict_table(table),
339            Model::GradientBoostedTrees(model) => model.predict_table(table),
340        }
341    }
342
343    pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
344        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
345        Ok(self.predict_table(&table))
346    }
347
348    pub fn predict_proba_table(
349        &self,
350        table: &dyn TableAccess,
351    ) -> Result<Vec<Vec<f64>>, PredictError> {
352        match self {
353            Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
354            Model::RandomForest(model) => model.predict_proba_table(table),
355            Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
356            Model::DecisionTreeRegressor(_) => {
357                Err(PredictError::ProbabilityPredictionRequiresClassification)
358            }
359        }
360    }
361
362    pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
363        let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
364        self.predict_proba_table(&table)
365    }
366
367    pub fn predict_named_columns(
368        &self,
369        columns: BTreeMap<String, Vec<f64>>,
370    ) -> Result<Vec<f64>, PredictError> {
371        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
372        Ok(self.predict_table(&table))
373    }
374
375    pub fn predict_proba_named_columns(
376        &self,
377        columns: BTreeMap<String, Vec<f64>>,
378    ) -> Result<Vec<Vec<f64>>, PredictError> {
379        let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
380        self.predict_proba_table(&table)
381    }
382
383    pub fn predict_sparse_binary_columns(
384        &self,
385        n_rows: usize,
386        n_features: usize,
387        columns: Vec<Vec<usize>>,
388    ) -> Result<Vec<f64>, PredictError> {
389        let table = InferenceTable::from_sparse_binary_columns(
390            n_rows,
391            n_features,
392            columns,
393            self.feature_preprocessing(),
394        )?;
395        Ok(self.predict_table(&table))
396    }
397
398    pub fn predict_proba_sparse_binary_columns(
399        &self,
400        n_rows: usize,
401        n_features: usize,
402        columns: Vec<Vec<usize>>,
403    ) -> Result<Vec<Vec<f64>>, PredictError> {
404        let table = InferenceTable::from_sparse_binary_columns(
405            n_rows,
406            n_features,
407            columns,
408            self.feature_preprocessing(),
409        )?;
410        self.predict_proba_table(&table)
411    }
412
413    #[cfg(feature = "polars")]
414    pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
415        let columns = polars_named_columns(df)?;
416        self.predict_named_columns(columns)
417    }
418
419    #[cfg(feature = "polars")]
420    pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
421        let mut predictions = Vec::new();
422        let mut offset = 0i64;
423        loop {
424            let batch = lf
425                .clone()
426                .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
427                .collect()?;
428            let height = batch.height();
429            if height == 0 {
430                break;
431            }
432            predictions.extend(self.predict_polars_dataframe(&batch)?);
433            if height < LAZYFRAME_PREDICT_BATCH_ROWS {
434                break;
435            }
436            offset += height as i64;
437        }
438        Ok(predictions)
439    }
440
441    pub fn algorithm(&self) -> TrainAlgorithm {
442        match self {
443            Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
444                TrainAlgorithm::Dt
445            }
446            Model::RandomForest(_) => TrainAlgorithm::Rf,
447            Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
448        }
449    }
450
451    pub fn task(&self) -> Task {
452        match self {
453            Model::DecisionTreeRegressor(_) => Task::Regression,
454            Model::DecisionTreeClassifier(_) => Task::Classification,
455            Model::RandomForest(model) => model.task(),
456            Model::GradientBoostedTrees(model) => model.task(),
457        }
458    }
459
460    pub fn criterion(&self) -> Criterion {
461        match self {
462            Model::DecisionTreeClassifier(model) => model.criterion(),
463            Model::DecisionTreeRegressor(model) => model.criterion(),
464            Model::RandomForest(model) => model.criterion(),
465            Model::GradientBoostedTrees(model) => model.criterion(),
466        }
467    }
468
469    pub fn tree_type(&self) -> TreeType {
470        match self {
471            Model::DecisionTreeClassifier(model) => match model.algorithm() {
472                DecisionTreeAlgorithm::Id3 => TreeType::Id3,
473                DecisionTreeAlgorithm::C45 => TreeType::C45,
474                DecisionTreeAlgorithm::Cart => TreeType::Cart,
475                DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
476                DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
477            },
478            Model::DecisionTreeRegressor(model) => match model.algorithm() {
479                RegressionTreeAlgorithm::Cart => TreeType::Cart,
480                RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
481                RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
482            },
483            Model::RandomForest(model) => model.tree_type(),
484            Model::GradientBoostedTrees(model) => model.tree_type(),
485        }
486    }
487
488    pub fn mean_value(&self) -> Option<f64> {
489        match self {
490            Model::DecisionTreeClassifier(_)
491            | Model::DecisionTreeRegressor(_)
492            | Model::RandomForest(_)
493            | Model::GradientBoostedTrees(_) => None,
494        }
495    }
496
497    pub fn canaries(&self) -> usize {
498        self.training_metadata().canaries
499    }
500
501    pub fn max_depth(&self) -> Option<usize> {
502        self.training_metadata().max_depth
503    }
504
505    pub fn min_samples_split(&self) -> Option<usize> {
506        self.training_metadata().min_samples_split
507    }
508
509    pub fn min_samples_leaf(&self) -> Option<usize> {
510        self.training_metadata().min_samples_leaf
511    }
512
513    pub fn n_trees(&self) -> Option<usize> {
514        self.training_metadata().n_trees
515    }
516
517    pub fn max_features(&self) -> Option<usize> {
518        self.training_metadata().max_features
519    }
520
521    pub fn seed(&self) -> Option<u64> {
522        self.training_metadata().seed
523    }
524
525    pub fn compute_oob(&self) -> bool {
526        self.training_metadata().compute_oob
527    }
528
529    pub fn oob_score(&self) -> Option<f64> {
530        self.training_metadata().oob_score
531    }
532
533    pub fn learning_rate(&self) -> Option<f64> {
534        self.training_metadata().learning_rate
535    }
536
537    pub fn bootstrap(&self) -> bool {
538        self.training_metadata().bootstrap.unwrap_or(false)
539    }
540
541    pub fn top_gradient_fraction(&self) -> Option<f64> {
542        self.training_metadata().top_gradient_fraction
543    }
544
545    pub fn other_gradient_fraction(&self) -> Option<f64> {
546        self.training_metadata().other_gradient_fraction
547    }
548
549    pub fn tree_count(&self) -> usize {
550        self.to_ir().model.trees.len()
551    }
552
553    pub fn tree_structure(
554        &self,
555        tree_index: usize,
556    ) -> Result<TreeStructureSummary, IntrospectionError> {
557        tree_structure_summary(self.tree_definition(tree_index)?)
558    }
559
560    pub fn tree_prediction_stats(
561        &self,
562        tree_index: usize,
563    ) -> Result<PredictionValueStats, IntrospectionError> {
564        prediction_value_stats(self.tree_definition(tree_index)?)
565    }
566
567    pub fn tree_node(
568        &self,
569        tree_index: usize,
570        node_index: usize,
571    ) -> Result<ir::NodeTreeNode, IntrospectionError> {
572        match self.tree_definition(tree_index)? {
573            ir::TreeDefinition::NodeTree { nodes, .. } => {
574                let available = nodes.len();
575                nodes
576                    .into_iter()
577                    .nth(node_index)
578                    .ok_or(IntrospectionError::NodeIndexOutOfBounds {
579                        requested: node_index,
580                        available,
581                    })
582            }
583            ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
584        }
585    }
586
587    pub fn tree_level(
588        &self,
589        tree_index: usize,
590        level_index: usize,
591    ) -> Result<ir::ObliviousLevel, IntrospectionError> {
592        match self.tree_definition(tree_index)? {
593            ir::TreeDefinition::ObliviousLevels { levels, .. } => {
594                let available = levels.len();
595                levels.into_iter().nth(level_index).ok_or(
596                    IntrospectionError::LevelIndexOutOfBounds {
597                        requested: level_index,
598                        available,
599                    },
600                )
601            }
602            ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
603        }
604    }
605
606    pub fn tree_leaf(
607        &self,
608        tree_index: usize,
609        leaf_index: usize,
610    ) -> Result<ir::IndexedLeaf, IntrospectionError> {
611        match self.tree_definition(tree_index)? {
612            ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
613                let available = leaves.len();
614                leaves
615                    .into_iter()
616                    .nth(leaf_index)
617                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
618                        requested: leaf_index,
619                        available,
620                    })
621            }
622            ir::TreeDefinition::NodeTree { nodes, .. } => {
623                let leaves = nodes
624                    .into_iter()
625                    .filter_map(|node| match node {
626                        ir::NodeTreeNode::Leaf {
627                            node_id,
628                            leaf,
629                            stats,
630                            ..
631                        } => Some(ir::IndexedLeaf {
632                            leaf_index: node_id,
633                            leaf,
634                            stats: ir::NodeStats {
635                                sample_count: stats.sample_count,
636                                impurity: stats.impurity,
637                                gain: stats.gain,
638                                class_counts: stats.class_counts,
639                                variance: stats.variance,
640                            },
641                        }),
642                        _ => None,
643                    })
644                    .collect::<Vec<_>>();
645                let available = leaves.len();
646                leaves
647                    .into_iter()
648                    .nth(leaf_index)
649                    .ok_or(IntrospectionError::LeafIndexOutOfBounds {
650                        requested: leaf_index,
651                        available,
652                    })
653            }
654        }
655    }
656
657    pub fn to_ir(&self) -> ModelPackageIr {
658        ir::model_to_ir(self)
659    }
660
661    pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
662        serde_json::to_string(&self.to_ir())
663    }
664
665    pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
666        serde_json::to_string_pretty(&self.to_ir())
667    }
668
669    pub fn serialize(&self) -> Result<String, serde_json::Error> {
670        self.to_ir_json()
671    }
672
673    pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
674        self.to_ir_json_pretty()
675    }
676
677    pub fn optimize_inference(
678        &self,
679        physical_cores: Option<usize>,
680    ) -> Result<OptimizedModel, OptimizeError> {
681        OptimizedModel::new(self.clone(), physical_cores)
682    }
683
684    pub fn json_schema() -> schemars::schema::RootSchema {
685        ModelPackageIr::json_schema()
686    }
687
688    pub fn json_schema_json() -> Result<String, IrError> {
689        ModelPackageIr::json_schema_json()
690    }
691
692    pub fn json_schema_json_pretty() -> Result<String, IrError> {
693        ModelPackageIr::json_schema_json_pretty()
694    }
695
696    pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
697        let ir: ModelPackageIr =
698            serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
699        ir::model_from_ir(ir)
700    }
701
702    pub(crate) fn num_features(&self) -> usize {
703        match self {
704            Model::DecisionTreeClassifier(model) => model.num_features(),
705            Model::DecisionTreeRegressor(model) => model.num_features(),
706            Model::RandomForest(model) => model.num_features(),
707            Model::GradientBoostedTrees(model) => model.num_features(),
708        }
709    }
710
711    pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
712        match self {
713            Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
714            Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
715            Model::RandomForest(model) => model.feature_preprocessing(),
716            Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
717        }
718    }
719
720    pub fn used_feature_indices(&self) -> Vec<usize> {
721        model_used_feature_indices(self)
722    }
723
724    pub fn used_feature_count(&self) -> usize {
725        self.used_feature_indices().len()
726    }
727
728    pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
729        match self {
730            Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
731            Model::RandomForest(model) => model.class_labels(),
732            Model::GradientBoostedTrees(model) => model.class_labels(),
733            Model::DecisionTreeRegressor(_) => None,
734        }
735    }
736
737    pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
738        match self {
739            Model::DecisionTreeClassifier(model) => model.training_metadata(),
740            Model::DecisionTreeRegressor(model) => model.training_metadata(),
741            Model::RandomForest(model) => model.training_metadata(),
742            Model::GradientBoostedTrees(model) => model.training_metadata(),
743        }
744    }
745
746    fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
747        let trees = self.to_ir().model.trees;
748        let available = trees.len();
749        trees
750            .into_iter()
751            .nth(tree_index)
752            .ok_or(IntrospectionError::TreeIndexOutOfBounds {
753                requested: tree_index,
754                available,
755            })
756    }
757}