1use super::*;
2
3#[derive(Debug, Clone)]
9pub struct OptimizedModel {
10 pub(crate) source_model: Model,
11 pub(crate) runtime: OptimizedRuntime,
12 pub(crate) executor: InferenceExecutor,
13 pub(crate) feature_projection: Vec<usize>,
14}
15
16impl OptimizedModel {
17 pub(crate) fn new(
18 source_model: Model,
19 physical_cores: Option<usize>,
20 ) -> Result<Self, OptimizeError> {
21 let thread_count = resolve_inference_thread_count(physical_cores)?;
22 let feature_projection = build_feature_projection(&source_model);
23 let feature_index_map =
24 build_feature_index_map(source_model.num_features(), &feature_projection);
25 let runtime = OptimizedRuntime::from_model(&source_model, &feature_index_map);
26 let executor = InferenceExecutor::new(thread_count)?;
27
28 Ok(Self {
29 source_model,
30 runtime,
31 executor,
32 feature_projection,
33 })
34 }
35
36 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
37 let projected = ProjectedTableView::new(table, &self.feature_projection);
38 if self.runtime.should_use_batch_matrix(table.n_rows()) {
39 let matrix = ColumnMajorBinnedMatrix::from_table_access_projected(
40 table,
41 &self.feature_projection,
42 );
43 return self.predict_column_major_binned_matrix(&matrix);
44 }
45
46 self.executor.predict_rows(projected.n_rows(), |row_index| {
47 self.runtime.predict_table_row(&projected, row_index)
48 })
49 }
50
51 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
52 let table = InferenceTable::from_rows_projected(
53 rows,
54 self.source_model.feature_preprocessing(),
55 &self.feature_projection,
56 )?;
57 if self.runtime.should_use_batch_matrix(table.n_rows()) {
58 let matrix = table.to_column_major_binned_matrix();
59 Ok(self.predict_column_major_binned_matrix(&matrix))
60 } else {
61 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
62 self.runtime.predict_table_row(&table, row_index)
63 }))
64 }
65 }
66
67 pub fn predict_named_columns(
68 &self,
69 columns: BTreeMap<String, Vec<f64>>,
70 ) -> Result<Vec<f64>, PredictError> {
71 let table = InferenceTable::from_named_columns_projected(
72 columns,
73 self.source_model.feature_preprocessing(),
74 &self.feature_projection,
75 )?;
76 if self.runtime.should_use_batch_matrix(table.n_rows()) {
77 let matrix = table.to_column_major_binned_matrix();
78 Ok(self.predict_column_major_binned_matrix(&matrix))
79 } else {
80 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
81 self.runtime.predict_table_row(&table, row_index)
82 }))
83 }
84 }
85
86 pub fn predict_proba_table(
87 &self,
88 table: &dyn TableAccess,
89 ) -> Result<Vec<Vec<f64>>, PredictError> {
90 let projected = ProjectedTableView::new(table, &self.feature_projection);
91 self.runtime.predict_proba_table(&projected, &self.executor)
92 }
93
94 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
95 let table = InferenceTable::from_rows_projected(
96 rows,
97 self.source_model.feature_preprocessing(),
98 &self.feature_projection,
99 )?;
100 self.runtime.predict_proba_table(&table, &self.executor)
101 }
102
103 pub fn predict_proba_named_columns(
104 &self,
105 columns: BTreeMap<String, Vec<f64>>,
106 ) -> Result<Vec<Vec<f64>>, PredictError> {
107 let table = InferenceTable::from_named_columns_projected(
108 columns,
109 self.source_model.feature_preprocessing(),
110 &self.feature_projection,
111 )?;
112 self.runtime.predict_proba_table(&table, &self.executor)
113 }
114
115 pub fn predict_proba_sparse_binary_columns(
116 &self,
117 n_rows: usize,
118 n_features: usize,
119 columns: Vec<Vec<usize>>,
120 ) -> Result<Vec<Vec<f64>>, PredictError> {
121 let table = InferenceTable::from_sparse_binary_columns_projected(
122 n_rows,
123 n_features,
124 columns,
125 self.source_model.feature_preprocessing(),
126 &self.feature_projection,
127 )?;
128 self.runtime.predict_proba_table(&table, &self.executor)
129 }
130
131 pub fn predict_sparse_binary_columns(
132 &self,
133 n_rows: usize,
134 n_features: usize,
135 columns: Vec<Vec<usize>>,
136 ) -> Result<Vec<f64>, PredictError> {
137 let table = InferenceTable::from_sparse_binary_columns_projected(
138 n_rows,
139 n_features,
140 columns,
141 self.source_model.feature_preprocessing(),
142 &self.feature_projection,
143 )?;
144 if self.runtime.should_use_batch_matrix(table.n_rows()) {
145 let matrix = table.to_column_major_binned_matrix();
146 Ok(self.predict_column_major_binned_matrix(&matrix))
147 } else {
148 Ok(self.executor.predict_rows(table.n_rows(), |row_index| {
149 self.runtime.predict_table_row(&table, row_index)
150 }))
151 }
152 }
153
154 #[cfg(feature = "polars")]
155 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
156 let columns = polars_named_columns(df)?;
157 self.predict_named_columns(columns)
158 }
159
160 #[cfg(feature = "polars")]
161 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
162 let mut predictions = Vec::new();
163 let mut offset = 0i64;
164 loop {
165 let batch = lf
166 .clone()
167 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
168 .collect()?;
169 let height = batch.height();
170 if height == 0 {
171 break;
172 }
173 predictions.extend(self.predict_polars_dataframe(&batch)?);
174 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
175 break;
176 }
177 offset += height as i64;
178 }
179 Ok(predictions)
180 }
181
182 pub fn algorithm(&self) -> TrainAlgorithm {
183 self.source_model.algorithm()
184 }
185
186 pub fn task(&self) -> Task {
187 self.source_model.task()
188 }
189
190 pub fn criterion(&self) -> Criterion {
191 self.source_model.criterion()
192 }
193
194 pub fn tree_type(&self) -> TreeType {
195 self.source_model.tree_type()
196 }
197
198 pub fn mean_value(&self) -> Option<f64> {
199 self.source_model.mean_value()
200 }
201
202 pub fn canaries(&self) -> usize {
203 self.source_model.canaries()
204 }
205
206 pub fn max_depth(&self) -> Option<usize> {
207 self.source_model.max_depth()
208 }
209
210 pub fn min_samples_split(&self) -> Option<usize> {
211 self.source_model.min_samples_split()
212 }
213
214 pub fn min_samples_leaf(&self) -> Option<usize> {
215 self.source_model.min_samples_leaf()
216 }
217
218 pub fn n_trees(&self) -> Option<usize> {
219 self.source_model.n_trees()
220 }
221
222 pub fn max_features(&self) -> Option<usize> {
223 self.source_model.max_features()
224 }
225
226 pub fn seed(&self) -> Option<u64> {
227 self.source_model.seed()
228 }
229
230 pub fn compute_oob(&self) -> bool {
231 self.source_model.compute_oob()
232 }
233
234 pub fn oob_score(&self) -> Option<f64> {
235 self.source_model.oob_score()
236 }
237
238 pub fn learning_rate(&self) -> Option<f64> {
239 self.source_model.learning_rate()
240 }
241
242 pub fn bootstrap(&self) -> bool {
243 self.source_model.bootstrap()
244 }
245
246 pub fn top_gradient_fraction(&self) -> Option<f64> {
247 self.source_model.top_gradient_fraction()
248 }
249
250 pub fn other_gradient_fraction(&self) -> Option<f64> {
251 self.source_model.other_gradient_fraction()
252 }
253
254 pub fn tree_count(&self) -> usize {
255 self.source_model.tree_count()
256 }
257
258 pub fn used_feature_indices(&self) -> Vec<usize> {
259 self.feature_projection.clone()
260 }
261
262 pub fn used_feature_count(&self) -> usize {
263 self.feature_projection.len()
264 }
265
266 pub fn tree_structure(
267 &self,
268 tree_index: usize,
269 ) -> Result<TreeStructureSummary, IntrospectionError> {
270 self.source_model.tree_structure(tree_index)
271 }
272
273 pub fn tree_prediction_stats(
274 &self,
275 tree_index: usize,
276 ) -> Result<PredictionValueStats, IntrospectionError> {
277 self.source_model.tree_prediction_stats(tree_index)
278 }
279
280 pub fn tree_node(
281 &self,
282 tree_index: usize,
283 node_index: usize,
284 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
285 self.source_model.tree_node(tree_index, node_index)
286 }
287
288 pub fn tree_level(
289 &self,
290 tree_index: usize,
291 level_index: usize,
292 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
293 self.source_model.tree_level(tree_index, level_index)
294 }
295
296 pub fn tree_leaf(
297 &self,
298 tree_index: usize,
299 leaf_index: usize,
300 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
301 self.source_model.tree_leaf(tree_index, leaf_index)
302 }
303
304 pub fn to_ir(&self) -> ModelPackageIr {
305 self.source_model.to_ir()
306 }
307
308 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
309 self.source_model.to_ir_json()
310 }
311
312 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
313 self.source_model.to_ir_json_pretty()
314 }
315
316 pub fn serialize(&self) -> Result<String, serde_json::Error> {
317 self.source_model.serialize()
318 }
319
320 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
321 self.source_model.serialize_pretty()
322 }
323
324 pub(crate) fn predict_column_major_binned_matrix(
325 &self,
326 matrix: &ColumnMajorBinnedMatrix,
327 ) -> Vec<f64> {
328 self.runtime
329 .predict_column_major_matrix(matrix, &self.executor)
330 }
331}
332
333impl Model {
334 pub fn predict_table(&self, table: &dyn TableAccess) -> Vec<f64> {
335 match self {
336 Model::DecisionTreeClassifier(model) => model.predict_table(table),
337 Model::DecisionTreeRegressor(model) => model.predict_table(table),
338 Model::RandomForest(model) => model.predict_table(table),
339 Model::GradientBoostedTrees(model) => model.predict_table(table),
340 }
341 }
342
343 pub fn predict_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<f64>, PredictError> {
344 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
345 Ok(self.predict_table(&table))
346 }
347
348 pub fn predict_proba_table(
349 &self,
350 table: &dyn TableAccess,
351 ) -> Result<Vec<Vec<f64>>, PredictError> {
352 match self {
353 Model::DecisionTreeClassifier(model) => Ok(model.predict_proba_table(table)),
354 Model::RandomForest(model) => model.predict_proba_table(table),
355 Model::GradientBoostedTrees(model) => model.predict_proba_table(table),
356 Model::DecisionTreeRegressor(_) => {
357 Err(PredictError::ProbabilityPredictionRequiresClassification)
358 }
359 }
360 }
361
362 pub fn predict_proba_rows(&self, rows: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PredictError> {
363 let table = InferenceTable::from_rows(rows, self.feature_preprocessing())?;
364 self.predict_proba_table(&table)
365 }
366
367 pub fn predict_named_columns(
368 &self,
369 columns: BTreeMap<String, Vec<f64>>,
370 ) -> Result<Vec<f64>, PredictError> {
371 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
372 Ok(self.predict_table(&table))
373 }
374
375 pub fn predict_proba_named_columns(
376 &self,
377 columns: BTreeMap<String, Vec<f64>>,
378 ) -> Result<Vec<Vec<f64>>, PredictError> {
379 let table = InferenceTable::from_named_columns(columns, self.feature_preprocessing())?;
380 self.predict_proba_table(&table)
381 }
382
383 pub fn predict_sparse_binary_columns(
384 &self,
385 n_rows: usize,
386 n_features: usize,
387 columns: Vec<Vec<usize>>,
388 ) -> Result<Vec<f64>, PredictError> {
389 let table = InferenceTable::from_sparse_binary_columns(
390 n_rows,
391 n_features,
392 columns,
393 self.feature_preprocessing(),
394 )?;
395 Ok(self.predict_table(&table))
396 }
397
398 pub fn predict_proba_sparse_binary_columns(
399 &self,
400 n_rows: usize,
401 n_features: usize,
402 columns: Vec<Vec<usize>>,
403 ) -> Result<Vec<Vec<f64>>, PredictError> {
404 let table = InferenceTable::from_sparse_binary_columns(
405 n_rows,
406 n_features,
407 columns,
408 self.feature_preprocessing(),
409 )?;
410 self.predict_proba_table(&table)
411 }
412
413 #[cfg(feature = "polars")]
414 pub fn predict_polars_dataframe(&self, df: &DataFrame) -> Result<Vec<f64>, PredictError> {
415 let columns = polars_named_columns(df)?;
416 self.predict_named_columns(columns)
417 }
418
419 #[cfg(feature = "polars")]
420 pub fn predict_polars_lazyframe(&self, lf: &LazyFrame) -> Result<Vec<f64>, PredictError> {
421 let mut predictions = Vec::new();
422 let mut offset = 0i64;
423 loop {
424 let batch = lf
425 .clone()
426 .slice(offset, LAZYFRAME_PREDICT_BATCH_ROWS as IdxSize)
427 .collect()?;
428 let height = batch.height();
429 if height == 0 {
430 break;
431 }
432 predictions.extend(self.predict_polars_dataframe(&batch)?);
433 if height < LAZYFRAME_PREDICT_BATCH_ROWS {
434 break;
435 }
436 offset += height as i64;
437 }
438 Ok(predictions)
439 }
440
441 pub fn algorithm(&self) -> TrainAlgorithm {
442 match self {
443 Model::DecisionTreeClassifier(_) | Model::DecisionTreeRegressor(_) => {
444 TrainAlgorithm::Dt
445 }
446 Model::RandomForest(_) => TrainAlgorithm::Rf,
447 Model::GradientBoostedTrees(_) => TrainAlgorithm::Gbm,
448 }
449 }
450
451 pub fn task(&self) -> Task {
452 match self {
453 Model::DecisionTreeRegressor(_) => Task::Regression,
454 Model::DecisionTreeClassifier(_) => Task::Classification,
455 Model::RandomForest(model) => model.task(),
456 Model::GradientBoostedTrees(model) => model.task(),
457 }
458 }
459
460 pub fn criterion(&self) -> Criterion {
461 match self {
462 Model::DecisionTreeClassifier(model) => model.criterion(),
463 Model::DecisionTreeRegressor(model) => model.criterion(),
464 Model::RandomForest(model) => model.criterion(),
465 Model::GradientBoostedTrees(model) => model.criterion(),
466 }
467 }
468
469 pub fn tree_type(&self) -> TreeType {
470 match self {
471 Model::DecisionTreeClassifier(model) => match model.algorithm() {
472 DecisionTreeAlgorithm::Id3 => TreeType::Id3,
473 DecisionTreeAlgorithm::C45 => TreeType::C45,
474 DecisionTreeAlgorithm::Cart => TreeType::Cart,
475 DecisionTreeAlgorithm::Randomized => TreeType::Randomized,
476 DecisionTreeAlgorithm::Oblivious => TreeType::Oblivious,
477 },
478 Model::DecisionTreeRegressor(model) => match model.algorithm() {
479 RegressionTreeAlgorithm::Cart => TreeType::Cart,
480 RegressionTreeAlgorithm::Randomized => TreeType::Randomized,
481 RegressionTreeAlgorithm::Oblivious => TreeType::Oblivious,
482 },
483 Model::RandomForest(model) => model.tree_type(),
484 Model::GradientBoostedTrees(model) => model.tree_type(),
485 }
486 }
487
488 pub fn mean_value(&self) -> Option<f64> {
489 match self {
490 Model::DecisionTreeClassifier(_)
491 | Model::DecisionTreeRegressor(_)
492 | Model::RandomForest(_)
493 | Model::GradientBoostedTrees(_) => None,
494 }
495 }
496
497 pub fn canaries(&self) -> usize {
498 self.training_metadata().canaries
499 }
500
501 pub fn max_depth(&self) -> Option<usize> {
502 self.training_metadata().max_depth
503 }
504
505 pub fn min_samples_split(&self) -> Option<usize> {
506 self.training_metadata().min_samples_split
507 }
508
509 pub fn min_samples_leaf(&self) -> Option<usize> {
510 self.training_metadata().min_samples_leaf
511 }
512
513 pub fn n_trees(&self) -> Option<usize> {
514 self.training_metadata().n_trees
515 }
516
517 pub fn max_features(&self) -> Option<usize> {
518 self.training_metadata().max_features
519 }
520
521 pub fn seed(&self) -> Option<u64> {
522 self.training_metadata().seed
523 }
524
525 pub fn compute_oob(&self) -> bool {
526 self.training_metadata().compute_oob
527 }
528
529 pub fn oob_score(&self) -> Option<f64> {
530 self.training_metadata().oob_score
531 }
532
533 pub fn learning_rate(&self) -> Option<f64> {
534 self.training_metadata().learning_rate
535 }
536
537 pub fn bootstrap(&self) -> bool {
538 self.training_metadata().bootstrap.unwrap_or(false)
539 }
540
541 pub fn top_gradient_fraction(&self) -> Option<f64> {
542 self.training_metadata().top_gradient_fraction
543 }
544
545 pub fn other_gradient_fraction(&self) -> Option<f64> {
546 self.training_metadata().other_gradient_fraction
547 }
548
549 pub fn tree_count(&self) -> usize {
550 self.to_ir().model.trees.len()
551 }
552
553 pub fn tree_structure(
554 &self,
555 tree_index: usize,
556 ) -> Result<TreeStructureSummary, IntrospectionError> {
557 tree_structure_summary(self.tree_definition(tree_index)?)
558 }
559
560 pub fn tree_prediction_stats(
561 &self,
562 tree_index: usize,
563 ) -> Result<PredictionValueStats, IntrospectionError> {
564 prediction_value_stats(self.tree_definition(tree_index)?)
565 }
566
567 pub fn tree_node(
568 &self,
569 tree_index: usize,
570 node_index: usize,
571 ) -> Result<ir::NodeTreeNode, IntrospectionError> {
572 match self.tree_definition(tree_index)? {
573 ir::TreeDefinition::NodeTree { nodes, .. } => {
574 let available = nodes.len();
575 nodes
576 .into_iter()
577 .nth(node_index)
578 .ok_or(IntrospectionError::NodeIndexOutOfBounds {
579 requested: node_index,
580 available,
581 })
582 }
583 ir::TreeDefinition::ObliviousLevels { .. } => Err(IntrospectionError::NotANodeTree),
584 }
585 }
586
587 pub fn tree_level(
588 &self,
589 tree_index: usize,
590 level_index: usize,
591 ) -> Result<ir::ObliviousLevel, IntrospectionError> {
592 match self.tree_definition(tree_index)? {
593 ir::TreeDefinition::ObliviousLevels { levels, .. } => {
594 let available = levels.len();
595 levels.into_iter().nth(level_index).ok_or(
596 IntrospectionError::LevelIndexOutOfBounds {
597 requested: level_index,
598 available,
599 },
600 )
601 }
602 ir::TreeDefinition::NodeTree { .. } => Err(IntrospectionError::NotAnObliviousTree),
603 }
604 }
605
606 pub fn tree_leaf(
607 &self,
608 tree_index: usize,
609 leaf_index: usize,
610 ) -> Result<ir::IndexedLeaf, IntrospectionError> {
611 match self.tree_definition(tree_index)? {
612 ir::TreeDefinition::ObliviousLevels { leaves, .. } => {
613 let available = leaves.len();
614 leaves
615 .into_iter()
616 .nth(leaf_index)
617 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
618 requested: leaf_index,
619 available,
620 })
621 }
622 ir::TreeDefinition::NodeTree { nodes, .. } => {
623 let leaves = nodes
624 .into_iter()
625 .filter_map(|node| match node {
626 ir::NodeTreeNode::Leaf {
627 node_id,
628 leaf,
629 stats,
630 ..
631 } => Some(ir::IndexedLeaf {
632 leaf_index: node_id,
633 leaf,
634 stats: ir::NodeStats {
635 sample_count: stats.sample_count,
636 impurity: stats.impurity,
637 gain: stats.gain,
638 class_counts: stats.class_counts,
639 variance: stats.variance,
640 },
641 }),
642 _ => None,
643 })
644 .collect::<Vec<_>>();
645 let available = leaves.len();
646 leaves
647 .into_iter()
648 .nth(leaf_index)
649 .ok_or(IntrospectionError::LeafIndexOutOfBounds {
650 requested: leaf_index,
651 available,
652 })
653 }
654 }
655 }
656
657 pub fn to_ir(&self) -> ModelPackageIr {
658 ir::model_to_ir(self)
659 }
660
661 pub fn to_ir_json(&self) -> Result<String, serde_json::Error> {
662 serde_json::to_string(&self.to_ir())
663 }
664
665 pub fn to_ir_json_pretty(&self) -> Result<String, serde_json::Error> {
666 serde_json::to_string_pretty(&self.to_ir())
667 }
668
669 pub fn serialize(&self) -> Result<String, serde_json::Error> {
670 self.to_ir_json()
671 }
672
673 pub fn serialize_pretty(&self) -> Result<String, serde_json::Error> {
674 self.to_ir_json_pretty()
675 }
676
677 pub fn optimize_inference(
678 &self,
679 physical_cores: Option<usize>,
680 ) -> Result<OptimizedModel, OptimizeError> {
681 OptimizedModel::new(self.clone(), physical_cores)
682 }
683
684 pub fn json_schema() -> schemars::schema::RootSchema {
685 ModelPackageIr::json_schema()
686 }
687
688 pub fn json_schema_json() -> Result<String, IrError> {
689 ModelPackageIr::json_schema_json()
690 }
691
692 pub fn json_schema_json_pretty() -> Result<String, IrError> {
693 ModelPackageIr::json_schema_json_pretty()
694 }
695
696 pub fn deserialize(serialized: &str) -> Result<Self, IrError> {
697 let ir: ModelPackageIr =
698 serde_json::from_str(serialized).map_err(|err| IrError::Json(err.to_string()))?;
699 ir::model_from_ir(ir)
700 }
701
702 pub(crate) fn num_features(&self) -> usize {
703 match self {
704 Model::DecisionTreeClassifier(model) => model.num_features(),
705 Model::DecisionTreeRegressor(model) => model.num_features(),
706 Model::RandomForest(model) => model.num_features(),
707 Model::GradientBoostedTrees(model) => model.num_features(),
708 }
709 }
710
711 pub(crate) fn feature_preprocessing(&self) -> &[FeaturePreprocessing] {
712 match self {
713 Model::DecisionTreeClassifier(model) => model.feature_preprocessing(),
714 Model::DecisionTreeRegressor(model) => model.feature_preprocessing(),
715 Model::RandomForest(model) => model.feature_preprocessing(),
716 Model::GradientBoostedTrees(model) => model.feature_preprocessing(),
717 }
718 }
719
720 pub fn used_feature_indices(&self) -> Vec<usize> {
721 model_used_feature_indices(self)
722 }
723
724 pub fn used_feature_count(&self) -> usize {
725 self.used_feature_indices().len()
726 }
727
728 pub(crate) fn class_labels(&self) -> Option<Vec<f64>> {
729 match self {
730 Model::DecisionTreeClassifier(model) => Some(model.class_labels().to_vec()),
731 Model::RandomForest(model) => model.class_labels(),
732 Model::GradientBoostedTrees(model) => model.class_labels(),
733 Model::DecisionTreeRegressor(_) => None,
734 }
735 }
736
737 pub(crate) fn training_metadata(&self) -> ir::TrainingMetadata {
738 match self {
739 Model::DecisionTreeClassifier(model) => model.training_metadata(),
740 Model::DecisionTreeRegressor(model) => model.training_metadata(),
741 Model::RandomForest(model) => model.training_metadata(),
742 Model::GradientBoostedTrees(model) => model.training_metadata(),
743 }
744 }
745
746 fn tree_definition(&self, tree_index: usize) -> Result<ir::TreeDefinition, IntrospectionError> {
747 let trees = self.to_ir().model.trees;
748 let available = trees.len();
749 trees
750 .into_iter()
751 .nth(tree_index)
752 .ok_or(IntrospectionError::TreeIndexOutOfBounds {
753 requested: tree_index,
754 available,
755 })
756 }
757}