1use crate::tree::classifier::{
13 DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14 ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15 TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18 DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19 RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22 Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23 NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36 pub ir_version: String,
37 pub format_name: String,
38 pub producer: ProducerMetadata,
39 pub model: ModelSection,
40 pub input_schema: InputSchema,
41 pub output_schema: OutputSchema,
42 pub inference_options: InferenceOptions,
43 pub preprocessing: PreprocessingSection,
44 pub postprocessing: PostprocessingSection,
45 pub training_metadata: TrainingMetadata,
46 pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51 pub library: String,
52 pub library_version: String,
53 pub language: String,
54 pub platform: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60 pub algorithm: String,
61 pub task: String,
62 pub tree_type: String,
63 pub representation: String,
64 pub num_features: usize,
65 pub num_outputs: usize,
66 pub supports_missing: bool,
67 pub supports_categorical: bool,
68 pub is_ensemble: bool,
69 pub trees: Vec<TreeDefinition>,
70 pub aggregation: Aggregation,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77 NodeTree {
78 tree_id: usize,
79 weight: f64,
80 root_node_id: usize,
81 nodes: Vec<NodeTreeNode>,
82 },
83 ObliviousLevels {
84 tree_id: usize,
85 weight: f64,
86 depth: usize,
87 levels: Vec<ObliviousLevel>,
88 leaf_indexing: LeafIndexing,
89 leaves: Vec<IndexedLeaf>,
90 },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96 Leaf {
97 node_id: usize,
98 depth: usize,
99 leaf: LeafPayload,
100 stats: NodeStats,
101 },
102 BinaryBranch {
103 node_id: usize,
104 depth: usize,
105 split: BinarySplit,
106 children: BinaryChildren,
107 stats: NodeStats,
108 },
109 MultiwayBranch {
110 node_id: usize,
111 depth: usize,
112 split: MultiwaySplit,
113 branches: Vec<MultiwayBranch>,
114 unmatched_leaf: LeafPayload,
115 stats: NodeStats,
116 },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121 pub left: usize,
122 pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127 pub bin: u16,
128 pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134 NumericBinThreshold {
135 feature_index: usize,
136 feature_name: String,
137 operator: String,
138 threshold_bin: u16,
139 threshold_upper_bound: Option<f64>,
140 comparison_dtype: String,
141 },
142 BooleanTest {
143 feature_index: usize,
144 feature_name: String,
145 false_child_semantics: String,
146 true_child_semantics: String,
147 },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152 pub split_type: String,
153 pub feature_index: usize,
154 pub feature_name: String,
155 pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160 pub level: usize,
161 pub split: ObliviousSplit,
162 pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168 NumericBinThreshold {
169 feature_index: usize,
170 feature_name: String,
171 operator: String,
172 threshold_bin: u16,
173 threshold_upper_bound: Option<f64>,
174 comparison_dtype: String,
175 bit_when_true: u8,
176 bit_when_false: u8,
177 },
178 BooleanTest {
179 feature_index: usize,
180 feature_name: String,
181 bit_when_false: u8,
182 bit_when_true: u8,
183 },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188 pub bit_order: String,
189 pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194 pub leaf_index: usize,
195 pub leaf: LeafPayload,
196 pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202 RegressionValue {
203 value: f64,
204 },
205 ClassIndex {
206 class_index: usize,
207 class_value: f64,
208 },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213 pub kind: String,
214 pub tree_weights: Vec<f64>,
215 pub normalize_by_weight_sum: bool,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222 pub feature_count: usize,
223 pub features: Vec<InputFeature>,
224 pub ordering: String,
225 pub input_tensor_layout: String,
226 pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231 pub index: usize,
232 pub name: String,
233 pub dtype: String,
234 pub logical_type: String,
235 pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240 pub raw_outputs: Vec<OutputField>,
241 pub final_outputs: Vec<OutputField>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248 pub name: String,
249 pub kind: String,
250 pub shape: Vec<usize>,
251 pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256 pub numeric_precision: String,
257 pub threshold_comparison: String,
258 pub nan_policy: String,
259 pub bool_encoding: BoolEncoding,
260 pub tie_breaking: TieBreaking,
261 pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266 pub false_values: Vec<String>,
267 pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272 pub classification: String,
273 pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278 pub guaranteed: bool,
279 pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284 pub included_in_model: bool,
285 pub numeric_binning: NumericBinning,
286 pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291 pub kind: String,
292 pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298 Numeric {
299 feature_index: usize,
300 boundaries: Vec<NumericBinBoundary>,
301 },
302 Binary {
303 feature_index: usize,
304 },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309 pub raw_output_kind: String,
310 pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316 Identity,
317 MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323 pub algorithm: String,
324 pub task: String,
325 pub tree_type: String,
326 pub criterion: String,
327 pub canaries: usize,
328 pub compute_oob: bool,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 pub max_depth: Option<usize>,
331 #[serde(skip_serializing_if = "Option::is_none")]
332 pub min_samples_split: Option<usize>,
333 #[serde(skip_serializing_if = "Option::is_none")]
334 pub min_samples_leaf: Option<usize>,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub n_trees: Option<usize>,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 pub max_features: Option<usize>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 pub seed: Option<u64>,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub oob_score: Option<f64>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub class_labels: Option<Vec<f64>>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub learning_rate: Option<f64>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub bootstrap: Option<bool>,
349 #[serde(skip_serializing_if = "Option::is_none")]
350 pub top_gradient_fraction: Option<f64>,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357 pub serialization: String,
358 pub canonical_json: bool,
359 pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364 pub minimum_runtime_version: String,
365 pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370 pub sample_count: usize,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 pub impurity: Option<f64>,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub gain: Option<f64>,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 pub class_counts: Option<Vec<usize>>,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383 UnsupportedIrVersion(String),
384 UnsupportedFormatName(String),
385 UnsupportedAlgorithm(String),
386 UnsupportedTask(String),
387 UnsupportedTreeType(String),
388 InvalidTreeCount(usize),
389 UnsupportedRepresentation(String),
390 InvalidFeatureCount { schema: usize, preprocessing: usize },
391 MissingClassLabels,
392 InvalidLeaf(String),
393 InvalidNode(String),
394 InvalidPreprocessing(String),
395 InvalidInferenceOption(String),
396 Json(String),
397}
398
399impl Display for IrError {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 match self {
402 IrError::UnsupportedIrVersion(version) => {
403 write!(f, "Unsupported IR version: {}.", version)
404 }
405 IrError::UnsupportedFormatName(name) => {
406 write!(f, "Unsupported IR format: {}.", name)
407 }
408 IrError::UnsupportedAlgorithm(algorithm) => {
409 write!(f, "Unsupported algorithm: {}.", algorithm)
410 }
411 IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412 IrError::UnsupportedTreeType(tree_type) => {
413 write!(f, "Unsupported tree type: {}.", tree_type)
414 }
415 IrError::InvalidTreeCount(count) => {
416 write!(f, "Expected exactly one tree in the IR, found {}.", count)
417 }
418 IrError::UnsupportedRepresentation(representation) => {
419 write!(f, "Unsupported tree representation: {}.", representation)
420 }
421 IrError::InvalidFeatureCount {
422 schema,
423 preprocessing,
424 } => write!(
425 f,
426 "Input schema declares {} features, but preprocessing declares {}.",
427 schema, preprocessing
428 ),
429 IrError::MissingClassLabels => {
430 write!(f, "Classification IR requires explicit class labels.")
431 }
432 IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433 IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434 IrError::InvalidPreprocessing(message) => {
435 write!(f, "Invalid preprocessing section: {}.", message)
436 }
437 IrError::InvalidInferenceOption(message) => {
438 write!(f, "Invalid inference options: {}.", message)
439 }
440 IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441 }
442 }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448 pub fn json_schema() -> RootSchema {
449 schema_for!(ModelPackageIr)
450 }
451
452 pub fn json_schema_json() -> Result<String, IrError> {
453 serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454 }
455
456 pub fn json_schema_json_pretty() -> Result<String, IrError> {
457 serde_json::to_string_pretty(&Self::json_schema())
458 .map_err(|err| IrError::Json(err.to_string()))
459 }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463 let trees = match model {
464 Model::RandomForest(forest) => forest
465 .trees()
466 .iter()
467 .map(model_tree_definition)
468 .collect::<Vec<_>>(),
469 Model::GradientBoostedTrees(boosted) => boosted
470 .trees()
471 .iter()
472 .map(model_tree_definition)
473 .collect::<Vec<_>>(),
474 _ => vec![model_tree_definition(model)],
475 };
476 let representation = if let Some(first_tree) = trees.first() {
477 match first_tree {
478 TreeDefinition::NodeTree { .. } => "node_tree",
479 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480 }
481 } else {
482 match model.tree_type() {
483 TreeType::Oblivious => "oblivious_levels",
484 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485 }
486 };
487 let class_labels = model.class_labels();
488 let is_ensemble = matches!(
489 model,
490 Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491 );
492 let tree_count = trees.len();
493 let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494 Model::RandomForest(_) => (
495 match model.task() {
496 Task::Regression => "average",
497 Task::Classification => "average_class_probabilities",
498 },
499 vec![1.0; tree_count],
500 true,
501 None,
502 ),
503 Model::GradientBoostedTrees(boosted) => (
504 match boosted.task() {
505 Task::Regression => "sum_tree_outputs",
506 Task::Classification => "sum_tree_outputs_then_sigmoid",
507 },
508 boosted.tree_weights().to_vec(),
509 false,
510 Some(boosted.base_score()),
511 ),
512 _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513 };
514
515 ModelPackageIr {
516 ir_version: IR_VERSION.to_string(),
517 format_name: FORMAT_NAME.to_string(),
518 producer: ProducerMetadata {
519 library: "forestfire-core".to_string(),
520 library_version: env!("CARGO_PKG_VERSION").to_string(),
521 language: "rust".to_string(),
522 platform: std::env::consts::ARCH.to_string(),
523 },
524 model: ModelSection {
525 algorithm: algorithm_name(model.algorithm()).to_string(),
526 task: task_name(model.task()).to_string(),
527 tree_type: tree_type_name(model.tree_type()).to_string(),
528 representation: representation.to_string(),
529 num_features: model.num_features(),
530 num_outputs: 1,
531 supports_missing: false,
532 supports_categorical: false,
533 is_ensemble,
534 trees,
535 aggregation: Aggregation {
536 kind: aggregation_kind.to_string(),
537 tree_weights,
538 normalize_by_weight_sum,
539 base_score,
540 },
541 },
542 input_schema: input_schema(model),
543 output_schema: output_schema(model, class_labels.clone()),
544 inference_options: InferenceOptions {
545 numeric_precision: "float64".to_string(),
546 threshold_comparison: "leq_left_gt_right".to_string(),
547 nan_policy: "not_supported".to_string(),
548 bool_encoding: BoolEncoding {
549 false_values: vec!["0".to_string(), "false".to_string()],
550 true_values: vec!["1".to_string(), "true".to_string()],
551 },
552 tie_breaking: TieBreaking {
553 classification: "lowest_class_index".to_string(),
554 argmax: "first_max_index".to_string(),
555 },
556 determinism: Determinism {
557 guaranteed: true,
558 notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559 .to_string(),
560 },
561 },
562 preprocessing: preprocessing(model),
563 postprocessing: postprocessing(model, class_labels),
564 training_metadata: model.training_metadata(),
565 integrity: IntegritySection {
566 serialization: "json".to_string(),
567 canonical_json: true,
568 compatibility: Compatibility {
569 minimum_runtime_version: IR_VERSION.to_string(),
570 required_capabilities: required_capabilities(model, representation),
571 },
572 },
573 }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577 validate_ir_header(&ir)?;
578 validate_inference_options(&ir.inference_options)?;
579
580 let algorithm = parse_algorithm(&ir.model.algorithm)?;
581 let task = parse_task(&ir.model.task)?;
582 let tree_type = parse_tree_type(&ir.model.tree_type)?;
583 let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584 let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585 let num_features = ir.input_schema.feature_count;
586 let options = tree_options(&ir.training_metadata);
587 let training_canaries = ir.training_metadata.canaries;
588 let deserialized_class_labels = classification_labels(&ir).ok();
589
590 if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591 return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592 }
593
594 if algorithm == TrainAlgorithm::Rf {
595 let trees = ir
596 .model
597 .trees
598 .into_iter()
599 .map(|tree| {
600 single_model_from_ir_parts(
601 task,
602 tree_type,
603 criterion,
604 feature_preprocessing.clone(),
605 num_features,
606 options.clone(),
607 training_canaries,
608 deserialized_class_labels.clone(),
609 tree,
610 )
611 })
612 .collect::<Result<Vec<_>, _>>()?;
613 return Ok(Model::RandomForest(RandomForest::new(
614 task,
615 criterion,
616 tree_type,
617 trees,
618 ir.training_metadata.compute_oob,
619 ir.training_metadata.oob_score,
620 ir.training_metadata
621 .max_features
622 .unwrap_or(num_features.max(1)),
623 ir.training_metadata.seed,
624 num_features,
625 feature_preprocessing,
626 )));
627 }
628
629 if algorithm == TrainAlgorithm::Gbm {
630 let tree_weights = ir.model.aggregation.tree_weights.clone();
631 let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632 let trees = ir
633 .model
634 .trees
635 .into_iter()
636 .map(|tree| {
637 boosted_tree_model_from_ir_parts(
638 tree_type,
639 criterion,
640 feature_preprocessing.clone(),
641 num_features,
642 options.clone(),
643 training_canaries,
644 tree,
645 )
646 })
647 .collect::<Result<Vec<_>, _>>()?;
648 return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
649 task,
650 tree_type,
651 trees,
652 tree_weights,
653 base_score,
654 ir.training_metadata.learning_rate.unwrap_or(0.1),
655 ir.training_metadata.bootstrap.unwrap_or(false),
656 ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
657 ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
658 ir.training_metadata
659 .max_features
660 .unwrap_or(num_features.max(1)),
661 ir.training_metadata.seed,
662 num_features,
663 feature_preprocessing,
664 deserialized_class_labels,
665 training_canaries,
666 )));
667 }
668
669 let tree = ir
670 .model
671 .trees
672 .into_iter()
673 .next()
674 .expect("validated single tree");
675
676 single_model_from_ir_parts(
677 task,
678 tree_type,
679 criterion,
680 feature_preprocessing,
681 num_features,
682 options,
683 training_canaries,
684 deserialized_class_labels,
685 tree,
686 )
687}
688
689fn boosted_tree_model_from_ir_parts(
690 tree_type: TreeType,
691 criterion: Criterion,
692 feature_preprocessing: Vec<FeaturePreprocessing>,
693 num_features: usize,
694 options: DecisionTreeOptions,
695 training_canaries: usize,
696 tree: TreeDefinition,
697) -> Result<Model, IrError> {
698 match (tree_type, tree) {
699 (
700 TreeType::Cart | TreeType::Randomized,
701 TreeDefinition::NodeTree {
702 nodes,
703 root_node_id,
704 ..
705 },
706 ) => Ok(Model::DecisionTreeRegressor(
707 DecisionTreeRegressor::from_ir_parts(
708 match tree_type {
709 TreeType::Cart => RegressionTreeAlgorithm::Cart,
710 TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
711 _ => unreachable!(),
712 },
713 criterion,
714 RegressionTreeStructure::Standard {
715 nodes: rebuild_regressor_nodes(nodes)?,
716 root: root_node_id,
717 },
718 RegressionTreeOptions {
719 max_depth: options.max_depth,
720 min_samples_split: options.min_samples_split,
721 min_samples_leaf: options.min_samples_leaf,
722 max_features: None,
723 random_seed: 0,
724 missing_value_strategies: Vec::new(),
725 canary_filter: crate::CanaryFilter::default(),
726 },
727 num_features,
728 feature_preprocessing,
729 training_canaries,
730 ),
731 )),
732 (TreeType::Oblivious, TreeDefinition::ObliviousLevels { levels, leaves, .. }) => {
733 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
734 let leaf_variances = rebuild_leaf_variances(&leaves)?;
735 Ok(Model::DecisionTreeRegressor(
736 DecisionTreeRegressor::from_ir_parts(
737 RegressionTreeAlgorithm::Oblivious,
738 criterion,
739 RegressionTreeStructure::Oblivious {
740 splits: rebuild_regressor_oblivious_splits(levels)?,
741 leaf_values: rebuild_regressor_leaf_values(leaves)?,
742 leaf_sample_counts,
743 leaf_variances,
744 },
745 RegressionTreeOptions {
746 max_depth: options.max_depth,
747 min_samples_split: options.min_samples_split,
748 min_samples_leaf: options.min_samples_leaf,
749 max_features: None,
750 random_seed: 0,
751 missing_value_strategies: Vec::new(),
752 canary_filter: crate::CanaryFilter::default(),
753 },
754 num_features,
755 feature_preprocessing,
756 training_canaries,
757 ),
758 ))
759 }
760 (_, tree) => Err(IrError::UnsupportedRepresentation(match tree {
761 TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
762 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
763 })),
764 }
765}
766
767#[allow(clippy::too_many_arguments)]
768fn single_model_from_ir_parts(
769 task: Task,
770 tree_type: TreeType,
771 criterion: Criterion,
772 feature_preprocessing: Vec<FeaturePreprocessing>,
773 num_features: usize,
774 options: DecisionTreeOptions,
775 training_canaries: usize,
776 deserialized_class_labels: Option<Vec<f64>>,
777 tree: TreeDefinition,
778) -> Result<Model, IrError> {
779 match (task, tree_type, tree) {
780 (
781 Task::Classification,
782 TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
783 TreeDefinition::NodeTree {
784 nodes,
785 root_node_id,
786 ..
787 },
788 ) => {
789 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
790 let structure = ClassifierTreeStructure::Standard {
791 nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
792 root: root_node_id,
793 };
794 Ok(Model::DecisionTreeClassifier(
795 DecisionTreeClassifier::from_ir_parts(
796 match tree_type {
797 TreeType::Id3 => DecisionTreeAlgorithm::Id3,
798 TreeType::C45 => DecisionTreeAlgorithm::C45,
799 TreeType::Cart => DecisionTreeAlgorithm::Cart,
800 TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
801 TreeType::Oblivious => unreachable!(),
802 },
803 criterion,
804 class_labels,
805 structure,
806 options,
807 num_features,
808 feature_preprocessing,
809 training_canaries,
810 ),
811 ))
812 }
813 (
814 Task::Classification,
815 TreeType::Oblivious,
816 TreeDefinition::ObliviousLevels { levels, leaves, .. },
817 ) => {
818 let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
819 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
820 let leaf_class_counts =
821 rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
822 let structure = ClassifierTreeStructure::Oblivious {
823 splits: rebuild_classifier_oblivious_splits(levels)?,
824 leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
825 leaf_sample_counts,
826 leaf_class_counts,
827 };
828 Ok(Model::DecisionTreeClassifier(
829 DecisionTreeClassifier::from_ir_parts(
830 DecisionTreeAlgorithm::Oblivious,
831 criterion,
832 class_labels,
833 structure,
834 options,
835 num_features,
836 feature_preprocessing,
837 training_canaries,
838 ),
839 ))
840 }
841 (
842 Task::Regression,
843 TreeType::Cart | TreeType::Randomized,
844 TreeDefinition::NodeTree {
845 nodes,
846 root_node_id,
847 ..
848 },
849 ) => Ok(Model::DecisionTreeRegressor(
850 DecisionTreeRegressor::from_ir_parts(
851 match tree_type {
852 TreeType::Cart => RegressionTreeAlgorithm::Cart,
853 TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
854 _ => unreachable!(),
855 },
856 criterion,
857 RegressionTreeStructure::Standard {
858 nodes: rebuild_regressor_nodes(nodes)?,
859 root: root_node_id,
860 },
861 RegressionTreeOptions {
862 max_depth: options.max_depth,
863 min_samples_split: options.min_samples_split,
864 min_samples_leaf: options.min_samples_leaf,
865 max_features: None,
866 random_seed: 0,
867 missing_value_strategies: Vec::new(),
868 canary_filter: crate::CanaryFilter::default(),
869 },
870 num_features,
871 feature_preprocessing,
872 training_canaries,
873 ),
874 )),
875 (
876 Task::Regression,
877 TreeType::Oblivious,
878 TreeDefinition::ObliviousLevels { levels, leaves, .. },
879 ) => {
880 let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
881 let leaf_variances = rebuild_leaf_variances(&leaves)?;
882 Ok(Model::DecisionTreeRegressor(
883 DecisionTreeRegressor::from_ir_parts(
884 RegressionTreeAlgorithm::Oblivious,
885 criterion,
886 RegressionTreeStructure::Oblivious {
887 splits: rebuild_regressor_oblivious_splits(levels)?,
888 leaf_values: rebuild_regressor_leaf_values(leaves)?,
889 leaf_sample_counts,
890 leaf_variances,
891 },
892 RegressionTreeOptions {
893 max_depth: options.max_depth,
894 min_samples_split: options.min_samples_split,
895 min_samples_leaf: options.min_samples_leaf,
896 max_features: None,
897 random_seed: 0,
898 missing_value_strategies: Vec::new(),
899 canary_filter: crate::CanaryFilter::default(),
900 },
901 num_features,
902 feature_preprocessing,
903 training_canaries,
904 ),
905 ))
906 }
907 (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
908 TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
909 TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
910 })),
911 }
912}
913
914fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
915 if ir.ir_version != IR_VERSION {
916 return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
917 }
918 if ir.format_name != FORMAT_NAME {
919 return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
920 }
921 if ir.model.supports_missing {
922 return Err(IrError::InvalidInferenceOption(
923 "missing values are not supported in IR v1".to_string(),
924 ));
925 }
926 if ir.model.supports_categorical {
927 return Err(IrError::InvalidInferenceOption(
928 "categorical features are not supported in IR v1".to_string(),
929 ));
930 }
931 Ok(())
932}
933
934fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
935 if options.threshold_comparison != "leq_left_gt_right" {
936 return Err(IrError::InvalidInferenceOption(format!(
937 "unsupported threshold comparison '{}'",
938 options.threshold_comparison
939 )));
940 }
941 if options.nan_policy != "not_supported" {
942 return Err(IrError::InvalidInferenceOption(format!(
943 "unsupported nan policy '{}'",
944 options.nan_policy
945 )));
946 }
947 Ok(())
948}
949
950fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
951 match value {
952 "dt" => Ok(TrainAlgorithm::Dt),
953 "rf" => Ok(TrainAlgorithm::Rf),
954 "gbm" => Ok(TrainAlgorithm::Gbm),
955 _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
956 }
957}
958
959fn parse_task(value: &str) -> Result<Task, IrError> {
960 match value {
961 "regression" => Ok(Task::Regression),
962 "classification" => Ok(Task::Classification),
963 _ => Err(IrError::UnsupportedTask(value.to_string())),
964 }
965}
966
967fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
968 match value {
969 "id3" => Ok(TreeType::Id3),
970 "c45" => Ok(TreeType::C45),
971 "cart" => Ok(TreeType::Cart),
972 "randomized" => Ok(TreeType::Randomized),
973 "oblivious" => Ok(TreeType::Oblivious),
974 _ => Err(IrError::UnsupportedTreeType(value.to_string())),
975 }
976}
977
978fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
979 match value {
980 "gini" => Ok(crate::Criterion::Gini),
981 "entropy" => Ok(crate::Criterion::Entropy),
982 "mean" => Ok(crate::Criterion::Mean),
983 "median" => Ok(crate::Criterion::Median),
984 "second_order" => Ok(crate::Criterion::SecondOrder),
985 "auto" => Ok(crate::Criterion::Auto),
986 _ => Err(IrError::InvalidInferenceOption(format!(
987 "unsupported criterion '{}'",
988 value
989 ))),
990 }
991}
992
993fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
994 DecisionTreeOptions {
995 max_depth: training.max_depth.unwrap_or(8),
996 min_samples_split: training.min_samples_split.unwrap_or(2),
997 min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
998 max_features: None,
999 random_seed: 0,
1000 missing_value_strategies: Vec::new(),
1001 canary_filter: crate::CanaryFilter::default(),
1002 }
1003}
1004
1005fn feature_preprocessing_from_ir(
1006 ir: &ModelPackageIr,
1007) -> Result<Vec<FeaturePreprocessing>, IrError> {
1008 let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
1009
1010 for feature in &ir.preprocessing.numeric_binning.features {
1011 match feature {
1012 FeatureBinning::Numeric {
1013 feature_index,
1014 boundaries,
1015 } => {
1016 let slot = features.get_mut(*feature_index).ok_or_else(|| {
1017 IrError::InvalidFeatureCount {
1018 schema: ir.input_schema.feature_count,
1019 preprocessing: feature_index + 1,
1020 }
1021 })?;
1022 *slot = Some(FeaturePreprocessing::Numeric {
1023 bin_boundaries: boundaries.clone(),
1024 missing_bin: boundaries
1025 .iter()
1026 .map(|boundary| boundary.bin)
1027 .max()
1028 .map_or(0, |bin| bin.saturating_add(1)),
1029 });
1030 }
1031 FeatureBinning::Binary { feature_index } => {
1032 let slot = features.get_mut(*feature_index).ok_or_else(|| {
1033 IrError::InvalidFeatureCount {
1034 schema: ir.input_schema.feature_count,
1035 preprocessing: feature_index + 1,
1036 }
1037 })?;
1038 *slot = Some(FeaturePreprocessing::Binary);
1039 }
1040 }
1041 }
1042
1043 if features.len() != ir.input_schema.feature_count {
1044 return Err(IrError::InvalidFeatureCount {
1045 schema: ir.input_schema.feature_count,
1046 preprocessing: features.len(),
1047 });
1048 }
1049
1050 features
1051 .into_iter()
1052 .map(|feature| {
1053 feature.ok_or_else(|| {
1054 IrError::InvalidPreprocessing(
1055 "every feature must have a preprocessing entry".to_string(),
1056 )
1057 })
1058 })
1059 .collect()
1060}
1061
1062fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
1063 ir.output_schema
1064 .class_order
1065 .clone()
1066 .or_else(|| ir.training_metadata.class_labels.clone())
1067 .ok_or(IrError::MissingClassLabels)
1068}
1069
1070fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
1071 match leaf {
1072 LeafPayload::ClassIndex {
1073 class_index,
1074 class_value,
1075 } => {
1076 let Some(expected) = class_labels.get(*class_index) else {
1077 return Err(IrError::InvalidLeaf(format!(
1078 "class index {} out of bounds",
1079 class_index
1080 )));
1081 };
1082 if expected.total_cmp(class_value).is_ne() {
1083 return Err(IrError::InvalidLeaf(format!(
1084 "class value {} does not match class order entry {}",
1085 class_value, expected
1086 )));
1087 }
1088 Ok(*class_index)
1089 }
1090 LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1091 "expected class_index leaf".to_string(),
1092 )),
1093 }
1094}
1095
1096fn rebuild_classifier_nodes(
1097 nodes: Vec<NodeTreeNode>,
1098 class_labels: &[f64],
1099) -> Result<Vec<ClassifierTreeNode>, IrError> {
1100 let mut rebuilt = vec![None; nodes.len()];
1101 for node in nodes {
1102 match node {
1103 NodeTreeNode::Leaf {
1104 node_id,
1105 leaf,
1106 stats,
1107 ..
1108 } => {
1109 let class_index = classifier_class_index(&leaf, class_labels)?;
1110 assign_node(
1111 &mut rebuilt,
1112 node_id,
1113 ClassifierTreeNode::Leaf {
1114 class_index,
1115 sample_count: stats.sample_count,
1116 class_counts: stats
1117 .class_counts
1118 .unwrap_or_else(|| vec![0; class_labels.len()]),
1119 },
1120 )?;
1121 }
1122 NodeTreeNode::BinaryBranch {
1123 node_id,
1124 split,
1125 children,
1126 stats,
1127 ..
1128 } => {
1129 let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1130 assign_node(
1131 &mut rebuilt,
1132 node_id,
1133 ClassifierTreeNode::BinarySplit {
1134 feature_index,
1135 threshold_bin,
1136 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1137 left_child: children.left,
1138 right_child: children.right,
1139 sample_count: stats.sample_count,
1140 impurity: stats.impurity.unwrap_or(0.0),
1141 gain: stats.gain.unwrap_or(0.0),
1142 class_counts: stats
1143 .class_counts
1144 .unwrap_or_else(|| vec![0; class_labels.len()]),
1145 },
1146 )?;
1147 }
1148 NodeTreeNode::MultiwayBranch {
1149 node_id,
1150 split,
1151 branches,
1152 unmatched_leaf,
1153 stats,
1154 ..
1155 } => {
1156 let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1157 assign_node(
1158 &mut rebuilt,
1159 node_id,
1160 ClassifierTreeNode::MultiwaySplit {
1161 feature_index: split.feature_index,
1162 fallback_class_index,
1163 branches: branches
1164 .into_iter()
1165 .map(|branch| (branch.bin, branch.child))
1166 .collect(),
1167 missing_child: None,
1168 sample_count: stats.sample_count,
1169 impurity: stats.impurity.unwrap_or(0.0),
1170 gain: stats.gain.unwrap_or(0.0),
1171 class_counts: stats
1172 .class_counts
1173 .unwrap_or_else(|| vec![0; class_labels.len()]),
1174 },
1175 )?;
1176 }
1177 }
1178 }
1179 collect_nodes(rebuilt)
1180}
1181
1182fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1183 let mut rebuilt = vec![None; nodes.len()];
1184 for node in nodes {
1185 match node {
1186 NodeTreeNode::Leaf {
1187 node_id,
1188 leaf: LeafPayload::RegressionValue { value },
1189 stats,
1190 ..
1191 } => {
1192 assign_node(
1193 &mut rebuilt,
1194 node_id,
1195 RegressionNode::Leaf {
1196 value,
1197 sample_count: stats.sample_count,
1198 variance: stats.variance,
1199 },
1200 )?;
1201 }
1202 NodeTreeNode::Leaf { .. } => {
1203 return Err(IrError::InvalidLeaf(
1204 "regression trees require regression_value leaves".to_string(),
1205 ));
1206 }
1207 NodeTreeNode::BinaryBranch {
1208 node_id,
1209 split,
1210 children,
1211 stats,
1212 ..
1213 } => {
1214 let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1215 assign_node(
1216 &mut rebuilt,
1217 node_id,
1218 RegressionNode::BinarySplit {
1219 feature_index,
1220 threshold_bin,
1221 missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1222 missing_value: 0.0,
1223 left_child: children.left,
1224 right_child: children.right,
1225 sample_count: stats.sample_count,
1226 impurity: stats.impurity.unwrap_or(0.0),
1227 gain: stats.gain.unwrap_or(0.0),
1228 variance: stats.variance,
1229 },
1230 )?;
1231 }
1232 NodeTreeNode::MultiwayBranch { .. } => {
1233 return Err(IrError::InvalidNode(
1234 "regression trees do not support multiway branches".to_string(),
1235 ));
1236 }
1237 }
1238 }
1239 collect_nodes(rebuilt)
1240}
1241
1242fn rebuild_classifier_oblivious_splits(
1243 levels: Vec<ObliviousLevel>,
1244) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1245 let mut rebuilt = Vec::with_capacity(levels.len());
1246 for level in levels {
1247 rebuilt.push(match level.split {
1248 ObliviousSplit::NumericBinThreshold {
1249 feature_index,
1250 threshold_bin,
1251 ..
1252 } => ClassifierObliviousSplit {
1253 feature_index,
1254 threshold_bin,
1255 missing_directions: Vec::new(),
1256 sample_count: level.stats.sample_count,
1257 impurity: level.stats.impurity.unwrap_or(0.0),
1258 gain: level.stats.gain.unwrap_or(0.0),
1259 },
1260 ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1261 feature_index,
1262 threshold_bin: 0,
1263 missing_directions: Vec::new(),
1264 sample_count: level.stats.sample_count,
1265 impurity: level.stats.impurity.unwrap_or(0.0),
1266 gain: level.stats.gain.unwrap_or(0.0),
1267 },
1268 });
1269 }
1270 Ok(rebuilt)
1271}
1272
1273fn rebuild_regressor_oblivious_splits(
1274 levels: Vec<ObliviousLevel>,
1275) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1276 let mut rebuilt = Vec::with_capacity(levels.len());
1277 for level in levels {
1278 rebuilt.push(match level.split {
1279 ObliviousSplit::NumericBinThreshold {
1280 feature_index,
1281 threshold_bin,
1282 ..
1283 } => RegressorObliviousSplit {
1284 feature_index,
1285 threshold_bin,
1286 sample_count: level.stats.sample_count,
1287 impurity: level.stats.impurity.unwrap_or(0.0),
1288 gain: level.stats.gain.unwrap_or(0.0),
1289 },
1290 ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1291 feature_index,
1292 threshold_bin: 0,
1293 sample_count: level.stats.sample_count,
1294 impurity: level.stats.impurity.unwrap_or(0.0),
1295 gain: level.stats.gain.unwrap_or(0.0),
1296 },
1297 });
1298 }
1299 Ok(rebuilt)
1300}
1301
1302fn rebuild_classifier_leaf_indices(
1303 leaves: Vec<IndexedLeaf>,
1304 class_labels: &[f64],
1305) -> Result<Vec<usize>, IrError> {
1306 let mut rebuilt = vec![None; leaves.len()];
1307 for indexed_leaf in leaves {
1308 let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1309 assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1310 }
1311 collect_nodes(rebuilt)
1312}
1313
1314fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1315 let mut rebuilt = vec![None; leaves.len()];
1316 for indexed_leaf in leaves {
1317 let value = match indexed_leaf.leaf {
1318 LeafPayload::RegressionValue { value } => value,
1319 LeafPayload::ClassIndex { .. } => {
1320 return Err(IrError::InvalidLeaf(
1321 "regression oblivious leaves require regression_value".to_string(),
1322 ));
1323 }
1324 };
1325 assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1326 }
1327 collect_nodes(rebuilt)
1328}
1329
1330fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1331 let mut rebuilt = vec![None; leaves.len()];
1332 for indexed_leaf in leaves {
1333 assign_node(
1334 &mut rebuilt,
1335 indexed_leaf.leaf_index,
1336 indexed_leaf.stats.sample_count,
1337 )?;
1338 }
1339 collect_nodes(rebuilt)
1340}
1341
1342fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1343 let mut rebuilt = vec![None; leaves.len()];
1344 for indexed_leaf in leaves {
1345 assign_node(
1346 &mut rebuilt,
1347 indexed_leaf.leaf_index,
1348 indexed_leaf.stats.variance,
1349 )?;
1350 }
1351 collect_nodes(rebuilt)
1352}
1353
1354fn rebuild_classifier_leaf_class_counts(
1355 leaves: &[IndexedLeaf],
1356 num_classes: usize,
1357) -> Result<Vec<Vec<usize>>, IrError> {
1358 let mut rebuilt = vec![None; leaves.len()];
1359 for indexed_leaf in leaves {
1360 assign_node(
1361 &mut rebuilt,
1362 indexed_leaf.leaf_index,
1363 indexed_leaf
1364 .stats
1365 .class_counts
1366 .clone()
1367 .unwrap_or_else(|| vec![0; num_classes]),
1368 )?;
1369 }
1370 collect_nodes(rebuilt)
1371}
1372
1373fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1374 match split {
1375 BinarySplit::NumericBinThreshold {
1376 feature_index,
1377 threshold_bin,
1378 ..
1379 } => Ok((feature_index, threshold_bin)),
1380 BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1381 }
1382}
1383
1384fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1385 classifier_binary_split(split)
1386}
1387
1388fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1389 let Some(slot) = slots.get_mut(index) else {
1390 return Err(IrError::InvalidNode(format!(
1391 "node index {} is out of bounds",
1392 index
1393 )));
1394 };
1395 if slot.is_some() {
1396 return Err(IrError::InvalidNode(format!(
1397 "duplicate node index {}",
1398 index
1399 )));
1400 }
1401 *slot = Some(value);
1402 Ok(())
1403}
1404
1405fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1406 slots
1407 .into_iter()
1408 .enumerate()
1409 .map(|(index, slot)| {
1410 slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1411 })
1412 .collect()
1413}
1414
1415fn input_schema(model: &Model) -> InputSchema {
1416 let features = model
1417 .feature_preprocessing()
1418 .iter()
1419 .enumerate()
1420 .map(|(feature_index, preprocessing)| {
1421 let kind = match preprocessing {
1422 FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1423 FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1424 };
1425
1426 InputFeature {
1427 index: feature_index,
1428 name: feature_name(feature_index),
1429 dtype: match kind {
1430 InputFeatureKind::Numeric => "float64".to_string(),
1431 InputFeatureKind::Binary => "bool".to_string(),
1432 },
1433 logical_type: match kind {
1434 InputFeatureKind::Numeric => "numeric".to_string(),
1435 InputFeatureKind::Binary => "boolean".to_string(),
1436 },
1437 nullable: false,
1438 }
1439 })
1440 .collect();
1441
1442 InputSchema {
1443 feature_count: model.num_features(),
1444 features,
1445 ordering: "strict_index_order".to_string(),
1446 input_tensor_layout: "row_major".to_string(),
1447 accepts_feature_names: false,
1448 }
1449}
1450
1451fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1452 match model.task() {
1453 Task::Regression => OutputSchema {
1454 raw_outputs: vec![OutputField {
1455 name: "value".to_string(),
1456 kind: "regression_value".to_string(),
1457 shape: Vec::new(),
1458 dtype: "float64".to_string(),
1459 }],
1460 final_outputs: vec![OutputField {
1461 name: "prediction".to_string(),
1462 kind: "value".to_string(),
1463 shape: Vec::new(),
1464 dtype: "float64".to_string(),
1465 }],
1466 class_order: None,
1467 },
1468 Task::Classification => OutputSchema {
1469 raw_outputs: vec![OutputField {
1470 name: "class_index".to_string(),
1471 kind: "class_index".to_string(),
1472 shape: Vec::new(),
1473 dtype: "uint64".to_string(),
1474 }],
1475 final_outputs: vec![OutputField {
1476 name: "predicted_class".to_string(),
1477 kind: "class_label".to_string(),
1478 shape: Vec::new(),
1479 dtype: "float64".to_string(),
1480 }],
1481 class_order: class_labels,
1482 },
1483 }
1484}
1485
1486fn preprocessing(model: &Model) -> PreprocessingSection {
1487 let features = model
1488 .feature_preprocessing()
1489 .iter()
1490 .enumerate()
1491 .map(|(feature_index, preprocessing)| match preprocessing {
1492 FeaturePreprocessing::Numeric { bin_boundaries, .. } => FeatureBinning::Numeric {
1493 feature_index,
1494 boundaries: bin_boundaries.clone(),
1495 },
1496 FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1497 })
1498 .collect();
1499
1500 PreprocessingSection {
1501 included_in_model: true,
1502 numeric_binning: NumericBinning {
1503 kind: "rank_bin_128".to_string(),
1504 features,
1505 },
1506 notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1507 .to_string(),
1508 }
1509}
1510
1511fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1512 match model.task() {
1513 Task::Regression => PostprocessingSection {
1514 raw_output_kind: "regression_value".to_string(),
1515 steps: vec![PostprocessingStep::Identity],
1516 },
1517 Task::Classification => PostprocessingSection {
1518 raw_output_kind: "class_index".to_string(),
1519 steps: vec![PostprocessingStep::MapClassIndexToLabel {
1520 labels: class_labels.expect("classification IR requires class labels"),
1521 }],
1522 },
1523 }
1524}
1525
1526fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1527 let mut capabilities = vec![
1528 representation.to_string(),
1529 "training_rank_bin_128".to_string(),
1530 ];
1531 match model.tree_type() {
1532 TreeType::Id3 | TreeType::C45 => {
1533 capabilities.push("binned_multiway_splits".to_string());
1534 }
1535 TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1536 capabilities.push("numeric_bin_threshold_splits".to_string());
1537 }
1538 }
1539 if model
1540 .feature_preprocessing()
1541 .iter()
1542 .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1543 {
1544 capabilities.push("boolean_features".to_string());
1545 }
1546 match model.task() {
1547 Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1548 Task::Classification => capabilities.push("class_index_leaves".to_string()),
1549 }
1550 capabilities
1551}
1552
1553pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1554 match algorithm {
1555 TrainAlgorithm::Dt => "dt",
1556 TrainAlgorithm::Rf => "rf",
1557 TrainAlgorithm::Gbm => "gbm",
1558 }
1559}
1560
1561fn model_tree_definition(model: &Model) -> TreeDefinition {
1562 match model {
1563 Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1564 Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1565 Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1566 unreachable!("ensemble IR expands into member trees")
1567 }
1568 }
1569}
1570
1571pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1572 match criterion {
1573 crate::Criterion::Auto => "auto",
1574 crate::Criterion::Gini => "gini",
1575 crate::Criterion::Entropy => "entropy",
1576 crate::Criterion::Mean => "mean",
1577 crate::Criterion::Median => "median",
1578 crate::Criterion::SecondOrder => "second_order",
1579 }
1580}
1581
1582pub(crate) fn task_name(task: Task) -> &'static str {
1583 match task {
1584 Task::Regression => "regression",
1585 Task::Classification => "classification",
1586 }
1587}
1588
1589pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1590 match tree_type {
1591 TreeType::Id3 => "id3",
1592 TreeType::C45 => "c45",
1593 TreeType::Cart => "cart",
1594 TreeType::Randomized => "randomized",
1595 TreeType::Oblivious => "oblivious",
1596 }
1597}
1598
1599pub(crate) fn feature_name(feature_index: usize) -> String {
1600 format!("f{}", feature_index)
1601}
1602
1603pub(crate) fn threshold_upper_bound(
1604 preprocessing: &[FeaturePreprocessing],
1605 feature_index: usize,
1606 threshold_bin: u16,
1607) -> Option<f64> {
1608 match preprocessing.get(feature_index)? {
1609 FeaturePreprocessing::Numeric { bin_boundaries, .. } => bin_boundaries
1610 .iter()
1611 .find(|boundary| boundary.bin == threshold_bin)
1612 .map(|boundary| boundary.upper_bound),
1613 FeaturePreprocessing::Binary => None,
1614 }
1615}