1use super::*;
2use std::collections::BTreeSet;
3
4#[derive(Debug, Clone)]
10pub struct OptimizedModel {
11 pub(crate) source_model: Model,
12 pub(crate) runtime: OptimizedRuntime,
13 pub(crate) executor: InferenceExecutor,
14 pub(crate) feature_projection: Vec<usize>,
15}
16
17impl OptimizedModel {
18 pub(crate) fn new(
19 source_model: Model,
20 physical_cores: Option<usize>,
21 missing_features: Option<&[usize]>,
22 ) -> Result<Self, OptimizeError> {
23 let thread_count = resolve_inference_thread_count(physical_cores)?;
24 let feature_projection = build_feature_projection(&source_model);
25 let feature_index_map =
26 build_feature_index_map(source_model.num_features(), &feature_projection);
27 let missing_feature_set =
28 missing_features.map(|features| features.iter().copied().collect::<BTreeSet<_>>());
29 let runtime = OptimizedRuntime::from_model(
30 &source_model,
31 &feature_index_map,
32 missing_feature_set.as_ref(),
33 );
34 let executor = InferenceExecutor::new(thread_count)?;
35
36 Ok(Self {
37 source_model,
38 runtime,
39 executor,
40 feature_projection,
41 })
42 }
43
44 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
45 let projected = ProjectedTableView::new(table, &self.feature_projection);
46 if self.runtime.should_use_batch_matrix(table.n_rows()) {
47 let matrix = ColumnMajorBinnedMatrix::from_table_access_projected(
48 table,
49 &self.feature_projection,
50 );
51 return self.predict_column_major_binned_matrix(&matrix);
52 }
53
54 self.executor.predict_rows(projected.n_rows(), |row_index| {
55 self.runtime.predict_table_row(&projected, row_index)
56 })
57 }
58
59 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
60 let table = InferenceTable::from_rows_projected(
61 rows,
62 self.source_model.feature_preprocessing(),
63 &self.feature_projection,
64 )?;
65 if self.runtime.should_use_batch_matrix(table.n_rows()) {
66 let matrix = table.to_column_major_binned_matrix();
67 Ok(self.predict_column_major_binned_matrix(&matrix))
68 } else {
69 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
70 self.runtime.predict_table_row(&table, row_index)
71 }))
72 }
73 }
74
75 pub fn predict_named_columns(
76 &self,
77 columns: BTreeMap<String, Vec<f64>>,
78 ) -> Result<Vec<f64>, PredictError> {
79 let table = InferenceTable::from_named_columns_projected(
80 columns,
81 self.source_model.feature_preprocessing(),
82 &self.feature_projection,
83 )?;
84 if self.runtime.should_use_batch_matrix(table.n_rows()) {
85 let matrix = table.to_column_major_binned_matrix();
86 Ok(self.predict_column_major_binned_matrix(&matrix))
87 } else {
88 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
89 self.runtime.predict_table_row(&table, row_index)
90 }))
91 }
92 }
93
94 pub fn predict_proba_table(
95 &self,
96 table: &dyn TableAccess,
97 ) -> Result<Vec<Vec<f64>>, PredictError> {
98 let projected = ProjectedTableView::new(table, &self.feature_projection);
99 self.runtime.predict_proba_table(&projected, &self.executor)
100 }
101
102 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
103 let table = InferenceTable::from_rows_projected(
104 rows,
105 self.source_model.feature_preprocessing(),
106 &self.feature_projection,
107 )?;
108 self.runtime.predict_proba_table(&table, &self.executor)
109 }
110
111 pub fn predict_proba_named_columns(
112 &self,
113 columns: BTreeMap<String, Vec<f64>>,
114 ) -> Result<Vec<Vec<f64>>, PredictError> {
115 let table = InferenceTable::from_named_columns_projected(
116 columns,
117 self.source_model.feature_preprocessing(),
118 &self.feature_projection,
119 )?;
120 self.runtime.predict_proba_table(&table, &self.executor)
121 }
122
123 pub fn predict_proba_sparse_binary_columns(
124 &self,
125 n_rows: usize,
126 n_features: usize,
127 columns: Vec<Vec<usize>>,
128 ) -> Result<Vec<Vec<f64>>, PredictError> {
129 let table = InferenceTable::from_sparse_binary_columns_projected(
130 n_rows,
131 n_features,
132 columns,
133 self.source_model.feature_preprocessing(),
134 &self.feature_projection,
135 )?;
136 self.runtime.predict_proba_table(&table, &self.executor)
137 }
138
139 pub fn predict_sparse_binary_columns(
140 &self,
141 n_rows: usize,
142 n_features: usize,
143 columns: Vec<Vec<usize>>,
144 ) -> Result<Vec<f64>, PredictError> {
145 let table = InferenceTable::from_sparse_binary_columns_projected(
146 n_rows,
147 n_features,
148 columns,
149 self.source_model.feature_preprocessing(),
150 &self.feature_projection,
151 )?;
152 if self.runtime.should_use_batch_matrix(table.n_rows()) {
153 let matrix = table.to_column_major_binned_matrix();
154 Ok(self.predict_column_major_binned_matrix(&matrix))
155 } else {
156 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
157 self.runtime.predict_table_row(&table, row_index)
158 }))
159 }
160 }
161
162 #[cfg(feature = "polars")]
163 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
164 let columns = polars_named_columns(df)?;
165 self.predict_named_columns(columns)
166 }
167
168 #[cfg(feature = "polars")]
169 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
170 let mut predictions = Vec::new();
171 let mut offset = 0i64;
172 loop {
173 let batch = lf
174 .clone()
175 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
176 .collect()?;
177 let height = batch.height();
178 if height == 0 {
179 break;
180 }
181 predictions.extend(self.predict_polars_dataframe(&batch)?);
182 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
183 break;
184 }
185 offset += height as i64;
186 }
187 Ok(predictions)
188 }
189
190 pub fn algorithm(&self) -> TrainAlgorithm {
191 self.source_model.algorithm()
192 }
193
194 pub fn task(&self) -> Task {
195 self.source_model.task()
196 }
197
198 pub fn criterion(&self) -> Criterion {
199 self.source_model.criterion()
200 }
201
202 pub fn tree_type(&self) -> TreeType {
203 self.source_model.tree_type()
204 }
205
206 pub fn mean_value(&self) -> Option<f64> {
207 self.source_model.mean_value()
208 }
209
210 pub fn canaries(&self) -> usize {
211 self.source_model.canaries()
212 }
213
214 pub fn max_depth(&self) -> Option<usize> {
215 self.source_model.max_depth()
216 }
217
218 pub fn min_samples_split(&self) -> Option<usize> {
219 self.source_model.min_samples_split()
220 }
221
222 pub fn min_samples_leaf(&self) -> Option<usize> {
223 self.source_model.min_samples_leaf()
224 }
225
226 pub fn n_trees(&self) -> Option<usize> {
227 self.source_model.n_trees()
228 }
229
230 pub fn max_features(&self) -> Option<usize> {
231 self.source_model.max_features()
232 }
233
234 pub fn seed(&self) -> Option<u64> {
235 self.source_model.seed()
236 }
237
238 pub fn compute_oob(&self) -> bool {
239 self.source_model.compute_oob()
240 }
241
242 pub fn oob_score(&self) -> Option<f64> {
243 self.source_model.oob_score()
244 }
245
246 pub fn learning_rate(&self) -> Option<f64> {
247 self.source_model.learning_rate()
248 }
249
250 pub fn bootstrap(&self) -> bool {
251 self.source_model.bootstrap()
252 }
253
254 pub fn top_gradient_fraction(&self) -> Option<f64> {
255 self.source_model.top_gradient_fraction()
256 }
257
258 pub fn other_gradient_fraction(&self) -> Option<f64> {
259 self.source_model.other_gradient_fraction()
260 }
261
262 pub fn tree_count(&self) -> usize {
263 self.source_model.tree_count()
264 }
265
266 pub fn used_feature_indices(&self) -> Vec<usize> {
267 self.feature_projection.clone()
268 }
269
270 pub fn used_feature_count(&self) -> usize {
271 self.feature_projection.len()
272 }
273
274 pub fn tree_structure(
275 &self,
276 tree_index: usize,
277 ) -> Result<TreeStructureSummary, IntrospectionError> {
278 self.source_model.tree_structure(tree_index)
279 }
280
281 pub fn tree_prediction_stats(
282 &self,
283 tree_index: usize,
284 ) -> Result<PredictionValueStats, IntrospectionError> {
285 self.source_model.tree_prediction_stats(tree_index)
286 }
287
288 pub fn tree_node(
289 &self,
290 tree_index: usize,
291 node_index: usize,
292 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
293 self.source_model.tree_node(tree_index, node_index)
294 }
295
296 pub fn tree_level(
297 &self,
298 tree_index: usize,
299 level_index: usize,
300 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
301 self.source_model.tree_level(tree_index, level_index)
302 }
303
304 pub fn tree_leaf(
305 &self,
306 tree_index: usize,
307 leaf_index: usize,
308 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
309 self.source_model.tree_leaf(tree_index, leaf_index)
310 }
311
312 pub fn to_ir(&self) -> ModelPackageIr {
313 self.source_model.to_ir()
314 }
315
316 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
317 self.source_model.to_ir_json()
318 }
319
320 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
321 self.source_model.to_ir_json_pretty()
322 }
323
324 pub fn serialize(&self) -> Result<String, serde_json::Error> {
325 self.source_model.serialize()
326 }
327
328 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
329 self.source_model.serialize_pretty()
330 }
331
332 pub(crate) fn predict_column_major_binned_matrix(
333 &self,
334 matrix: &ColumnMajorBinnedMatrix,
335 ) -> Vec<f64> {
336 self.runtime
337 .predict_column_major_matrix(matrix, &self.executor)
338 }
339}
340
341impl Model {
342 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
343 match self {
344 Model::DecisionTreeClassifier(model) => model.predict_table(table),
345 Model::DecisionTreeRegressor(model) => model.predict_table(table),
346 Model::RandomForest(model) => model.predict_table(table),
347 Model::GradientBoostedTrees(model) => model.predict_table(table),
348 }
349 }
350
351 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
352 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
353 Ok(self.predict_table(&table))
354 }
355
356 pub fn predict_proba_table(
357 &self,
358 table: &dyn TableAccess,
359 ) -> Result<Vec<Vec<f64>>, PredictError> {
360 match self {
361 Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
362 Model::RandomForest(model) => model.predict_proba_table(table),
363 Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
364 Model::DecisionTreeRegressor(_) => {
365 Err(PredictError::ProbabilityPredictionRequiresClassification)
366 }
367 }
368 }
369
370 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
371 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
372 self.predict_proba_table(&table)
373 }
374
375 pub fn predict_named_columns(
376 &self,
377 columns: BTreeMap<String, Vec<f64>>,
378 ) -> Result<Vec<f64>, PredictError> {
379 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
380 Ok(self.predict_table(&table))
381 }
382
383 pub fn predict_proba_named_columns(
384 &self,
385 columns: BTreeMap<String, Vec<f64>>,
386 ) -> Result<Vec<Vec<f64>>, PredictError> {
387 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
388 self.predict_proba_table(&table)
389 }
390
391 pub fn predict_sparse_binary_columns(
392 &self,
393 n_rows: usize,
394 n_features: usize,
395 columns: Vec<Vec<usize>>,
396 ) -> Result<Vec<f64>, PredictError> {
397 let table = InferenceTable::from_sparse_binary_columns(
398 n_rows,
399 n_features,
400 columns,
401 self.feature_preprocessing(),
402 )?;
403 Ok(self.predict_table(&table))
404 }
405
406 pub fn predict_proba_sparse_binary_columns(
407 &self,
408 n_rows: usize,
409 n_features: usize,
410 columns: Vec<Vec<usize>>,
411 ) -> Result<Vec<Vec<f64>>, PredictError> {
412 let table = InferenceTable::from_sparse_binary_columns(
413 n_rows,
414 n_features,
415 columns,
416 self.feature_preprocessing(),
417 )?;
418 self.predict_proba_table(&table)
419 }
420
421 #[cfg(feature = "polars")]
422 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
423 let columns = polars_named_columns(df)?;
424 self.predict_named_columns(columns)
425 }
426
427 #[cfg(feature = "polars")]
428 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
429 let mut predictions = Vec::new();
430 let mut offset = 0i64;
431 loop {
432 let batch = lf
433 .clone()
434 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
435 .collect()?;
436 let height = batch.height();
437 if height == 0 {
438 break;
439 }
440 predictions.extend(self.predict_polars_dataframe(&batch)?);
441 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
442 break;
443 }
444 offset += height as i64;
445 }
446 Ok(predictions)
447 }
448
449 pub fn algorithm(&self) -> TrainAlgorithm {
450 match self {
451 Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
452 TrainAlgorithm::Dt
453 }
454 Model::RandomForest(_) => TrainAlgorithm::Rf,
455 Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
456 }
457 }
458
459 pub fn task(&self) -> Task {
460 match self {
461 Model::DecisionTreeRegressor(_) => Task::Regression,
462 Model::DecisionTreeClassifier(_) => Task::Classification,
463 Model::RandomForest(model) => model.task(),
464 Model::GradientBoostedTrees(model) => model.task(),
465 }
466 }
467
468 pub fn criterion(&self) -> Criterion {
469 match self {
470 Model::DecisionTreeClassifier(model) => model.criterion(),
471 Model::DecisionTreeRegressor(model) => model.criterion(),
472 Model::RandomForest(model) => model.criterion(),
473 Model::GradientBoostedTrees(model) => model.criterion(),
474 }
475 }
476
477 pub fn tree_type(&self) -> TreeType {
478 match self {
479 Model::DecisionTreeClassifier(model) => match model.algorithm() {
480 DecisionTreeAlgorithm::Id3 => TreeType::Id3,
481 DecisionTreeAlgorithm::C45 => TreeType::C45,
482 DecisionTreeAlgorithm::Cart => TreeType::Cart,
483 DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
484 DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
485 },
486 Model::DecisionTreeRegressor(model) => match model.algorithm() {
487 RegressionTreeAlgorithm::Cart => TreeType::Cart,
488 RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
489 RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
490 },
491 Model::RandomForest(model) => model.tree_type(),
492 Model::GradientBoostedTrees(model) => model.tree_type(),
493 }
494 }
495
496 pub fn mean_value(&self) -> Option<f64> {
497 match self {
498 Model::DecisionTreeClassifier(_)
499 | Model::DecisionTreeRegressor(_)
500 | Model::RandomForest(_)
501 | Model::GradientBoostedTrees(_) => None,
502 }
503 }
504
505 pub fn canaries(&self) -> usize {
506 self.training_metadata().canaries
507 }
508
509 pub fn max_depth(&self) -> Option<usize> {
510 self.training_metadata().max_depth
511 }
512
513 pub fn min_samples_split(&self) -> Option<usize> {
514 self.training_metadata().min_samples_split
515 }
516
517 pub fn min_samples_leaf(&self) -> Option<usize> {
518 self.training_metadata().min_samples_leaf
519 }
520
521 pub fn n_trees(&self) -> Option<usize> {
522 self.training_metadata().n_trees
523 }
524
525 pub fn max_features(&self) -> Option<usize> {
526 self.training_metadata().max_features
527 }
528
529 pub fn seed(&self) -> Option<u64> {
530 self.training_metadata().seed
531 }
532
533 pub fn compute_oob(&self) -> bool {
534 self.training_metadata().compute_oob
535 }
536
537 pub fn oob_score(&self) -> Option<f64> {
538 self.training_metadata().oob_score
539 }
540
541 pub fn learning_rate(&self) -> Option<f64> {
542 self.training_metadata().learning_rate
543 }
544
545 pub fn bootstrap(&self) -> bool {
546 self.training_metadata().bootstrap.unwrap_or(false)
547 }
548
549 pub fn top_gradient_fraction(&self) -> Option<f64> {
550 self.training_metadata().top_gradient_fraction
551 }
552
553 pub fn other_gradient_fraction(&self) -> Option<f64> {
554 self.training_metadata().other_gradient_fraction
555 }
556
557 pub fn tree_count(&self) -> usize {
558 self.to_ir().model.trees.len()
559 }
560
561 pub fn tree_structure(
562 &self,
563 tree_index: usize,
564 ) -> Result<TreeStructureSummary, IntrospectionError> {
565 tree_structure_summary(self.tree_definition(tree_index)?)
566 }
567
568 pub fn tree_prediction_stats(
569 &self,
570 tree_index: usize,
571 ) -> Result<PredictionValueStats, IntrospectionError> {
572 prediction_value_stats(self.tree_definition(tree_index)?)
573 }
574
575 pub fn tree_node(
576 &self,
577 tree_index: usize,
578 node_index: usize,
579 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
580 match self.tree_definition(tree_index)? {
581 ir::TreeDefinition::NodeTree { nodes, .. } => {
582 let available = nodes.len();
583 nodes
584 .into_iter()
585 .nth(node_index)
586 .ok_or(IntrospectionError::NodeIndexOutOfBounds {
587 requested: node_index,
588 available,
589 })
590 }
591 ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
592 }
593 }
594
595 pub fn tree_level(
596 &self,
597 tree_index: usize,
598 level_index: usize,
599 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
600 match self.tree_definition(tree_index)? {
601 ir::TreeDefinition::ObliviousLevels { levels, .. } => {
602 let available = levels.len();
603 levels.into_iter().nth(level_index).ok_or(
604 IntrospectionError::LevelIndexOutOfBounds {
605 requested: level_index,
606 available,
607 },
608 )
609 }
610 ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
611 }
612 }
613
614 pub fn tree_leaf(
615 &self,
616 tree_index: usize,
617 leaf_index: usize,
618 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
619 match self.tree_definition(tree_index)? {
620 ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
621 let available = leaves.len();
622 leaves
623 .into_iter()
624 .nth(leaf_index)
625 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
626 requested: leaf_index,
627 available,
628 })
629 }
630 ir::TreeDefinition::NodeTree { nodes, .. } => {
631 let leaves = nodes
632 .into_iter()
633 .filter_map(|node| match node {
634 ir::NodeTreeNode::Leaf {
635 node_id,
636 leaf,
637 stats,
638 ..
639 } => Some(ir::IndexedLeaf {
640 leaf_index: node_id,
641 leaf,
642 stats: ir::NodeStats {
643 sample_count: stats.sample_count,
644 impurity: stats.impurity,
645 gain: stats.gain,
646 class_counts: stats.class_counts,
647 variance: stats.variance,
648 },
649 }),
650 _ => None,
651 })
652 .collect::<Vec<_>>();
653 let available = leaves.len();
654 leaves
655 .into_iter()
656 .nth(leaf_index)
657 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
658 requested: leaf_index,
659 available,
660 })
661 }
662 }
663 }
664
665 pub fn to_ir(&self) -> ModelPackageIr {
666 ir::model_to_ir(self)
667 }
668
669 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
670 serde_json::to_string(&self.to_ir())
671 }
672
673 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
674 serde_json::to_string_pretty(&self.to_ir())
675 }
676
677 pub fn serialize(&self) -> Result<String, serde_json::Error> {
678 self.to_ir_json()
679 }
680
681 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
682 self.to_ir_json_pretty()
683 }
684
685 pub fn optimize_inference(
686 &self,
687 physical_cores: Option<usize>,
688 ) -> Result<OptimizedModel, OptimizeError> {
689 OptimizedModel::new(self.clone(), physical_cores, None)
690 }
691
692 pub fn optimize_inference_with_missing_features(
693 &self,
694 physical_cores: Option<usize>,
695 missing_features: Option<Vec<usize>>,
696 ) -> Result<OptimizedModel, OptimizeError> {
697 OptimizedModel::new(self.clone(), physical_cores, missing_features.as_deref())
698 }
699
700 pub fn json_schema() -> schemars::schema::RootSchema {
701 ModelPackageIr::json_schema()
702 }
703
704 pub fn json_schema_json() -> Result<String, IrError> {
705 ModelPackageIr::json_schema_json()
706 }
707
708 pub fn json_schema_json_pretty() -> Result<String, IrError> {
709 ModelPackageIr::json_schema_json_pretty()
710 }
711
712 pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
713 let ir: ModelPackageIr =
714 serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
715 ir::model_from_ir(ir)
716 }
717
718 pub(crate) fn num_features(&self) -> usize {
719 match self {
720 Model::DecisionTreeClassifier(model) => model.num_features(),
721 Model::DecisionTreeRegressor(model) => model.num_features(),
722 Model::RandomForest(model) => model.num_features(),
723 Model::GradientBoostedTrees(model) => model.num_features(),
724 }
725 }
726
727 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
728 match self {
729 Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
730 Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
731 Model::RandomForest(model) => model.feature_preprocessing(),
732 Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
733 }
734 }
735
736 pub fn used_feature_indices(&self) -> Vec<usize> {
737 model_used_feature_indices(self)
738 }
739
740 pub fn used_feature_count(&self) -> usize {
741 self.used_feature_indices().len()
742 }
743
744 pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
745 match self {
746 Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
747 Model::RandomForest(model) => model.class_labels(),
748 Model::GradientBoostedTrees(model) => model.class_labels(),
749 Model::DecisionTreeRegressor(_) => None,
750 }
751 }
752
753 pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
754 match self {
755 Model::DecisionTreeClassifier(model) => model.training_metadata(),
756 Model::DecisionTreeRegressor(model) => model.training_metadata(),
757 Model::RandomForest(model) => model.training_metadata(),
758 Model::GradientBoostedTrees(model) => model.training_metadata(),
759 }
760 }
761
762 fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
763 let trees = self.to_ir().model.trees;
764 let available = trees.len();
765 trees
766 .into_iter()
767 .nth(tree_index)
768 .ok_or(IntrospectionError::TreeIndexOutOfBounds {
769 requested: tree_index,
770 available,
771 })
772 }
773}