1use crate::tree::classifier::{
13 DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14 ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15 TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18 DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19 RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22 Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23 NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36 pub ir_version: String,
37 pub format_name: String,
38 pub producer: ProducerMetadata,
39 pub model: ModelSection,
40 pub input_schema: InputSchema,
41 pub output_schema: OutputSchema,
42 pub inference_options: InferenceOptions,
43 pub preprocessing: PreprocessingSection,
44 pub postprocessing: PostprocessingSection,
45 pub training_metadata: TrainingMetadata,
46 pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51 pub library: String,
52 pub library_version: String,
53 pub language: String,
54 pub platform: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60 pub algorithm: String,
61 pub task: String,
62 pub tree_type: String,
63 pub representation: String,
64 pub num_features: usize,
65 pub num_outputs: usize,
66 pub supports_missing: bool,
67 pub supports_categorical: bool,
68 pub is_ensemble: bool,
69 pub trees: Vec<TreeDefinition>,
70 pub aggregation: Aggregation,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77 NodeTree {
78 tree_id: usize,
79 weight: f64,
80 root_node_id: usize,
81 nodes: Vec<NodeTreeNode>,
82 },
83 ObliviousLevels {
84 tree_id: usize,
85 weight: f64,
86 depth: usize,
87 levels: Vec<ObliviousLevel>,
88 leaf_indexing: LeafIndexing,
89 leaves: Vec<IndexedLeaf>,
90 },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96 Leaf {
97 node_id: usize,
98 depth: usize,
99 leaf: LeafPayload,
100 stats: NodeStats,
101 },
102 BinaryBranch {
103 node_id: usize,
104 depth: usize,
105 split: BinarySplit,
106 children: BinaryChildren,
107 stats: NodeStats,
108 },
109 MultiwayBranch {
110 node_id: usize,
111 depth: usize,
112 split: MultiwaySplit,
113 branches: Vec<MultiwayBranch>,
114 unmatched_leaf: LeafPayload,
115 stats: NodeStats,
116 },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121 pub left: usize,
122 pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127 pub bin: u16,
128 pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134 NumericBinThreshold {
135 feature_index: usize,
136 feature_name: String,
137 operator: String,
138 threshold_bin: u16,
139 threshold_upper_bound: Option<f64>,
140 comparison_dtype: String,
141 },
142 BooleanTest {
143 feature_index: usize,
144 feature_name: String,
145 false_child_semantics: String,
146 true_child_semantics: String,
147 },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152 pub split_type: String,
153 pub feature_index: usize,
154 pub feature_name: String,
155 pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160 pub level: usize,
161 pub split: ObliviousSplit,
162 pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168 NumericBinThreshold {
169 feature_index: usize,
170 feature_name: String,
171 operator: String,
172 threshold_bin: u16,
173 threshold_upper_bound: Option<f64>,
174 comparison_dtype: String,
175 bit_when_true: u8,
176 bit_when_false: u8,
177 },
178 BooleanTest {
179 feature_index: usize,
180 feature_name: String,
181 bit_when_false: u8,
182 bit_when_true: u8,
183 },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188 pub bit_order: String,
189 pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194 pub leaf_index: usize,
195 pub leaf: LeafPayload,
196 pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202 RegressionValue {
203 value: f64,
204 },
205 ClassIndex {
206 class_index: usize,
207 class_value: f64,
208 },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213 pub kind: String,
214 pub tree_weights: Vec<f64>,
215 pub normalize_by_weight_sum: bool,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222 pub feature_count: usize,
223 pub features: Vec<InputFeature>,
224 pub ordering: String,
225 pub input_tensor_layout: String,
226 pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231 pub index: usize,
232 pub name: String,
233 pub dtype: String,
234 pub logical_type: String,
235 pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240 pub raw_outputs: Vec<OutputField>,
241 pub final_outputs: Vec<OutputField>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248 pub name: String,
249 pub kind: String,
250 pub shape: Vec<usize>,
251 pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256 pub numeric_precision: String,
257 pub threshold_comparison: String,
258 pub nan_policy: String,
259 pub bool_encoding: BoolEncoding,
260 pub tie_breaking: TieBreaking,
261 pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266 pub false_values: Vec<String>,
267 pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272 pub classification: String,
273 pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278 pub guaranteed: bool,
279 pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284 pub included_in_model: bool,
285 pub numeric_binning: NumericBinning,
286 pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291 pub kind: String,
292 pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298 Numeric {
299 feature_index: usize,
300 boundaries: Vec<NumericBinBoundary>,
301 },
302 Binary {
303 feature_index: usize,
304 },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309 pub raw_output_kind: String,
310 pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316 Identity,
317 MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323 pub algorithm: String,
324 pub task: String,
325 pub tree_type: String,
326 pub criterion: String,
327 pub canaries: usize,
328 pub compute_oob: bool,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 pub max_depth: Option<usize>,
331 #[serde(skip_serializing_if = "Option::is_none")]
332 pub min_samples_split: Option<usize>,
333 #[serde(skip_serializing_if = "Option::is_none")]
334 pub min_samples_leaf: Option<usize>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub n_trees: Option<usize>,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 pub max_features: Option<usize>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 pub seed: Option<u64>,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub oob_score: Option<f64>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub class_labels: Option<Vec<f64>>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub learning_rate: Option<f64>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub bootstrap: Option<bool>,
349 #[serde(skip_serializing_if = "Option::is_none")]
350 pub top_gradient_fraction: Option<f64>,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357 pub serialization: String,
358 pub canonical_json: bool,
359 pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364 pub minimum_runtime_version: String,
365 pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370 pub sample_count: usize,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub impurity: Option<f64>,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub gain: Option<f64>,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 pub class_counts: Option<Vec<usize>>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383 UnsupportedIrVersion(String),
384 UnsupportedFormatName(String),
385 UnsupportedAlgorithm(String),
386 UnsupportedTask(String),
387 UnsupportedTreeType(String),
388 InvalidTreeCount(usize),
389 UnsupportedRepresentation(String),
390 InvalidFeatureCount { schema: usize, preprocessing: usize },
391 MissingClassLabels,
392 InvalidLeaf(String),
393 InvalidNode(String),
394 InvalidPreprocessing(String),
395 InvalidInferenceOption(String),
396 Json(String),
397}
398
399impl Display for IrError {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 match self {
402 IrError::UnsupportedIrVersion(version) => {
403 write!(f, "Unsupported IR version: {}.", version)
404 }
405 IrError::UnsupportedFormatName(name) => {
406 write!(f, "Unsupported IR format: {}.", name)
407 }
408 IrError::UnsupportedAlgorithm(algorithm) => {
409 write!(f, "Unsupported algorithm: {}.", algorithm)
410 }
411 IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412 IrError::UnsupportedTreeType(tree_type) => {
413 write!(f, "Unsupported tree type: {}.", tree_type)
414 }
415 IrError::InvalidTreeCount(count) => {
416 write!(f, "Expected exactly one tree in the IR, found {}.", count)
417 }
418 IrError::UnsupportedRepresentation(representation) => {
419 write!(f, "Unsupported tree representation: {}.", representation)
420 }
421 IrError::InvalidFeatureCount {
422 schema,
423 preprocessing,
424 } => write!(
425 f,
426 "Input schema declares {} features, but preprocessing declares {}.",
427 schema, preprocessing
428 ),
429 IrError::MissingClassLabels => {
430 write!(f, "Classification IR requires explicit class labels.")
431 }
432 IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433 IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434 IrError::InvalidPreprocessing(message) => {
435 write!(f, "Invalid preprocessing section: {}.", message)
436 }
437 IrError::InvalidInferenceOption(message) => {
438 write!(f, "Invalid inference options: {}.", message)
439 }
440 IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441 }
442 }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448 pub fn json_schema() -> RootSchema {
449 schema_for!(ModelPackageIr)
450 }
451
452 pub fn json_schema_json() -> Result<String, IrError> {
453 serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454 }
455
456 pub fn json_schema_json_pretty() -> Result<String, IrError> {
457 serde_json::to_string_pretty(&Self::json_schema())
458 .map_err(|err| IrError::Json(err.to_string()))
459 }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463 let trees = match model {
464 Model::RandomForest(forest) => forest
465 .trees()
466 .iter()
467 .map(model_tree_definition)
468 .collect::<Vec<_>>(),
469 Model::GradientBoostedTrees(boosted) => boosted
470 .trees()
471 .iter()
472 .map(model_tree_definition)
473 .collect::<Vec<_>>(),
474 _ => vec![model_tree_definition(model)],
475 };
476 let representation = if let Some(first_tree) = trees.first() {
477 match first_tree {
478 TreeDefinition::NodeTree { .. } => "node_tree",
479 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480 }
481 } else {
482 match model.tree_type() {
483 TreeType::Oblivious => "oblivious_levels",
484 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485 }
486 };
487 let class_labels = model.class_labels();
488 let is_ensemble = matches!(
489 model,
490 Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491 );
492 let tree_count = trees.len();
493 let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494 Model::RandomForest(_) => (
495 match model.task() {
496 Task::Regression => "average",
497 Task::Classification => "average_class_probabilities",
498 },
499 vec![1.0; tree_count],
500 true,
501 None,
502 ),
503 Model::GradientBoostedTrees(boosted) => (
504 match boosted.task() {
505 Task::Regression => "sum_tree_outputs",
506 Task::Classification => "sum_tree_outputs_then_sigmoid",
507 },
508 boosted.tree_weights().to_vec(),
509 false,
510 Some(boosted.base_score()),
511 ),
512 _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513 };
514
515 ModelPackageIr {
516 ir_version: IR_VERSION.to_string(),
517 format_name: FORMAT_NAME.to_string(),
518 producer: ProducerMetadata {
519 library: "forestfire-core".to_string(),
520 library_version: env!("CARGO_PKG_VERSION").to_string(),
521 language: "rust".to_string(),
522 platform: std::env::consts::ARCH.to_string(),
523 },
524 model: ModelSection {
525 algorithm: algorithm_name(model.algorithm()).to_string(),
526 task: task_name(model.task()).to_string(),
527 tree_type: tree_type_name(model.tree_type()).to_string(),
528 representation: representation.to_string(),
529 num_features: model.num_features(),
530 num_outputs: 1,
531 supports_missing: false,
532 supports_categorical: false,
533 is_ensemble,
534 trees,
535 aggregation: Aggregation {
536 kind: aggregation_kind.to_string(),
537 tree_weights,
538 normalize_by_weight_sum,
539 base_score,
540 },
541 },
542 input_schema: input_schema(model),
543 output_schema: output_schema(model, class_labels.clone()),
544 inference_options: InferenceOptions {
545 numeric_precision: "float64".to_string(),
546 threshold_comparison: "leq_left_gt_right".to_string(),
547 nan_policy: "not_supported".to_string(),
548 bool_encoding: BoolEncoding {
549 false_values: vec!["0".to_string(), "false".to_string()],
550 true_values: vec!["1".to_string(), "true".to_string()],
551 },
552 tie_breaking: TieBreaking {
553 classification: "lowest_class_index".to_string(),
554 argmax: "first_max_index".to_string(),
555 },
556 determinism: Determinism {
557 guaranteed: true,
558 notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559 .to_string(),
560 },
561 },
562 preprocessing: preprocessing(model),
563 postprocessing: postprocessing(model, class_labels),
564 training_metadata: model.training_metadata(),
565 integrity: IntegritySection {
566 serialization: "json".to_string(),
567 canonical_json: true,
568 compatibility: Compatibility {
569 minimum_runtime_version: IR_VERSION.to_string(),
570 required_capabilities: required_capabilities(model, representation),
571 },
572 },
573 }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577 validate_ir_header(&ir)?;
578 validate_inference_options(&ir.inference_options)?;
579
580 let algorithm = parse_algorithm(&ir.model.algorithm)?;
581 let task = parse_task(&ir.model.task)?;
582 let tree_type = parse_tree_type(&ir.model.tree_type)?;
583 let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584 let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585 let num_features = ir.input_schema.feature_count;
586 let options = tree_options(&ir.training_metadata);
587 let training_canaries = ir.training_metadata.canaries;
588 let deserialized_class_labels = classification_labels(&ir).ok();
589
590 if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591 return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592 }
593
594 if algorithm == TrainAlgorithm::Rf {
595 let trees = ir
596 .model
597 .trees
598 .into_iter()
599 .map(|tree| {
600 single_model_from_ir_parts(
601 task,
602 tree_type,
603 criterion,
604 feature_preprocessing.clone(),
605 num_features,
606 options.clone(),
607 training_canaries,
608 deserialized_class_labels.clone(),
609 tree,
610 )
611 })
612 .collect::<Result<Vec<_>, _>>()?;
613 return Ok(Model::RandomForest(RandomForest::new(
614 task,
615 criterion,
616 tree_type,
617 trees,
618 ir.training_metadata.compute_oob,
619 ir.training_metadata.oob_score,
620 ir.training_metadata
621 .max_features
622 .unwrap_or(num_features.max(1)),
623 ir.training_metadata.seed,
624 num_features,
625 feature_preprocessing,
626 )));
627 }
628
629 if algorithm == TrainAlgorithm::Gbm {
630 let tree_weights = ir.model.aggregation.tree_weights.clone();
631 let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632 let trees = ir
633 .model
634 .trees
635 .into_iter()
636 .map(|tree| {
637 boosted_tree_model_from_ir_parts(
638 tree_type,
639 criterion,
640 feature_preprocessing.clone(),
641 num_features,
642 options.clone(),
643 training_canaries,
644 tree,
645 )
646 })
647 .collect::<Result<Vec<_>, _>>()?;
648 return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
649 task,
650 tree_type,
651 trees,
652 tree_weights,
653 base_score,
654 ir.training_metadata.learning_rate.unwrap_or(0.1),
655 ir.training_metadata.bootstrap.unwrap_or(false),
656 ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
657 ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
658 ir.training_metadata
659 .max_features
660 .unwrap_or(num_features.max(1)),
661 ir.training_metadata.seed,
662 num_features,
663 feature_preprocessing,
664 deserialized_class_labels,
665 training_canaries,
666 )));
667 }
668
669 let tree = ir
670 .model
671 .trees
672 .into_iter()
673 .next()
674 .expect("validated single tree");
675
676 single_model_from_ir_parts(
677 task,
678 tree_type,
679 criterion,
680 feature_preprocessing,
681 num_features,
682 options,
683 training_canaries,
684 deserialized_class_labels,
685 tree,
686 )
687}
688
689fn boosted_tree_model_from_ir_parts(
690 tree_type: TreeType,
691 criterion: Criterion,
692 feature_preprocessing: Vec<FeaturePreprocessing>,
693 num_features: usize,
694 options: DecisionTreeOptions,
695 training_canaries: usize,
696 tree: TreeDefinition,
697) -> Result<Model, IrError> {
698 match (tree_type, tree) {
699 (
700 TreeType::Cart | TreeType::Randomized,
701 TreeDefinition::NodeTree {
702 nodes,
703 root_node_id,
704 ..
705 },
706 ) => Ok(Model::DecisionTreeRegressor(
707 DecisionTreeRegressor::from_ir_parts(
708 match tree_type {
709 TreeType::Cart => RegressionTreeAlgorithm::Cart,
710 TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
711 _ => unreachable!(),
712 },
713 criterion,
714 RegressionTreeStructure::Standard {
715 nodes: rebuild_regressor_nodes(nodes)?,
716 root: root_node_id,
717 },
718 RegressionTreeOptions {
719 max_depth: options.max_depth,
720 min_samples_split: options.min_samples_split,
721 min_samples_leaf: options.min_samples_leaf,
722 max_features: None,
723 random_seed: 0,
724 missing_value_strategies: Vec::new(),
725 },
726 num_features,
727 feature_preprocessing,
728 training_canaries,
729 ),
730 )),
731 (TreeType::Oblivious, TreeDefinition::ObliviousLevels { levels, leaves, .. }) => {
732 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
733 let leaf_variances = rebuild_leaf_variances(&leaves)?;
734 Ok(Model::DecisionTreeRegressor(
735 DecisionTreeRegressor::from_ir_parts(
736 RegressionTreeAlgorithm::Oblivious,
737 criterion,
738 RegressionTreeStructure::Oblivious {
739 splits: rebuild_regressor_oblivious_splits(levels)?,
740 leaf_values: rebuild_regressor_leaf_values(leaves)?,
741 leaf_sample_counts,
742 leaf_variances,
743 },
744 RegressionTreeOptions {
745 max_depth: options.max_depth,
746 min_samples_split: options.min_samples_split,
747 min_samples_leaf: options.min_samples_leaf,
748 max_features: None,
749 random_seed: 0,
750 missing_value_strategies: Vec::new(),
751 },
752 num_features,
753 feature_preprocessing,
754 training_canaries,
755 ),
756 ))
757 }
758 (_, tree) => Err(IrError::UnsupportedRepresentation(match tree {
759 TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
760 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
761 })),
762 }
763}
764
765#[allow(clippy::too_many_arguments)]
766fn single_model_from_ir_parts(
767 task: Task,
768 tree_type: TreeType,
769 criterion: Criterion,
770 feature_preprocessing: Vec<FeaturePreprocessing>,
771 num_features: usize,
772 options: DecisionTreeOptions,
773 training_canaries: usize,
774 deserialized_class_labels: Option<Vec<f64>>,
775 tree: TreeDefinition,
776) -> Result<Model, IrError> {
777 match (task, tree_type, tree) {
778 (
779 Task::Classification,
780 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
781 TreeDefinition::NodeTree {
782 nodes,
783 root_node_id,
784 ..
785 },
786 ) => {
787 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
788 let structure = ClassifierTreeStructure::Standard {
789 nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
790 root: root_node_id,
791 };
792 Ok(Model::DecisionTreeClassifier(
793 DecisionTreeClassifier::from_ir_parts(
794 match tree_type {
795 TreeType::Id3 => DecisionTreeAlgorithm::Id3,
796 TreeType::C45 => DecisionTreeAlgorithm::C45,
797 TreeType::Cart => DecisionTreeAlgorithm::Cart,
798 TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
799 TreeType::Oblivious => unreachable!(),
800 },
801 criterion,
802 class_labels,
803 structure,
804 options,
805 num_features,
806 feature_preprocessing,
807 training_canaries,
808 ),
809 ))
810 }
811 (
812 Task::Classification,
813 TreeType::Oblivious,
814 TreeDefinition::ObliviousLevels { levels, leaves, .. },
815 ) => {
816 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
817 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
818 let leaf_class_counts =
819 rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
820 let structure = ClassifierTreeStructure::Oblivious {
821 splits: rebuild_classifier_oblivious_splits(levels)?,
822 leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
823 leaf_sample_counts,
824 leaf_class_counts,
825 };
826 Ok(Model::DecisionTreeClassifier(
827 DecisionTreeClassifier::from_ir_parts(
828 DecisionTreeAlgorithm::Oblivious,
829 criterion,
830 class_labels,
831 structure,
832 options,
833 num_features,
834 feature_preprocessing,
835 training_canaries,
836 ),
837 ))
838 }
839 (
840 Task::Regression,
841 TreeType::Cart | TreeType::Randomized,
842 TreeDefinition::NodeTree {
843 nodes,
844 root_node_id,
845 ..
846 },
847 ) => Ok(Model::DecisionTreeRegressor(
848 DecisionTreeRegressor::from_ir_parts(
849 match tree_type {
850 TreeType::Cart => RegressionTreeAlgorithm::Cart,
851 TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
852 _ => unreachable!(),
853 },
854 criterion,
855 RegressionTreeStructure::Standard {
856 nodes: rebuild_regressor_nodes(nodes)?,
857 root: root_node_id,
858 },
859 RegressionTreeOptions {
860 max_depth: options.max_depth,
861 min_samples_split: options.min_samples_split,
862 min_samples_leaf: options.min_samples_leaf,
863 max_features: None,
864 random_seed: 0,
865 missing_value_strategies: Vec::new(),
866 },
867 num_features,
868 feature_preprocessing,
869 training_canaries,
870 ),
871 )),
872 (
873 Task::Regression,
874 TreeType::Oblivious,
875 TreeDefinition::ObliviousLevels { levels, leaves, .. },
876 ) => {
877 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
878 let leaf_variances = rebuild_leaf_variances(&leaves)?;
879 Ok(Model::DecisionTreeRegressor(
880 DecisionTreeRegressor::from_ir_parts(
881 RegressionTreeAlgorithm::Oblivious,
882 criterion,
883 RegressionTreeStructure::Oblivious {
884 splits: rebuild_regressor_oblivious_splits(levels)?,
885 leaf_values: rebuild_regressor_leaf_values(leaves)?,
886 leaf_sample_counts,
887 leaf_variances,
888 },
889 RegressionTreeOptions {
890 max_depth: options.max_depth,
891 min_samples_split: options.min_samples_split,
892 min_samples_leaf: options.min_samples_leaf,
893 max_features: None,
894 random_seed: 0,
895 missing_value_strategies: Vec::new(),
896 },
897 num_features,
898 feature_preprocessing,
899 training_canaries,
900 ),
901 ))
902 }
903 (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
904 TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
905 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
906 })),
907 }
908}
909
910fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
911 if ir.ir_version != IR_VERSION {
912 return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
913 }
914 if ir.format_name != FORMAT_NAME {
915 return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
916 }
917 if ir.model.supports_missing {
918 return Err(IrError::InvalidInferenceOption(
919 "missing values are not supported in IR v1".to_string(),
920 ));
921 }
922 if ir.model.supports_categorical {
923 return Err(IrError::InvalidInferenceOption(
924 "categorical features are not supported in IR v1".to_string(),
925 ));
926 }
927 Ok(())
928}
929
930fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
931 if options.threshold_comparison != "leq_left_gt_right" {
932 return Err(IrError::InvalidInferenceOption(format!(
933 "unsupported threshold comparison '{}'",
934 options.threshold_comparison
935 )));
936 }
937 if options.nan_policy != "not_supported" {
938 return Err(IrError::InvalidInferenceOption(format!(
939 "unsupported nan policy '{}'",
940 options.nan_policy
941 )));
942 }
943 Ok(())
944}
945
946fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
947 match value {
948 "dt" => Ok(TrainAlgorithm::Dt),
949 "rf" => Ok(TrainAlgorithm::Rf),
950 "gbm" => Ok(TrainAlgorithm::Gbm),
951 _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
952 }
953}
954
955fn parse_task(value: &str) -> Result<Task, IrError> {
956 match value {
957 "regression" => Ok(Task::Regression),
958 "classification" => Ok(Task::Classification),
959 _ => Err(IrError::UnsupportedTask(value.to_string())),
960 }
961}
962
963fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
964 match value {
965 "id3" => Ok(TreeType::Id3),
966 "c45" => Ok(TreeType::C45),
967 "cart" => Ok(TreeType::Cart),
968 "randomized" => Ok(TreeType::Randomized),
969 "oblivious" => Ok(TreeType::Oblivious),
970 _ => Err(IrError::UnsupportedTreeType(value.to_string())),
971 }
972}
973
974fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
975 match value {
976 "gini" => Ok(crate::Criterion::Gini),
977 "entropy" => Ok(crate::Criterion::Entropy),
978 "mean" => Ok(crate::Criterion::Mean),
979 "median" => Ok(crate::Criterion::Median),
980 "second_order" => Ok(crate::Criterion::SecondOrder),
981 "auto" => Ok(crate::Criterion::Auto),
982 _ => Err(IrError::InvalidInferenceOption(format!(
983 "unsupported criterion '{}'",
984 value
985 ))),
986 }
987}
988
989fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
990 DecisionTreeOptions {
991 max_depth: training.max_depth.unwrap_or(8),
992 min_samples_split: training.min_samples_split.unwrap_or(2),
993 min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
994 max_features: None,
995 random_seed: 0,
996 missing_value_strategies: Vec::new(),
997 }
998}
999
1000fn feature_preprocessing_from_ir(
1001 ir: &ModelPackageIr,
1002) -> Result<Vec<FeaturePreprocessing>, IrError> {
1003 let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
1004
1005 for feature in &ir.preprocessing.numeric_binning.features {
1006 match feature {
1007 FeatureBinning::Numeric {
1008 feature_index,
1009 boundaries,
1010 } => {
1011 let slot = features.get_mut(*feature_index).ok_or_else(|| {
1012 IrError::InvalidFeatureCount {
1013 schema: ir.input_schema.feature_count,
1014 preprocessing: feature_index + 1,
1015 }
1016 })?;
1017 *slot = Some(FeaturePreprocessing::Numeric {
1018 bin_boundaries: boundaries.clone(),
1019 missing_bin: boundaries
1020 .iter()
1021 .map(|boundary| boundary.bin)
1022 .max()
1023 .map_or(0, |bin| bin.saturating_add(1)),
1024 });
1025 }
1026 FeatureBinning::Binary { feature_index } => {
1027 let slot = features.get_mut(*feature_index).ok_or_else(|| {
1028 IrError::InvalidFeatureCount {
1029 schema: ir.input_schema.feature_count,
1030 preprocessing: feature_index + 1,
1031 }
1032 })?;
1033 *slot = Some(FeaturePreprocessing::Binary);
1034 }
1035 }
1036 }
1037
1038 if features.len() != ir.input_schema.feature_count {
1039 return Err(IrError::InvalidFeatureCount {
1040 schema: ir.input_schema.feature_count,
1041 preprocessing: features.len(),
1042 });
1043 }
1044
1045 features
1046 .into_iter()
1047 .map(|feature| {
1048 feature.ok_or_else(|| {
1049 IrError::InvalidPreprocessing(
1050 "every feature must have a preprocessing entry".to_string(),
1051 )
1052 })
1053 })
1054 .collect()
1055}
1056
1057fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
1058 ir.output_schema
1059 .class_order
1060 .clone()
1061 .or_else(|| ir.training_metadata.class_labels.clone())
1062 .ok_or(IrError::MissingClassLabels)
1063}
1064
1065fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
1066 match leaf {
1067 LeafPayload::ClassIndex {
1068 class_index,
1069 class_value,
1070 } => {
1071 let Some(expected) = class_labels.get(*class_index) else {
1072 return Err(IrError::InvalidLeaf(format!(
1073 "class index {} out of bounds",
1074 class_index
1075 )));
1076 };
1077 if expected.total_cmp(class_value).is_ne() {
1078 return Err(IrError::InvalidLeaf(format!(
1079 "class value {} does not match class order entry {}",
1080 class_value, expected
1081 )));
1082 }
1083 Ok(*class_index)
1084 }
1085 LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1086 "expected class_index leaf".to_string(),
1087 )),
1088 }
1089}
1090
1091fn rebuild_classifier_nodes(
1092 nodes: Vec<NodeTreeNode>,
1093 class_labels: &[f64],
1094) -> Result<Vec<ClassifierTreeNode>, IrError> {
1095 let mut rebuilt = vec![None; nodes.len()];
1096 for node in nodes {
1097 match node {
1098 NodeTreeNode::Leaf {
1099 node_id,
1100 leaf,
1101 stats,
1102 ..
1103 } => {
1104 let class_index = classifier_class_index(&leaf, class_labels)?;
1105 assign_node(
1106 &mut rebuilt,
1107 node_id,
1108 ClassifierTreeNode::Leaf {
1109 class_index,
1110 sample_count: stats.sample_count,
1111 class_counts: stats
1112 .class_counts
1113 .unwrap_or_else(|| vec![0; class_labels.len()]),
1114 },
1115 )?;
1116 }
1117 NodeTreeNode::BinaryBranch {
1118 node_id,
1119 split,
1120 children,
1121 stats,
1122 ..
1123 } => {
1124 let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1125 assign_node(
1126 &mut rebuilt,
1127 node_id,
1128 ClassifierTreeNode::BinarySplit {
1129 feature_index,
1130 threshold_bin,
1131 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1132 left_child: children.left,
1133 right_child: children.right,
1134 sample_count: stats.sample_count,
1135 impurity: stats.impurity.unwrap_or(0.0),
1136 gain: stats.gain.unwrap_or(0.0),
1137 class_counts: stats
1138 .class_counts
1139 .unwrap_or_else(|| vec![0; class_labels.len()]),
1140 },
1141 )?;
1142 }
1143 NodeTreeNode::MultiwayBranch {
1144 node_id,
1145 split,
1146 branches,
1147 unmatched_leaf,
1148 stats,
1149 ..
1150 } => {
1151 let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1152 assign_node(
1153 &mut rebuilt,
1154 node_id,
1155 ClassifierTreeNode::MultiwaySplit {
1156 feature_index: split.feature_index,
1157 fallback_class_index,
1158 branches: branches
1159 .into_iter()
1160 .map(|branch| (branch.bin, branch.child))
1161 .collect(),
1162 missing_child: None,
1163 sample_count: stats.sample_count,
1164 impurity: stats.impurity.unwrap_or(0.0),
1165 gain: stats.gain.unwrap_or(0.0),
1166 class_counts: stats
1167 .class_counts
1168 .unwrap_or_else(|| vec![0; class_labels.len()]),
1169 },
1170 )?;
1171 }
1172 }
1173 }
1174 collect_nodes(rebuilt)
1175}
1176
1177fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1178 let mut rebuilt = vec![None; nodes.len()];
1179 for node in nodes {
1180 match node {
1181 NodeTreeNode::Leaf {
1182 node_id,
1183 leaf: LeafPayload::RegressionValue { value },
1184 stats,
1185 ..
1186 } => {
1187 assign_node(
1188 &mut rebuilt,
1189 node_id,
1190 RegressionNode::Leaf {
1191 value,
1192 sample_count: stats.sample_count,
1193 variance: stats.variance,
1194 },
1195 )?;
1196 }
1197 NodeTreeNode::Leaf { .. } => {
1198 return Err(IrError::InvalidLeaf(
1199 "regression trees require regression_value leaves".to_string(),
1200 ));
1201 }
1202 NodeTreeNode::BinaryBranch {
1203 node_id,
1204 split,
1205 children,
1206 stats,
1207 ..
1208 } => {
1209 let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1210 assign_node(
1211 &mut rebuilt,
1212 node_id,
1213 RegressionNode::BinarySplit {
1214 feature_index,
1215 threshold_bin,
1216 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1217 missing_value: 0.0,
1218 left_child: children.left,
1219 right_child: children.right,
1220 sample_count: stats.sample_count,
1221 impurity: stats.impurity.unwrap_or(0.0),
1222 gain: stats.gain.unwrap_or(0.0),
1223 variance: stats.variance,
1224 },
1225 )?;
1226 }
1227 NodeTreeNode::MultiwayBranch { .. } => {
1228 return Err(IrError::InvalidNode(
1229 "regression trees do not support multiway branches".to_string(),
1230 ));
1231 }
1232 }
1233 }
1234 collect_nodes(rebuilt)
1235}
1236
1237fn rebuild_classifier_oblivious_splits(
1238 levels: Vec<ObliviousLevel>,
1239) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1240 let mut rebuilt = Vec::with_capacity(levels.len());
1241 for level in levels {
1242 rebuilt.push(match level.split {
1243 ObliviousSplit::NumericBinThreshold {
1244 feature_index,
1245 threshold_bin,
1246 ..
1247 } => ClassifierObliviousSplit {
1248 feature_index,
1249 threshold_bin,
1250 missing_directions: Vec::new(),
1251 sample_count: level.stats.sample_count,
1252 impurity: level.stats.impurity.unwrap_or(0.0),
1253 gain: level.stats.gain.unwrap_or(0.0),
1254 },
1255 ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1256 feature_index,
1257 threshold_bin: 0,
1258 missing_directions: Vec::new(),
1259 sample_count: level.stats.sample_count,
1260 impurity: level.stats.impurity.unwrap_or(0.0),
1261 gain: level.stats.gain.unwrap_or(0.0),
1262 },
1263 });
1264 }
1265 Ok(rebuilt)
1266}
1267
1268fn rebuild_regressor_oblivious_splits(
1269 levels: Vec<ObliviousLevel>,
1270) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1271 let mut rebuilt = Vec::with_capacity(levels.len());
1272 for level in levels {
1273 rebuilt.push(match level.split {
1274 ObliviousSplit::NumericBinThreshold {
1275 feature_index,
1276 threshold_bin,
1277 ..
1278 } => RegressorObliviousSplit {
1279 feature_index,
1280 threshold_bin,
1281 sample_count: level.stats.sample_count,
1282 impurity: level.stats.impurity.unwrap_or(0.0),
1283 gain: level.stats.gain.unwrap_or(0.0),
1284 },
1285 ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1286 feature_index,
1287 threshold_bin: 0,
1288 sample_count: level.stats.sample_count,
1289 impurity: level.stats.impurity.unwrap_or(0.0),
1290 gain: level.stats.gain.unwrap_or(0.0),
1291 },
1292 });
1293 }
1294 Ok(rebuilt)
1295}
1296
1297fn rebuild_classifier_leaf_indices(
1298 leaves: Vec<IndexedLeaf>,
1299 class_labels: &[f64],
1300) -> Result<Vec<usize>, IrError> {
1301 let mut rebuilt = vec![None; leaves.len()];
1302 for indexed_leaf in leaves {
1303 let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1304 assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1305 }
1306 collect_nodes(rebuilt)
1307}
1308
1309fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1310 let mut rebuilt = vec![None; leaves.len()];
1311 for indexed_leaf in leaves {
1312 let value = match indexed_leaf.leaf {
1313 LeafPayload::RegressionValue { value } => value,
1314 LeafPayload::ClassIndex { .. } => {
1315 return Err(IrError::InvalidLeaf(
1316 "regression oblivious leaves require regression_value".to_string(),
1317 ));
1318 }
1319 };
1320 assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1321 }
1322 collect_nodes(rebuilt)
1323}
1324
1325fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1326 let mut rebuilt = vec![None; leaves.len()];
1327 for indexed_leaf in leaves {
1328 assign_node(
1329 &mut rebuilt,
1330 indexed_leaf.leaf_index,
1331 indexed_leaf.stats.sample_count,
1332 )?;
1333 }
1334 collect_nodes(rebuilt)
1335}
1336
1337fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1338 let mut rebuilt = vec![None; leaves.len()];
1339 for indexed_leaf in leaves {
1340 assign_node(
1341 &mut rebuilt,
1342 indexed_leaf.leaf_index,
1343 indexed_leaf.stats.variance,
1344 )?;
1345 }
1346 collect_nodes(rebuilt)
1347}
1348
1349fn rebuild_classifier_leaf_class_counts(
1350 leaves: &[IndexedLeaf],
1351 num_classes: usize,
1352) -> Result<Vec<Vec<usize>>, IrError> {
1353 let mut rebuilt = vec![None; leaves.len()];
1354 for indexed_leaf in leaves {
1355 assign_node(
1356 &mut rebuilt,
1357 indexed_leaf.leaf_index,
1358 indexed_leaf
1359 .stats
1360 .class_counts
1361 .clone()
1362 .unwrap_or_else(|| vec![0; num_classes]),
1363 )?;
1364 }
1365 collect_nodes(rebuilt)
1366}
1367
1368fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1369 match split {
1370 BinarySplit::NumericBinThreshold {
1371 feature_index,
1372 threshold_bin,
1373 ..
1374 } => Ok((feature_index, threshold_bin)),
1375 BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1376 }
1377}
1378
1379fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1380 classifier_binary_split(split)
1381}
1382
1383fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1384 let Some(slot) = slots.get_mut(index) else {
1385 return Err(IrError::InvalidNode(format!(
1386 "node index {} is out of bounds",
1387 index
1388 )));
1389 };
1390 if slot.is_some() {
1391 return Err(IrError::InvalidNode(format!(
1392 "duplicate node index {}",
1393 index
1394 )));
1395 }
1396 *slot = Some(value);
1397 Ok(())
1398}
1399
1400fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1401 slots
1402 .into_iter()
1403 .enumerate()
1404 .map(|(index, slot)| {
1405 slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1406 })
1407 .collect()
1408}
1409
1410fn input_schema(model: &Model) -> InputSchema {
1411 let features = model
1412 .feature_preprocessing()
1413 .iter()
1414 .enumerate()
1415 .map(|(feature_index, preprocessing)| {
1416 let kind = match preprocessing {
1417 FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1418 FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1419 };
1420
1421 InputFeature {
1422 index: feature_index,
1423 name: feature_name(feature_index),
1424 dtype: match kind {
1425 InputFeatureKind::Numeric => "float64".to_string(),
1426 InputFeatureKind::Binary => "bool".to_string(),
1427 },
1428 logical_type: match kind {
1429 InputFeatureKind::Numeric => "numeric".to_string(),
1430 InputFeatureKind::Binary => "boolean".to_string(),
1431 },
1432 nullable: false,
1433 }
1434 })
1435 .collect();
1436
1437 InputSchema {
1438 feature_count: model.num_features(),
1439 features,
1440 ordering: "strict_index_order".to_string(),
1441 input_tensor_layout: "row_major".to_string(),
1442 accepts_feature_names: false,
1443 }
1444}
1445
1446fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1447 match model.task() {
1448 Task::Regression => OutputSchema {
1449 raw_outputs: vec![OutputField {
1450 name: "value".to_string(),
1451 kind: "regression_value".to_string(),
1452 shape: Vec::new(),
1453 dtype: "float64".to_string(),
1454 }],
1455 final_outputs: vec![OutputField {
1456 name: "prediction".to_string(),
1457 kind: "value".to_string(),
1458 shape: Vec::new(),
1459 dtype: "float64".to_string(),
1460 }],
1461 class_order: None,
1462 },
1463 Task::Classification => OutputSchema {
1464 raw_outputs: vec![OutputField {
1465 name: "class_index".to_string(),
1466 kind: "class_index".to_string(),
1467 shape: Vec::new(),
1468 dtype: "uint64".to_string(),
1469 }],
1470 final_outputs: vec![OutputField {
1471 name: "predicted_class".to_string(),
1472 kind: "class_label".to_string(),
1473 shape: Vec::new(),
1474 dtype: "float64".to_string(),
1475 }],
1476 class_order: class_labels,
1477 },
1478 }
1479}
1480
1481fn preprocessing(model: &Model) -> PreprocessingSection {
1482 let features = model
1483 .feature_preprocessing()
1484 .iter()
1485 .enumerate()
1486 .map(|(feature_index, preprocessing)| match preprocessing {
1487 FeaturePreprocessing::Numeric { bin_boundaries, .. } => FeatureBinning::Numeric {
1488 feature_index,
1489 boundaries: bin_boundaries.clone(),
1490 },
1491 FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1492 })
1493 .collect();
1494
1495 PreprocessingSection {
1496 included_in_model: true,
1497 numeric_binning: NumericBinning {
1498 kind: "rank_bin_128".to_string(),
1499 features,
1500 },
1501 notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1502 .to_string(),
1503 }
1504}
1505
1506fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1507 match model.task() {
1508 Task::Regression => PostprocessingSection {
1509 raw_output_kind: "regression_value".to_string(),
1510 steps: vec![PostprocessingStep::Identity],
1511 },
1512 Task::Classification => PostprocessingSection {
1513 raw_output_kind: "class_index".to_string(),
1514 steps: vec![PostprocessingStep::MapClassIndexToLabel {
1515 labels: class_labels.expect("classification IR requires class labels"),
1516 }],
1517 },
1518 }
1519}
1520
1521fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1522 let mut capabilities = vec![
1523 representation.to_string(),
1524 "training_rank_bin_128".to_string(),
1525 ];
1526 match model.tree_type() {
1527 TreeType::Id3 | TreeType::C45 => {
1528 capabilities.push("binned_multiway_splits".to_string());
1529 }
1530 TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1531 capabilities.push("numeric_bin_threshold_splits".to_string());
1532 }
1533 }
1534 if model
1535 .feature_preprocessing()
1536 .iter()
1537 .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1538 {
1539 capabilities.push("boolean_features".to_string());
1540 }
1541 match model.task() {
1542 Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1543 Task::Classification => capabilities.push("class_index_leaves".to_string()),
1544 }
1545 capabilities
1546}
1547
1548pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1549 match algorithm {
1550 TrainAlgorithm::Dt => "dt",
1551 TrainAlgorithm::Rf => "rf",
1552 TrainAlgorithm::Gbm => "gbm",
1553 }
1554}
1555
1556fn model_tree_definition(model: &Model) -> TreeDefinition {
1557 match model {
1558 Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1559 Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1560 Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1561 unreachable!("ensemble IR expands into member trees")
1562 }
1563 }
1564}
1565
1566pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1567 match criterion {
1568 crate::Criterion::Auto => "auto",
1569 crate::Criterion::Gini => "gini",
1570 crate::Criterion::Entropy => "entropy",
1571 crate::Criterion::Mean => "mean",
1572 crate::Criterion::Median => "median",
1573 crate::Criterion::SecondOrder => "second_order",
1574 }
1575}
1576
1577pub(crate) fn task_name(task: Task) -> &'static str {
1578 match task {
1579 Task::Regression => "regression",
1580 Task::Classification => "classification",
1581 }
1582}
1583
1584pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1585 match tree_type {
1586 TreeType::Id3 => "id3",
1587 TreeType::C45 => "c45",
1588 TreeType::Cart => "cart",
1589 TreeType::Randomized => "randomized",
1590 TreeType::Oblivious => "oblivious",
1591 }
1592}
1593
1594pub(crate) fn feature_name(feature_index: usize) -> String {
1595 format!("f{}", feature_index)
1596}
1597
1598pub(crate) fn threshold_upper_bound(
1599 preprocessing: &[FeaturePreprocessing],
1600 feature_index: usize,
1601 threshold_bin: u16,
1602) -> Option<f64> {
1603 match preprocessing.get(feature_index)? {
1604 FeaturePreprocessing::Numeric { bin_boundaries, .. } => bin_boundaries
1605 .iter()
1606 .find(|boundary| boundary.bin == threshold_bin)
1607 .map(|boundary| boundary.upper_bound),
1608 FeaturePreprocessing::Binary => None,
1609 }
1610}