1use crate::tree::classifier::{
13 DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14 ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15 TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18 DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19 RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22 Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23 NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36 pub ir_version: String,
37 pub format_name: String,
38 pub producer: ProducerMetadata,
39 pub model: ModelSection,
40 pub input_schema: InputSchema,
41 pub output_schema: OutputSchema,
42 pub inference_options: InferenceOptions,
43 pub preprocessing: PreprocessingSection,
44 pub postprocessing: PostprocessingSection,
45 pub training_metadata: TrainingMetadata,
46 pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51 pub library: String,
52 pub library_version: String,
53 pub language: String,
54 pub platform: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60 pub algorithm: String,
61 pub task: String,
62 pub tree_type: String,
63 pub representation: String,
64 pub num_features: usize,
65 pub num_outputs: usize,
66 pub supports_missing: bool,
67 pub supports_categorical: bool,
68 pub is_ensemble: bool,
69 pub trees: Vec<TreeDefinition>,
70 pub aggregation: Aggregation,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77 NodeTree {
78 tree_id: usize,
79 weight: f64,
80 root_node_id: usize,
81 nodes: Vec<NodeTreeNode>,
82 },
83 ObliviousLevels {
84 tree_id: usize,
85 weight: f64,
86 depth: usize,
87 levels: Vec<ObliviousLevel>,
88 leaf_indexing: LeafIndexing,
89 leaves: Vec<IndexedLeaf>,
90 },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96 Leaf {
97 node_id: usize,
98 depth: usize,
99 leaf: LeafPayload,
100 stats: NodeStats,
101 },
102 BinaryBranch {
103 node_id: usize,
104 depth: usize,
105 split: BinarySplit,
106 children: BinaryChildren,
107 stats: NodeStats,
108 },
109 MultiwayBranch {
110 node_id: usize,
111 depth: usize,
112 split: MultiwaySplit,
113 branches: Vec<MultiwayBranch>,
114 unmatched_leaf: LeafPayload,
115 stats: NodeStats,
116 },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121 pub left: usize,
122 pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127 pub bin: u16,
128 pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134 NumericBinThreshold {
135 feature_index: usize,
136 feature_name: String,
137 operator: String,
138 threshold_bin: u16,
139 threshold_upper_bound: Option<f64>,
140 comparison_dtype: String,
141 },
142 BooleanTest {
143 feature_index: usize,
144 feature_name: String,
145 false_child_semantics: String,
146 true_child_semantics: String,
147 },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152 pub split_type: String,
153 pub feature_index: usize,
154 pub feature_name: String,
155 pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160 pub level: usize,
161 pub split: ObliviousSplit,
162 pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168 NumericBinThreshold {
169 feature_index: usize,
170 feature_name: String,
171 operator: String,
172 threshold_bin: u16,
173 threshold_upper_bound: Option<f64>,
174 comparison_dtype: String,
175 bit_when_true: u8,
176 bit_when_false: u8,
177 },
178 BooleanTest {
179 feature_index: usize,
180 feature_name: String,
181 bit_when_false: u8,
182 bit_when_true: u8,
183 },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188 pub bit_order: String,
189 pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194 pub leaf_index: usize,
195 pub leaf: LeafPayload,
196 pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202 RegressionValue {
203 value: f64,
204 },
205 ClassIndex {
206 class_index: usize,
207 class_value: f64,
208 },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213 pub kind: String,
214 pub tree_weights: Vec<f64>,
215 pub normalize_by_weight_sum: bool,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222 pub feature_count: usize,
223 pub features: Vec<InputFeature>,
224 pub ordering: String,
225 pub input_tensor_layout: String,
226 pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231 pub index: usize,
232 pub name: String,
233 pub dtype: String,
234 pub logical_type: String,
235 pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240 pub raw_outputs: Vec<OutputField>,
241 pub final_outputs: Vec<OutputField>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248 pub name: String,
249 pub kind: String,
250 pub shape: Vec<usize>,
251 pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256 pub numeric_precision: String,
257 pub threshold_comparison: String,
258 pub nan_policy: String,
259 pub bool_encoding: BoolEncoding,
260 pub tie_breaking: TieBreaking,
261 pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266 pub false_values: Vec<String>,
267 pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272 pub classification: String,
273 pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278 pub guaranteed: bool,
279 pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284 pub included_in_model: bool,
285 pub numeric_binning: NumericBinning,
286 pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291 pub kind: String,
292 pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298 Numeric {
299 feature_index: usize,
300 boundaries: Vec<NumericBinBoundary>,
301 },
302 Binary {
303 feature_index: usize,
304 },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309 pub raw_output_kind: String,
310 pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316 Identity,
317 MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323 pub algorithm: String,
324 pub task: String,
325 pub tree_type: String,
326 pub criterion: String,
327 pub canaries: usize,
328 pub compute_oob: bool,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 pub max_depth: Option<usize>,
331 #[serde(skip_serializing_if = "Option::is_none")]
332 pub min_samples_split: Option<usize>,
333 #[serde(skip_serializing_if = "Option::is_none")]
334 pub min_samples_leaf: Option<usize>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub n_trees: Option<usize>,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 pub max_features: Option<usize>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 pub seed: Option<u64>,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub oob_score: Option<f64>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub class_labels: Option<Vec<f64>>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub learning_rate: Option<f64>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub bootstrap: Option<bool>,
349 #[serde(skip_serializing_if = "Option::is_none")]
350 pub top_gradient_fraction: Option<f64>,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357 pub serialization: String,
358 pub canonical_json: bool,
359 pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364 pub minimum_runtime_version: String,
365 pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370 pub sample_count: usize,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub impurity: Option<f64>,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub gain: Option<f64>,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 pub class_counts: Option<Vec<usize>>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383 UnsupportedIrVersion(String),
384 UnsupportedFormatName(String),
385 UnsupportedAlgorithm(String),
386 UnsupportedTask(String),
387 UnsupportedTreeType(String),
388 InvalidTreeCount(usize),
389 UnsupportedRepresentation(String),
390 InvalidFeatureCount { schema: usize, preprocessing: usize },
391 MissingClassLabels,
392 InvalidLeaf(String),
393 InvalidNode(String),
394 InvalidPreprocessing(String),
395 InvalidInferenceOption(String),
396 Json(String),
397}
398
399impl Display for IrError {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 match self {
402 IrError::UnsupportedIrVersion(version) => {
403 write!(f, "Unsupported IR version: {}.", version)
404 }
405 IrError::UnsupportedFormatName(name) => {
406 write!(f, "Unsupported IR format: {}.", name)
407 }
408 IrError::UnsupportedAlgorithm(algorithm) => {
409 write!(f, "Unsupported algorithm: {}.", algorithm)
410 }
411 IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412 IrError::UnsupportedTreeType(tree_type) => {
413 write!(f, "Unsupported tree type: {}.", tree_type)
414 }
415 IrError::InvalidTreeCount(count) => {
416 write!(f, "Expected exactly one tree in the IR, found {}.", count)
417 }
418 IrError::UnsupportedRepresentation(representation) => {
419 write!(f, "Unsupported tree representation: {}.", representation)
420 }
421 IrError::InvalidFeatureCount {
422 schema,
423 preprocessing,
424 } => write!(
425 f,
426 "Input schema declares {} features, but preprocessing declares {}.",
427 schema, preprocessing
428 ),
429 IrError::MissingClassLabels => {
430 write!(f, "Classification IR requires explicit class labels.")
431 }
432 IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433 IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434 IrError::InvalidPreprocessing(message) => {
435 write!(f, "Invalid preprocessing section: {}.", message)
436 }
437 IrError::InvalidInferenceOption(message) => {
438 write!(f, "Invalid inference options: {}.", message)
439 }
440 IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441 }
442 }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448 pub fn json_schema() -> RootSchema {
449 schema_for!(ModelPackageIr)
450 }
451
452 pub fn json_schema_json() -> Result<String, IrError> {
453 serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454 }
455
456 pub fn json_schema_json_pretty() -> Result<String, IrError> {
457 serde_json::to_string_pretty(&Self::json_schema())
458 .map_err(|err| IrError::Json(err.to_string()))
459 }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463 let trees = match model {
464 Model::RandomForest(forest) => forest
465 .trees()
466 .iter()
467 .map(model_tree_definition)
468 .collect::<Vec<_>>(),
469 Model::GradientBoostedTrees(boosted) => boosted
470 .trees()
471 .iter()
472 .map(model_tree_definition)
473 .collect::<Vec<_>>(),
474 _ => vec![model_tree_definition(model)],
475 };
476 let representation = if let Some(first_tree) = trees.first() {
477 match first_tree {
478 TreeDefinition::NodeTree { .. } => "node_tree",
479 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480 }
481 } else {
482 match model.tree_type() {
483 TreeType::Oblivious => "oblivious_levels",
484 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485 }
486 };
487 let class_labels = model.class_labels();
488 let is_ensemble = matches!(
489 model,
490 Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491 );
492 let tree_count = trees.len();
493 let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494 Model::RandomForest(_) => (
495 match model.task() {
496 Task::Regression => "average",
497 Task::Classification => "average_class_probabilities",
498 },
499 vec![1.0; tree_count],
500 true,
501 None,
502 ),
503 Model::GradientBoostedTrees(boosted) => (
504 match boosted.task() {
505 Task::Regression => "sum_tree_outputs",
506 Task::Classification => "sum_tree_outputs_then_sigmoid",
507 },
508 boosted.tree_weights().to_vec(),
509 false,
510 Some(boosted.base_score()),
511 ),
512 _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513 };
514
515 ModelPackageIr {
516 ir_version: IR_VERSION.to_string(),
517 format_name: FORMAT_NAME.to_string(),
518 producer: ProducerMetadata {
519 library: "forestfire-core".to_string(),
520 library_version: env!("CARGO_PKG_VERSION").to_string(),
521 language: "rust".to_string(),
522 platform: std::env::consts::ARCH.to_string(),
523 },
524 model: ModelSection {
525 algorithm: algorithm_name(model.algorithm()).to_string(),
526 task: task_name(model.task()).to_string(),
527 tree_type: tree_type_name(model.tree_type()).to_string(),
528 representation: representation.to_string(),
529 num_features: model.num_features(),
530 num_outputs: 1,
531 supports_missing: false,
532 supports_categorical: false,
533 is_ensemble,
534 trees,
535 aggregation: Aggregation {
536 kind: aggregation_kind.to_string(),
537 tree_weights,
538 normalize_by_weight_sum,
539 base_score,
540 },
541 },
542 input_schema: input_schema(model),
543 output_schema: output_schema(model, class_labels.clone()),
544 inference_options: InferenceOptions {
545 numeric_precision: "float64".to_string(),
546 threshold_comparison: "leq_left_gt_right".to_string(),
547 nan_policy: "not_supported".to_string(),
548 bool_encoding: BoolEncoding {
549 false_values: vec!["0".to_string(), "false".to_string()],
550 true_values: vec!["1".to_string(), "true".to_string()],
551 },
552 tie_breaking: TieBreaking {
553 classification: "lowest_class_index".to_string(),
554 argmax: "first_max_index".to_string(),
555 },
556 determinism: Determinism {
557 guaranteed: true,
558 notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559 .to_string(),
560 },
561 },
562 preprocessing: preprocessing(model),
563 postprocessing: postprocessing(model, class_labels),
564 training_metadata: model.training_metadata(),
565 integrity: IntegritySection {
566 serialization: "json".to_string(),
567 canonical_json: true,
568 compatibility: Compatibility {
569 minimum_runtime_version: IR_VERSION.to_string(),
570 required_capabilities: required_capabilities(model, representation),
571 },
572 },
573 }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577 validate_ir_header(&ir)?;
578 validate_inference_options(&ir.inference_options)?;
579
580 let algorithm = parse_algorithm(&ir.model.algorithm)?;
581 let task = parse_task(&ir.model.task)?;
582 let tree_type = parse_tree_type(&ir.model.tree_type)?;
583 let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584 let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585 let num_features = ir.input_schema.feature_count;
586 let options = tree_options(&ir.training_metadata);
587 let training_canaries = ir.training_metadata.canaries;
588 let deserialized_class_labels = classification_labels(&ir).ok();
589
590 if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591 return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592 }
593
594 if algorithm == TrainAlgorithm::Rf {
595 let trees = ir
596 .model
597 .trees
598 .into_iter()
599 .map(|tree| {
600 single_model_from_ir_parts(
601 task,
602 tree_type,
603 criterion,
604 feature_preprocessing.clone(),
605 num_features,
606 options,
607 training_canaries,
608 deserialized_class_labels.clone(),
609 tree,
610 )
611 })
612 .collect::<Result<Vec<_>, _>>()?;
613 return Ok(Model::RandomForest(RandomForest::new(
614 task,
615 criterion,
616 tree_type,
617 trees,
618 ir.training_metadata.compute_oob,
619 ir.training_metadata.oob_score,
620 ir.training_metadata
621 .max_features
622 .unwrap_or(num_features.max(1)),
623 ir.training_metadata.seed,
624 num_features,
625 feature_preprocessing,
626 )));
627 }
628
629 if algorithm == TrainAlgorithm::Gbm {
630 let tree_weights = ir.model.aggregation.tree_weights.clone();
631 let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632 let trees = ir
633 .model
634 .trees
635 .into_iter()
636 .map(|tree| {
637 single_model_from_ir_parts(
638 task,
639 tree_type,
640 criterion,
641 feature_preprocessing.clone(),
642 num_features,
643 options,
644 training_canaries,
645 deserialized_class_labels.clone(),
646 tree,
647 )
648 })
649 .collect::<Result<Vec<_>, _>>()?;
650 return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
651 task,
652 tree_type,
653 trees,
654 tree_weights,
655 base_score,
656 ir.training_metadata.learning_rate.unwrap_or(0.1),
657 ir.training_metadata.bootstrap.unwrap_or(false),
658 ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
659 ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
660 ir.training_metadata
661 .max_features
662 .unwrap_or(num_features.max(1)),
663 ir.training_metadata.seed,
664 num_features,
665 feature_preprocessing,
666 deserialized_class_labels,
667 training_canaries,
668 )));
669 }
670
671 let tree = ir
672 .model
673 .trees
674 .into_iter()
675 .next()
676 .expect("validated single tree");
677
678 single_model_from_ir_parts(
679 task,
680 tree_type,
681 criterion,
682 feature_preprocessing,
683 num_features,
684 options,
685 training_canaries,
686 deserialized_class_labels,
687 tree,
688 )
689}
690
691#[allow(clippy::too_many_arguments)]
692fn single_model_from_ir_parts(
693 task: Task,
694 tree_type: TreeType,
695 criterion: Criterion,
696 feature_preprocessing: Vec<FeaturePreprocessing>,
697 num_features: usize,
698 options: DecisionTreeOptions,
699 training_canaries: usize,
700 deserialized_class_labels: Option<Vec<f64>>,
701 tree: TreeDefinition,
702) -> Result<Model, IrError> {
703 match (task, tree_type, tree) {
704 (
705 Task::Classification,
706 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
707 TreeDefinition::NodeTree {
708 nodes,
709 root_node_id,
710 ..
711 },
712 ) => {
713 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
714 let structure = ClassifierTreeStructure::Standard {
715 nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
716 root: root_node_id,
717 };
718 Ok(Model::DecisionTreeClassifier(
719 DecisionTreeClassifier::from_ir_parts(
720 match tree_type {
721 TreeType::Id3 => DecisionTreeAlgorithm::Id3,
722 TreeType::C45 => DecisionTreeAlgorithm::C45,
723 TreeType::Cart => DecisionTreeAlgorithm::Cart,
724 TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
725 TreeType::Oblivious => unreachable!(),
726 },
727 criterion,
728 class_labels,
729 structure,
730 options,
731 num_features,
732 feature_preprocessing,
733 training_canaries,
734 ),
735 ))
736 }
737 (
738 Task::Classification,
739 TreeType::Oblivious,
740 TreeDefinition::ObliviousLevels { levels, leaves, .. },
741 ) => {
742 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
743 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
744 let leaf_class_counts =
745 rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
746 let structure = ClassifierTreeStructure::Oblivious {
747 splits: rebuild_classifier_oblivious_splits(levels)?,
748 leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
749 leaf_sample_counts,
750 leaf_class_counts,
751 };
752 Ok(Model::DecisionTreeClassifier(
753 DecisionTreeClassifier::from_ir_parts(
754 DecisionTreeAlgorithm::Oblivious,
755 criterion,
756 class_labels,
757 structure,
758 options,
759 num_features,
760 feature_preprocessing,
761 training_canaries,
762 ),
763 ))
764 }
765 (
766 Task::Regression,
767 TreeType::Cart | TreeType::Randomized,
768 TreeDefinition::NodeTree {
769 nodes,
770 root_node_id,
771 ..
772 },
773 ) => Ok(Model::DecisionTreeRegressor(
774 DecisionTreeRegressor::from_ir_parts(
775 match tree_type {
776 TreeType::Cart => RegressionTreeAlgorithm::Cart,
777 TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
778 _ => unreachable!(),
779 },
780 criterion,
781 RegressionTreeStructure::Standard {
782 nodes: rebuild_regressor_nodes(nodes)?,
783 root: root_node_id,
784 },
785 RegressionTreeOptions {
786 max_depth: options.max_depth,
787 min_samples_split: options.min_samples_split,
788 min_samples_leaf: options.min_samples_leaf,
789 max_features: None,
790 random_seed: 0,
791 },
792 num_features,
793 feature_preprocessing,
794 training_canaries,
795 ),
796 )),
797 (
798 Task::Regression,
799 TreeType::Oblivious,
800 TreeDefinition::ObliviousLevels { levels, leaves, .. },
801 ) => {
802 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
803 let leaf_variances = rebuild_leaf_variances(&leaves)?;
804 Ok(Model::DecisionTreeRegressor(
805 DecisionTreeRegressor::from_ir_parts(
806 RegressionTreeAlgorithm::Oblivious,
807 criterion,
808 RegressionTreeStructure::Oblivious {
809 splits: rebuild_regressor_oblivious_splits(levels)?,
810 leaf_values: rebuild_regressor_leaf_values(leaves)?,
811 leaf_sample_counts,
812 leaf_variances,
813 },
814 RegressionTreeOptions {
815 max_depth: options.max_depth,
816 min_samples_split: options.min_samples_split,
817 min_samples_leaf: options.min_samples_leaf,
818 max_features: None,
819 random_seed: 0,
820 },
821 num_features,
822 feature_preprocessing,
823 training_canaries,
824 ),
825 ))
826 }
827 (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
828 TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
829 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
830 })),
831 }
832}
833
834fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
835 if ir.ir_version != IR_VERSION {
836 return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
837 }
838 if ir.format_name != FORMAT_NAME {
839 return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
840 }
841 if ir.model.supports_missing {
842 return Err(IrError::InvalidInferenceOption(
843 "missing values are not supported in IR v1".to_string(),
844 ));
845 }
846 if ir.model.supports_categorical {
847 return Err(IrError::InvalidInferenceOption(
848 "categorical features are not supported in IR v1".to_string(),
849 ));
850 }
851 Ok(())
852}
853
854fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
855 if options.threshold_comparison != "leq_left_gt_right" {
856 return Err(IrError::InvalidInferenceOption(format!(
857 "unsupported threshold comparison '{}'",
858 options.threshold_comparison
859 )));
860 }
861 if options.nan_policy != "not_supported" {
862 return Err(IrError::InvalidInferenceOption(format!(
863 "unsupported nan policy '{}'",
864 options.nan_policy
865 )));
866 }
867 Ok(())
868}
869
870fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
871 match value {
872 "dt" => Ok(TrainAlgorithm::Dt),
873 "rf" => Ok(TrainAlgorithm::Rf),
874 "gbm" => Ok(TrainAlgorithm::Gbm),
875 _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
876 }
877}
878
879fn parse_task(value: &str) -> Result<Task, IrError> {
880 match value {
881 "regression" => Ok(Task::Regression),
882 "classification" => Ok(Task::Classification),
883 _ => Err(IrError::UnsupportedTask(value.to_string())),
884 }
885}
886
887fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
888 match value {
889 "id3" => Ok(TreeType::Id3),
890 "c45" => Ok(TreeType::C45),
891 "cart" => Ok(TreeType::Cart),
892 "randomized" => Ok(TreeType::Randomized),
893 "oblivious" => Ok(TreeType::Oblivious),
894 _ => Err(IrError::UnsupportedTreeType(value.to_string())),
895 }
896}
897
898fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
899 match value {
900 "gini" => Ok(crate::Criterion::Gini),
901 "entropy" => Ok(crate::Criterion::Entropy),
902 "mean" => Ok(crate::Criterion::Mean),
903 "median" => Ok(crate::Criterion::Median),
904 "second_order" => Ok(crate::Criterion::SecondOrder),
905 "auto" => Ok(crate::Criterion::Auto),
906 _ => Err(IrError::InvalidInferenceOption(format!(
907 "unsupported criterion '{}'",
908 value
909 ))),
910 }
911}
912
913fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
914 DecisionTreeOptions {
915 max_depth: training.max_depth.unwrap_or(8),
916 min_samples_split: training.min_samples_split.unwrap_or(2),
917 min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
918 max_features: None,
919 random_seed: 0,
920 }
921}
922
923fn feature_preprocessing_from_ir(
924 ir: &ModelPackageIr,
925) -> Result<Vec<FeaturePreprocessing>, IrError> {
926 let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
927
928 for feature in &ir.preprocessing.numeric_binning.features {
929 match feature {
930 FeatureBinning::Numeric {
931 feature_index,
932 boundaries,
933 } => {
934 let slot = features.get_mut(*feature_index).ok_or_else(|| {
935 IrError::InvalidFeatureCount {
936 schema: ir.input_schema.feature_count,
937 preprocessing: feature_index + 1,
938 }
939 })?;
940 *slot = Some(FeaturePreprocessing::Numeric {
941 bin_boundaries: boundaries.clone(),
942 });
943 }
944 FeatureBinning::Binary { feature_index } => {
945 let slot = features.get_mut(*feature_index).ok_or_else(|| {
946 IrError::InvalidFeatureCount {
947 schema: ir.input_schema.feature_count,
948 preprocessing: feature_index + 1,
949 }
950 })?;
951 *slot = Some(FeaturePreprocessing::Binary);
952 }
953 }
954 }
955
956 if features.len() != ir.input_schema.feature_count {
957 return Err(IrError::InvalidFeatureCount {
958 schema: ir.input_schema.feature_count,
959 preprocessing: features.len(),
960 });
961 }
962
963 features
964 .into_iter()
965 .map(|feature| {
966 feature.ok_or_else(|| {
967 IrError::InvalidPreprocessing(
968 "every feature must have a preprocessing entry".to_string(),
969 )
970 })
971 })
972 .collect()
973}
974
975fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
976 ir.output_schema
977 .class_order
978 .clone()
979 .or_else(|| ir.training_metadata.class_labels.clone())
980 .ok_or(IrError::MissingClassLabels)
981}
982
983fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
984 match leaf {
985 LeafPayload::ClassIndex {
986 class_index,
987 class_value,
988 } => {
989 let Some(expected) = class_labels.get(*class_index) else {
990 return Err(IrError::InvalidLeaf(format!(
991 "class index {} out of bounds",
992 class_index
993 )));
994 };
995 if expected.total_cmp(class_value).is_ne() {
996 return Err(IrError::InvalidLeaf(format!(
997 "class value {} does not match class order entry {}",
998 class_value, expected
999 )));
1000 }
1001 Ok(*class_index)
1002 }
1003 LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1004 "expected class_index leaf".to_string(),
1005 )),
1006 }
1007}
1008
1009fn rebuild_classifier_nodes(
1010 nodes: Vec<NodeTreeNode>,
1011 class_labels: &[f64],
1012) -> Result<Vec<ClassifierTreeNode>, IrError> {
1013 let mut rebuilt = vec![None; nodes.len()];
1014 for node in nodes {
1015 match node {
1016 NodeTreeNode::Leaf {
1017 node_id,
1018 leaf,
1019 stats,
1020 ..
1021 } => {
1022 let class_index = classifier_class_index(&leaf, class_labels)?;
1023 assign_node(
1024 &mut rebuilt,
1025 node_id,
1026 ClassifierTreeNode::Leaf {
1027 class_index,
1028 sample_count: stats.sample_count,
1029 class_counts: stats
1030 .class_counts
1031 .unwrap_or_else(|| vec![0; class_labels.len()]),
1032 },
1033 )?;
1034 }
1035 NodeTreeNode::BinaryBranch {
1036 node_id,
1037 split,
1038 children,
1039 stats,
1040 ..
1041 } => {
1042 let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1043 assign_node(
1044 &mut rebuilt,
1045 node_id,
1046 ClassifierTreeNode::BinarySplit {
1047 feature_index,
1048 threshold_bin,
1049 left_child: children.left,
1050 right_child: children.right,
1051 sample_count: stats.sample_count,
1052 impurity: stats.impurity.unwrap_or(0.0),
1053 gain: stats.gain.unwrap_or(0.0),
1054 class_counts: stats
1055 .class_counts
1056 .unwrap_or_else(|| vec![0; class_labels.len()]),
1057 },
1058 )?;
1059 }
1060 NodeTreeNode::MultiwayBranch {
1061 node_id,
1062 split,
1063 branches,
1064 unmatched_leaf,
1065 stats,
1066 ..
1067 } => {
1068 let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1069 assign_node(
1070 &mut rebuilt,
1071 node_id,
1072 ClassifierTreeNode::MultiwaySplit {
1073 feature_index: split.feature_index,
1074 fallback_class_index,
1075 branches: branches
1076 .into_iter()
1077 .map(|branch| (branch.bin, branch.child))
1078 .collect(),
1079 sample_count: stats.sample_count,
1080 impurity: stats.impurity.unwrap_or(0.0),
1081 gain: stats.gain.unwrap_or(0.0),
1082 class_counts: stats
1083 .class_counts
1084 .unwrap_or_else(|| vec![0; class_labels.len()]),
1085 },
1086 )?;
1087 }
1088 }
1089 }
1090 collect_nodes(rebuilt)
1091}
1092
1093fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1094 let mut rebuilt = vec![None; nodes.len()];
1095 for node in nodes {
1096 match node {
1097 NodeTreeNode::Leaf {
1098 node_id,
1099 leaf: LeafPayload::RegressionValue { value },
1100 stats,
1101 ..
1102 } => {
1103 assign_node(
1104 &mut rebuilt,
1105 node_id,
1106 RegressionNode::Leaf {
1107 value,
1108 sample_count: stats.sample_count,
1109 variance: stats.variance,
1110 },
1111 )?;
1112 }
1113 NodeTreeNode::Leaf { .. } => {
1114 return Err(IrError::InvalidLeaf(
1115 "regression trees require regression_value leaves".to_string(),
1116 ));
1117 }
1118 NodeTreeNode::BinaryBranch {
1119 node_id,
1120 split,
1121 children,
1122 stats,
1123 ..
1124 } => {
1125 let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1126 assign_node(
1127 &mut rebuilt,
1128 node_id,
1129 RegressionNode::BinarySplit {
1130 feature_index,
1131 threshold_bin,
1132 left_child: children.left,
1133 right_child: children.right,
1134 sample_count: stats.sample_count,
1135 impurity: stats.impurity.unwrap_or(0.0),
1136 gain: stats.gain.unwrap_or(0.0),
1137 variance: stats.variance,
1138 },
1139 )?;
1140 }
1141 NodeTreeNode::MultiwayBranch { .. } => {
1142 return Err(IrError::InvalidNode(
1143 "regression trees do not support multiway branches".to_string(),
1144 ));
1145 }
1146 }
1147 }
1148 collect_nodes(rebuilt)
1149}
1150
1151fn rebuild_classifier_oblivious_splits(
1152 levels: Vec<ObliviousLevel>,
1153) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1154 let mut rebuilt = Vec::with_capacity(levels.len());
1155 for level in levels {
1156 rebuilt.push(match level.split {
1157 ObliviousSplit::NumericBinThreshold {
1158 feature_index,
1159 threshold_bin,
1160 ..
1161 } => ClassifierObliviousSplit {
1162 feature_index,
1163 threshold_bin,
1164 sample_count: level.stats.sample_count,
1165 impurity: level.stats.impurity.unwrap_or(0.0),
1166 gain: level.stats.gain.unwrap_or(0.0),
1167 },
1168 ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1169 feature_index,
1170 threshold_bin: 0,
1171 sample_count: level.stats.sample_count,
1172 impurity: level.stats.impurity.unwrap_or(0.0),
1173 gain: level.stats.gain.unwrap_or(0.0),
1174 },
1175 });
1176 }
1177 Ok(rebuilt)
1178}
1179
1180fn rebuild_regressor_oblivious_splits(
1181 levels: Vec<ObliviousLevel>,
1182) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1183 let mut rebuilt = Vec::with_capacity(levels.len());
1184 for level in levels {
1185 rebuilt.push(match level.split {
1186 ObliviousSplit::NumericBinThreshold {
1187 feature_index,
1188 threshold_bin,
1189 ..
1190 } => RegressorObliviousSplit {
1191 feature_index,
1192 threshold_bin,
1193 sample_count: level.stats.sample_count,
1194 impurity: level.stats.impurity.unwrap_or(0.0),
1195 gain: level.stats.gain.unwrap_or(0.0),
1196 },
1197 ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1198 feature_index,
1199 threshold_bin: 0,
1200 sample_count: level.stats.sample_count,
1201 impurity: level.stats.impurity.unwrap_or(0.0),
1202 gain: level.stats.gain.unwrap_or(0.0),
1203 },
1204 });
1205 }
1206 Ok(rebuilt)
1207}
1208
1209fn rebuild_classifier_leaf_indices(
1210 leaves: Vec<IndexedLeaf>,
1211 class_labels: &[f64],
1212) -> Result<Vec<usize>, IrError> {
1213 let mut rebuilt = vec![None; leaves.len()];
1214 for indexed_leaf in leaves {
1215 let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1216 assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1217 }
1218 collect_nodes(rebuilt)
1219}
1220
1221fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1222 let mut rebuilt = vec![None; leaves.len()];
1223 for indexed_leaf in leaves {
1224 let value = match indexed_leaf.leaf {
1225 LeafPayload::RegressionValue { value } => value,
1226 LeafPayload::ClassIndex { .. } => {
1227 return Err(IrError::InvalidLeaf(
1228 "regression oblivious leaves require regression_value".to_string(),
1229 ));
1230 }
1231 };
1232 assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1233 }
1234 collect_nodes(rebuilt)
1235}
1236
1237fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1238 let mut rebuilt = vec![None; leaves.len()];
1239 for indexed_leaf in leaves {
1240 assign_node(
1241 &mut rebuilt,
1242 indexed_leaf.leaf_index,
1243 indexed_leaf.stats.sample_count,
1244 )?;
1245 }
1246 collect_nodes(rebuilt)
1247}
1248
1249fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1250 let mut rebuilt = vec![None; leaves.len()];
1251 for indexed_leaf in leaves {
1252 assign_node(
1253 &mut rebuilt,
1254 indexed_leaf.leaf_index,
1255 indexed_leaf.stats.variance,
1256 )?;
1257 }
1258 collect_nodes(rebuilt)
1259}
1260
1261fn rebuild_classifier_leaf_class_counts(
1262 leaves: &[IndexedLeaf],
1263 num_classes: usize,
1264) -> Result<Vec<Vec<usize>>, IrError> {
1265 let mut rebuilt = vec![None; leaves.len()];
1266 for indexed_leaf in leaves {
1267 assign_node(
1268 &mut rebuilt,
1269 indexed_leaf.leaf_index,
1270 indexed_leaf
1271 .stats
1272 .class_counts
1273 .clone()
1274 .unwrap_or_else(|| vec![0; num_classes]),
1275 )?;
1276 }
1277 collect_nodes(rebuilt)
1278}
1279
1280fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1281 match split {
1282 BinarySplit::NumericBinThreshold {
1283 feature_index,
1284 threshold_bin,
1285 ..
1286 } => Ok((feature_index, threshold_bin)),
1287 BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1288 }
1289}
1290
1291fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1292 classifier_binary_split(split)
1293}
1294
1295fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1296 let Some(slot) = slots.get_mut(index) else {
1297 return Err(IrError::InvalidNode(format!(
1298 "node index {} is out of bounds",
1299 index
1300 )));
1301 };
1302 if slot.is_some() {
1303 return Err(IrError::InvalidNode(format!(
1304 "duplicate node index {}",
1305 index
1306 )));
1307 }
1308 *slot = Some(value);
1309 Ok(())
1310}
1311
1312fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1313 slots
1314 .into_iter()
1315 .enumerate()
1316 .map(|(index, slot)| {
1317 slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1318 })
1319 .collect()
1320}
1321
1322fn input_schema(model: &Model) -> InputSchema {
1323 let features = model
1324 .feature_preprocessing()
1325 .iter()
1326 .enumerate()
1327 .map(|(feature_index, preprocessing)| {
1328 let kind = match preprocessing {
1329 FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1330 FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1331 };
1332
1333 InputFeature {
1334 index: feature_index,
1335 name: feature_name(feature_index),
1336 dtype: match kind {
1337 InputFeatureKind::Numeric => "float64".to_string(),
1338 InputFeatureKind::Binary => "bool".to_string(),
1339 },
1340 logical_type: match kind {
1341 InputFeatureKind::Numeric => "numeric".to_string(),
1342 InputFeatureKind::Binary => "boolean".to_string(),
1343 },
1344 nullable: false,
1345 }
1346 })
1347 .collect();
1348
1349 InputSchema {
1350 feature_count: model.num_features(),
1351 features,
1352 ordering: "strict_index_order".to_string(),
1353 input_tensor_layout: "row_major".to_string(),
1354 accepts_feature_names: false,
1355 }
1356}
1357
1358fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1359 match model.task() {
1360 Task::Regression => OutputSchema {
1361 raw_outputs: vec![OutputField {
1362 name: "value".to_string(),
1363 kind: "regression_value".to_string(),
1364 shape: Vec::new(),
1365 dtype: "float64".to_string(),
1366 }],
1367 final_outputs: vec![OutputField {
1368 name: "prediction".to_string(),
1369 kind: "value".to_string(),
1370 shape: Vec::new(),
1371 dtype: "float64".to_string(),
1372 }],
1373 class_order: None,
1374 },
1375 Task::Classification => OutputSchema {
1376 raw_outputs: vec![OutputField {
1377 name: "class_index".to_string(),
1378 kind: "class_index".to_string(),
1379 shape: Vec::new(),
1380 dtype: "uint64".to_string(),
1381 }],
1382 final_outputs: vec![OutputField {
1383 name: "predicted_class".to_string(),
1384 kind: "class_label".to_string(),
1385 shape: Vec::new(),
1386 dtype: "float64".to_string(),
1387 }],
1388 class_order: class_labels,
1389 },
1390 }
1391}
1392
1393fn preprocessing(model: &Model) -> PreprocessingSection {
1394 let features = model
1395 .feature_preprocessing()
1396 .iter()
1397 .enumerate()
1398 .map(|(feature_index, preprocessing)| match preprocessing {
1399 FeaturePreprocessing::Numeric { bin_boundaries } => FeatureBinning::Numeric {
1400 feature_index,
1401 boundaries: bin_boundaries.clone(),
1402 },
1403 FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1404 })
1405 .collect();
1406
1407 PreprocessingSection {
1408 included_in_model: true,
1409 numeric_binning: NumericBinning {
1410 kind: "rank_bin_128".to_string(),
1411 features,
1412 },
1413 notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1414 .to_string(),
1415 }
1416}
1417
1418fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1419 match model.task() {
1420 Task::Regression => PostprocessingSection {
1421 raw_output_kind: "regression_value".to_string(),
1422 steps: vec![PostprocessingStep::Identity],
1423 },
1424 Task::Classification => PostprocessingSection {
1425 raw_output_kind: "class_index".to_string(),
1426 steps: vec![PostprocessingStep::MapClassIndexToLabel {
1427 labels: class_labels.expect("classification IR requires class labels"),
1428 }],
1429 },
1430 }
1431}
1432
1433fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1434 let mut capabilities = vec![
1435 representation.to_string(),
1436 "training_rank_bin_128".to_string(),
1437 ];
1438 match model.tree_type() {
1439 TreeType::Id3 | TreeType::C45 => {
1440 capabilities.push("binned_multiway_splits".to_string());
1441 }
1442 TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1443 capabilities.push("numeric_bin_threshold_splits".to_string());
1444 }
1445 }
1446 if model
1447 .feature_preprocessing()
1448 .iter()
1449 .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1450 {
1451 capabilities.push("boolean_features".to_string());
1452 }
1453 match model.task() {
1454 Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1455 Task::Classification => capabilities.push("class_index_leaves".to_string()),
1456 }
1457 capabilities
1458}
1459
1460pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1461 match algorithm {
1462 TrainAlgorithm::Dt => "dt",
1463 TrainAlgorithm::Rf => "rf",
1464 TrainAlgorithm::Gbm => "gbm",
1465 }
1466}
1467
1468fn model_tree_definition(model: &Model) -> TreeDefinition {
1469 match model {
1470 Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1471 Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1472 Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1473 unreachable!("ensemble IR expands into member trees")
1474 }
1475 }
1476}
1477
1478pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1479 match criterion {
1480 crate::Criterion::Auto => "auto",
1481 crate::Criterion::Gini => "gini",
1482 crate::Criterion::Entropy => "entropy",
1483 crate::Criterion::Mean => "mean",
1484 crate::Criterion::Median => "median",
1485 crate::Criterion::SecondOrder => "second_order",
1486 }
1487}
1488
1489pub(crate) fn task_name(task: Task) -> &'static str {
1490 match task {
1491 Task::Regression => "regression",
1492 Task::Classification => "classification",
1493 }
1494}
1495
1496pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1497 match tree_type {
1498 TreeType::Id3 => "id3",
1499 TreeType::C45 => "c45",
1500 TreeType::Cart => "cart",
1501 TreeType::Randomized => "randomized",
1502 TreeType::Oblivious => "oblivious",
1503 }
1504}
1505
1506pub(crate) fn feature_name(feature_index: usize) -> String {
1507 format!("f{}", feature_index)
1508}
1509
1510pub(crate) fn threshold_upper_bound(
1511 preprocessing: &[FeaturePreprocessing],
1512 feature_index: usize,
1513 threshold_bin: u16,
1514) -> Option<f64> {
1515 match preprocessing.get(feature_index)? {
1516 FeaturePreprocessing::Numeric { bin_boundaries } => bin_boundaries
1517 .iter()
1518 .find(|boundary| boundary.bin == threshold_bin)
1519 .map(|boundary| boundary.upper_bound),
1520 FeaturePreprocessing::Binary => None,
1521 }
1522}