Skip to main content

forestfire_core/
ir.rs

1//! Stable intermediate representation for ForestFire models.
2//!
3//! The IR sits between two worlds:
4//!
5//! - the semantic in-memory model types used for training and introspection
6//! - the lowered runtime structures used by optimized inference
7//!
8//! It exists so models can be serialized, schema-checked, inspected from other
9//! languages, and reconstructed without depending on the exact Rust memory
10//! layout of the training structs.
11
12use crate::tree::classifier::{
13    DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14    ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15    TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18    DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19    RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22    Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23    NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33/// Top-level model package serialized by the library.
34#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36    pub ir_version: String,
37    pub format_name: String,
38    pub producer: ProducerMetadata,
39    pub model: ModelSection,
40    pub input_schema: InputSchema,
41    pub output_schema: OutputSchema,
42    pub inference_options: InferenceOptions,
43    pub preprocessing: PreprocessingSection,
44    pub postprocessing: PostprocessingSection,
45    pub training_metadata: TrainingMetadata,
46    pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51    pub library: String,
52    pub library_version: String,
53    pub language: String,
54    pub platform: String,
55}
56
57/// Structural model description independent of any concrete runtime layout.
58#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60    pub algorithm: String,
61    pub task: String,
62    pub tree_type: String,
63    pub representation: String,
64    pub num_features: usize,
65    pub num_outputs: usize,
66    pub supports_missing: bool,
67    pub supports_categorical: bool,
68    pub is_ensemble: bool,
69    pub trees: Vec<TreeDefinition>,
70    pub aggregation: Aggregation,
71}
72
73/// Concrete tree payload stored in the IR.
74#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77    NodeTree {
78        tree_id: usize,
79        weight: f64,
80        root_node_id: usize,
81        nodes: Vec<NodeTreeNode>,
82    },
83    ObliviousLevels {
84        tree_id: usize,
85        weight: f64,
86        depth: usize,
87        levels: Vec<ObliviousLevel>,
88        leaf_indexing: LeafIndexing,
89        leaves: Vec<IndexedLeaf>,
90    },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96    Leaf {
97        node_id: usize,
98        depth: usize,
99        leaf: LeafPayload,
100        stats: NodeStats,
101    },
102    BinaryBranch {
103        node_id: usize,
104        depth: usize,
105        split: BinarySplit,
106        children: BinaryChildren,
107        stats: NodeStats,
108    },
109    MultiwayBranch {
110        node_id: usize,
111        depth: usize,
112        split: MultiwaySplit,
113        branches: Vec<MultiwayBranch>,
114        unmatched_leaf: LeafPayload,
115        stats: NodeStats,
116    },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121    pub left: usize,
122    pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127    pub bin: u16,
128    pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134    NumericBinThreshold {
135        feature_index: usize,
136        feature_name: String,
137        operator: String,
138        threshold_bin: u16,
139        threshold_upper_bound: Option<f64>,
140        comparison_dtype: String,
141    },
142    BooleanTest {
143        feature_index: usize,
144        feature_name: String,
145        false_child_semantics: String,
146        true_child_semantics: String,
147    },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152    pub split_type: String,
153    pub feature_index: usize,
154    pub feature_name: String,
155    pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160    pub level: usize,
161    pub split: ObliviousSplit,
162    pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168    NumericBinThreshold {
169        feature_index: usize,
170        feature_name: String,
171        operator: String,
172        threshold_bin: u16,
173        threshold_upper_bound: Option<f64>,
174        comparison_dtype: String,
175        bit_when_true: u8,
176        bit_when_false: u8,
177    },
178    BooleanTest {
179        feature_index: usize,
180        feature_name: String,
181        bit_when_false: u8,
182        bit_when_true: u8,
183    },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188    pub bit_order: String,
189    pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194    pub leaf_index: usize,
195    pub leaf: LeafPayload,
196    pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202    RegressionValue {
203        value: f64,
204    },
205    ClassIndex {
206        class_index: usize,
207        class_value: f64,
208    },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213    pub kind: String,
214    pub tree_weights: Vec<f64>,
215    pub normalize_by_weight_sum: bool,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222    pub feature_count: usize,
223    pub features: Vec<InputFeature>,
224    pub ordering: String,
225    pub input_tensor_layout: String,
226    pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231    pub index: usize,
232    pub name: String,
233    pub dtype: String,
234    pub logical_type: String,
235    pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240    pub raw_outputs: Vec<OutputField>,
241    pub final_outputs: Vec<OutputField>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248    pub name: String,
249    pub kind: String,
250    pub shape: Vec<usize>,
251    pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256    pub numeric_precision: String,
257    pub threshold_comparison: String,
258    pub nan_policy: String,
259    pub bool_encoding: BoolEncoding,
260    pub tie_breaking: TieBreaking,
261    pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266    pub false_values: Vec<String>,
267    pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272    pub classification: String,
273    pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278    pub guaranteed: bool,
279    pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284    pub included_in_model: bool,
285    pub numeric_binning: NumericBinning,
286    pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291    pub kind: String,
292    pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298    Numeric {
299        feature_index: usize,
300        boundaries: Vec<NumericBinBoundary>,
301    },
302    Binary {
303        feature_index: usize,
304    },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309    pub raw_output_kind: String,
310    pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316    Identity,
317    MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320/// Serialized training metadata reflected back to bindings and docs.
321#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323    pub algorithm: String,
324    pub task: String,
325    pub tree_type: String,
326    pub criterion: String,
327    pub canaries: usize,
328    pub compute_oob: bool,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub max_depth: Option<usize>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub min_samples_split: Option<usize>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub min_samples_leaf: Option<usize>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub n_trees: Option<usize>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub max_features: Option<usize>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub seed: Option<u64>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub oob_score: Option<f64>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub class_labels: Option<Vec<f64>>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub learning_rate: Option<f64>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub bootstrap: Option<bool>,
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub top_gradient_fraction: Option<f64>,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357    pub serialization: String,
358    pub canonical_json: bool,
359    pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364    pub minimum_runtime_version: String,
365    pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370    pub sample_count: usize,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub impurity: Option<f64>,
373    #[serde(skip_serializing_if = "Option::is_none")]
374    pub gain: Option<f64>,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    pub class_counts: Option<Vec<usize>>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383    UnsupportedIrVersion(String),
384    UnsupportedFormatName(String),
385    UnsupportedAlgorithm(String),
386    UnsupportedTask(String),
387    UnsupportedTreeType(String),
388    InvalidTreeCount(usize),
389    UnsupportedRepresentation(String),
390    InvalidFeatureCount { schema: usize, preprocessing: usize },
391    MissingClassLabels,
392    InvalidLeaf(String),
393    InvalidNode(String),
394    InvalidPreprocessing(String),
395    InvalidInferenceOption(String),
396    Json(String),
397}
398
399impl Display for IrError {
400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401        match self {
402            IrError::UnsupportedIrVersion(version) => {
403                write!(f, "Unsupported IR version: {}.", version)
404            }
405            IrError::UnsupportedFormatName(name) => {
406                write!(f, "Unsupported IR format: {}.", name)
407            }
408            IrError::UnsupportedAlgorithm(algorithm) => {
409                write!(f, "Unsupported algorithm: {}.", algorithm)
410            }
411            IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412            IrError::UnsupportedTreeType(tree_type) => {
413                write!(f, "Unsupported tree type: {}.", tree_type)
414            }
415            IrError::InvalidTreeCount(count) => {
416                write!(f, "Expected exactly one tree in the IR, found {}.", count)
417            }
418            IrError::UnsupportedRepresentation(representation) => {
419                write!(f, "Unsupported tree representation: {}.", representation)
420            }
421            IrError::InvalidFeatureCount {
422                schema,
423                preprocessing,
424            } => write!(
425                f,
426                "Input schema declares {} features, but preprocessing declares {}.",
427                schema, preprocessing
428            ),
429            IrError::MissingClassLabels => {
430                write!(f, "Classification IR requires explicit class labels.")
431            }
432            IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433            IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434            IrError::InvalidPreprocessing(message) => {
435                write!(f, "Invalid preprocessing section: {}.", message)
436            }
437            IrError::InvalidInferenceOption(message) => {
438                write!(f, "Invalid inference options: {}.", message)
439            }
440            IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441        }
442    }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448    pub fn json_schema() -> RootSchema {
449        schema_for!(ModelPackageIr)
450    }
451
452    pub fn json_schema_json() -> Result<String, IrError> {
453        serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454    }
455
456    pub fn json_schema_json_pretty() -> Result<String, IrError> {
457        serde_json::to_string_pretty(&Self::json_schema())
458            .map_err(|err| IrError::Json(err.to_string()))
459    }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463    let trees = match model {
464        Model::RandomForest(forest) => forest
465            .trees()
466            .iter()
467            .map(model_tree_definition)
468            .collect::<Vec<_>>(),
469        Model::GradientBoostedTrees(boosted) => boosted
470            .trees()
471            .iter()
472            .map(model_tree_definition)
473            .collect::<Vec<_>>(),
474        _ => vec![model_tree_definition(model)],
475    };
476    let representation = if let Some(first_tree) = trees.first() {
477        match first_tree {
478            TreeDefinition::NodeTree { .. } => "node_tree",
479            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480        }
481    } else {
482        match model.tree_type() {
483            TreeType::Oblivious => "oblivious_levels",
484            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485        }
486    };
487    let class_labels = model.class_labels();
488    let is_ensemble = matches!(
489        model,
490        Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491    );
492    let tree_count = trees.len();
493    let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494        Model::RandomForest(_) => (
495            match model.task() {
496                Task::Regression => "average",
497                Task::Classification => "average_class_probabilities",
498            },
499            vec![1.0; tree_count],
500            true,
501            None,
502        ),
503        Model::GradientBoostedTrees(boosted) => (
504            match boosted.task() {
505                Task::Regression => "sum_tree_outputs",
506                Task::Classification => "sum_tree_outputs_then_sigmoid",
507            },
508            boosted.tree_weights().to_vec(),
509            false,
510            Some(boosted.base_score()),
511        ),
512        _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513    };
514
515    ModelPackageIr {
516        ir_version: IR_VERSION.to_string(),
517        format_name: FORMAT_NAME.to_string(),
518        producer: ProducerMetadata {
519            library: "forestfire-core".to_string(),
520            library_version: env!("CARGO_PKG_VERSION").to_string(),
521            language: "rust".to_string(),
522            platform: std::env::consts::ARCH.to_string(),
523        },
524        model: ModelSection {
525            algorithm: algorithm_name(model.algorithm()).to_string(),
526            task: task_name(model.task()).to_string(),
527            tree_type: tree_type_name(model.tree_type()).to_string(),
528            representation: representation.to_string(),
529            num_features: model.num_features(),
530            num_outputs: 1,
531            supports_missing: false,
532            supports_categorical: false,
533            is_ensemble,
534            trees,
535            aggregation: Aggregation {
536                kind: aggregation_kind.to_string(),
537                tree_weights,
538                normalize_by_weight_sum,
539                base_score,
540            },
541        },
542        input_schema: input_schema(model),
543        output_schema: output_schema(model, class_labels.clone()),
544        inference_options: InferenceOptions {
545            numeric_precision: "float64".to_string(),
546            threshold_comparison: "leq_left_gt_right".to_string(),
547            nan_policy: "not_supported".to_string(),
548            bool_encoding: BoolEncoding {
549                false_values: vec!["0".to_string(), "false".to_string()],
550                true_values: vec!["1".to_string(), "true".to_string()],
551            },
552            tie_breaking: TieBreaking {
553                classification: "lowest_class_index".to_string(),
554                argmax: "first_max_index".to_string(),
555            },
556            determinism: Determinism {
557                guaranteed: true,
558                notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559                    .to_string(),
560            },
561        },
562        preprocessing: preprocessing(model),
563        postprocessing: postprocessing(model, class_labels),
564        training_metadata: model.training_metadata(),
565        integrity: IntegritySection {
566            serialization: "json".to_string(),
567            canonical_json: true,
568            compatibility: Compatibility {
569                minimum_runtime_version: IR_VERSION.to_string(),
570                required_capabilities: required_capabilities(model, representation),
571            },
572        },
573    }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577    validate_ir_header(&ir)?;
578    validate_inference_options(&ir.inference_options)?;
579
580    let algorithm = parse_algorithm(&ir.model.algorithm)?;
581    let task = parse_task(&ir.model.task)?;
582    let tree_type = parse_tree_type(&ir.model.tree_type)?;
583    let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584    let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585    let num_features = ir.input_schema.feature_count;
586    let options = tree_options(&ir.training_metadata);
587    let training_canaries = ir.training_metadata.canaries;
588    let deserialized_class_labels = classification_labels(&ir).ok();
589
590    if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591        return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592    }
593
594    if algorithm == TrainAlgorithm::Rf {
595        let trees = ir
596            .model
597            .trees
598            .into_iter()
599            .map(|tree| {
600                single_model_from_ir_parts(
601                    task,
602                    tree_type,
603                    criterion,
604                    feature_preprocessing.clone(),
605                    num_features,
606                    options,
607                    training_canaries,
608                    deserialized_class_labels.clone(),
609                    tree,
610                )
611            })
612            .collect::<Result<Vec<_>, _>>()?;
613        return Ok(Model::RandomForest(RandomForest::new(
614            task,
615            criterion,
616            tree_type,
617            trees,
618            ir.training_metadata.compute_oob,
619            ir.training_metadata.oob_score,
620            ir.training_metadata
621                .max_features
622                .unwrap_or(num_features.max(1)),
623            ir.training_metadata.seed,
624            num_features,
625            feature_preprocessing,
626        )));
627    }
628
629    if algorithm == TrainAlgorithm::Gbm {
630        let tree_weights = ir.model.aggregation.tree_weights.clone();
631        let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632        let trees = ir
633            .model
634            .trees
635            .into_iter()
636            .map(|tree| {
637                single_model_from_ir_parts(
638                    task,
639                    tree_type,
640                    criterion,
641                    feature_preprocessing.clone(),
642                    num_features,
643                    options,
644                    training_canaries,
645                    deserialized_class_labels.clone(),
646                    tree,
647                )
648            })
649            .collect::<Result<Vec<_>, _>>()?;
650        return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
651            task,
652            tree_type,
653            trees,
654            tree_weights,
655            base_score,
656            ir.training_metadata.learning_rate.unwrap_or(0.1),
657            ir.training_metadata.bootstrap.unwrap_or(false),
658            ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
659            ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
660            ir.training_metadata
661                .max_features
662                .unwrap_or(num_features.max(1)),
663            ir.training_metadata.seed,
664            num_features,
665            feature_preprocessing,
666            deserialized_class_labels,
667            training_canaries,
668        )));
669    }
670
671    let tree = ir
672        .model
673        .trees
674        .into_iter()
675        .next()
676        .expect("validated single tree");
677
678    single_model_from_ir_parts(
679        task,
680        tree_type,
681        criterion,
682        feature_preprocessing,
683        num_features,
684        options,
685        training_canaries,
686        deserialized_class_labels,
687        tree,
688    )
689}
690
691#[allow(clippy::too_many_arguments)]
692fn single_model_from_ir_parts(
693    task: Task,
694    tree_type: TreeType,
695    criterion: Criterion,
696    feature_preprocessing: Vec<FeaturePreprocessing>,
697    num_features: usize,
698    options: DecisionTreeOptions,
699    training_canaries: usize,
700    deserialized_class_labels: Option<Vec<f64>>,
701    tree: TreeDefinition,
702) -> Result<Model, IrError> {
703    match (task, tree_type, tree) {
704        (
705            Task::Classification,
706            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
707            TreeDefinition::NodeTree {
708                nodes,
709                root_node_id,
710                ..
711            },
712        ) => {
713            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
714            let structure = ClassifierTreeStructure::Standard {
715                nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
716                root: root_node_id,
717            };
718            Ok(Model::DecisionTreeClassifier(
719                DecisionTreeClassifier::from_ir_parts(
720                    match tree_type {
721                        TreeType::Id3 => DecisionTreeAlgorithm::Id3,
722                        TreeType::C45 => DecisionTreeAlgorithm::C45,
723                        TreeType::Cart => DecisionTreeAlgorithm::Cart,
724                        TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
725                        TreeType::Oblivious => unreachable!(),
726                    },
727                    criterion,
728                    class_labels,
729                    structure,
730                    options,
731                    num_features,
732                    feature_preprocessing,
733                    training_canaries,
734                ),
735            ))
736        }
737        (
738            Task::Classification,
739            TreeType::Oblivious,
740            TreeDefinition::ObliviousLevels { levels, leaves, .. },
741        ) => {
742            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
743            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
744            let leaf_class_counts =
745                rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
746            let structure = ClassifierTreeStructure::Oblivious {
747                splits: rebuild_classifier_oblivious_splits(levels)?,
748                leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
749                leaf_sample_counts,
750                leaf_class_counts,
751            };
752            Ok(Model::DecisionTreeClassifier(
753                DecisionTreeClassifier::from_ir_parts(
754                    DecisionTreeAlgorithm::Oblivious,
755                    criterion,
756                    class_labels,
757                    structure,
758                    options,
759                    num_features,
760                    feature_preprocessing,
761                    training_canaries,
762                ),
763            ))
764        }
765        (
766            Task::Regression,
767            TreeType::Cart | TreeType::Randomized,
768            TreeDefinition::NodeTree {
769                nodes,
770                root_node_id,
771                ..
772            },
773        ) => Ok(Model::DecisionTreeRegressor(
774            DecisionTreeRegressor::from_ir_parts(
775                match tree_type {
776                    TreeType::Cart => RegressionTreeAlgorithm::Cart,
777                    TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
778                    _ => unreachable!(),
779                },
780                criterion,
781                RegressionTreeStructure::Standard {
782                    nodes: rebuild_regressor_nodes(nodes)?,
783                    root: root_node_id,
784                },
785                RegressionTreeOptions {
786                    max_depth: options.max_depth,
787                    min_samples_split: options.min_samples_split,
788                    min_samples_leaf: options.min_samples_leaf,
789                    max_features: None,
790                    random_seed: 0,
791                },
792                num_features,
793                feature_preprocessing,
794                training_canaries,
795            ),
796        )),
797        (
798            Task::Regression,
799            TreeType::Oblivious,
800            TreeDefinition::ObliviousLevels { levels, leaves, .. },
801        ) => {
802            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
803            let leaf_variances = rebuild_leaf_variances(&leaves)?;
804            Ok(Model::DecisionTreeRegressor(
805                DecisionTreeRegressor::from_ir_parts(
806                    RegressionTreeAlgorithm::Oblivious,
807                    criterion,
808                    RegressionTreeStructure::Oblivious {
809                        splits: rebuild_regressor_oblivious_splits(levels)?,
810                        leaf_values: rebuild_regressor_leaf_values(leaves)?,
811                        leaf_sample_counts,
812                        leaf_variances,
813                    },
814                    RegressionTreeOptions {
815                        max_depth: options.max_depth,
816                        min_samples_split: options.min_samples_split,
817                        min_samples_leaf: options.min_samples_leaf,
818                        max_features: None,
819                        random_seed: 0,
820                    },
821                    num_features,
822                    feature_preprocessing,
823                    training_canaries,
824                ),
825            ))
826        }
827        (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
828            TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
829            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
830        })),
831    }
832}
833
834fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
835    if ir.ir_version != IR_VERSION {
836        return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
837    }
838    if ir.format_name != FORMAT_NAME {
839        return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
840    }
841    if ir.model.supports_missing {
842        return Err(IrError::InvalidInferenceOption(
843            "missing values are not supported in IR v1".to_string(),
844        ));
845    }
846    if ir.model.supports_categorical {
847        return Err(IrError::InvalidInferenceOption(
848            "categorical features are not supported in IR v1".to_string(),
849        ));
850    }
851    Ok(())
852}
853
854fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
855    if options.threshold_comparison != "leq_left_gt_right" {
856        return Err(IrError::InvalidInferenceOption(format!(
857            "unsupported threshold comparison '{}'",
858            options.threshold_comparison
859        )));
860    }
861    if options.nan_policy != "not_supported" {
862        return Err(IrError::InvalidInferenceOption(format!(
863            "unsupported nan policy '{}'",
864            options.nan_policy
865        )));
866    }
867    Ok(())
868}
869
870fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
871    match value {
872        "dt" => Ok(TrainAlgorithm::Dt),
873        "rf" => Ok(TrainAlgorithm::Rf),
874        "gbm" => Ok(TrainAlgorithm::Gbm),
875        _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
876    }
877}
878
879fn parse_task(value: &str) -> Result<Task, IrError> {
880    match value {
881        "regression" => Ok(Task::Regression),
882        "classification" => Ok(Task::Classification),
883        _ => Err(IrError::UnsupportedTask(value.to_string())),
884    }
885}
886
887fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
888    match value {
889        "id3" => Ok(TreeType::Id3),
890        "c45" => Ok(TreeType::C45),
891        "cart" => Ok(TreeType::Cart),
892        "randomized" => Ok(TreeType::Randomized),
893        "oblivious" => Ok(TreeType::Oblivious),
894        _ => Err(IrError::UnsupportedTreeType(value.to_string())),
895    }
896}
897
898fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
899    match value {
900        "gini" => Ok(crate::Criterion::Gini),
901        "entropy" => Ok(crate::Criterion::Entropy),
902        "mean" => Ok(crate::Criterion::Mean),
903        "median" => Ok(crate::Criterion::Median),
904        "second_order" => Ok(crate::Criterion::SecondOrder),
905        "auto" => Ok(crate::Criterion::Auto),
906        _ => Err(IrError::InvalidInferenceOption(format!(
907            "unsupported criterion '{}'",
908            value
909        ))),
910    }
911}
912
913fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
914    DecisionTreeOptions {
915        max_depth: training.max_depth.unwrap_or(8),
916        min_samples_split: training.min_samples_split.unwrap_or(2),
917        min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
918        max_features: None,
919        random_seed: 0,
920    }
921}
922
923fn feature_preprocessing_from_ir(
924    ir: &ModelPackageIr,
925) -> Result<Vec<FeaturePreprocessing>, IrError> {
926    let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
927
928    for feature in &ir.preprocessing.numeric_binning.features {
929        match feature {
930            FeatureBinning::Numeric {
931                feature_index,
932                boundaries,
933            } => {
934                let slot = features.get_mut(*feature_index).ok_or_else(|| {
935                    IrError::InvalidFeatureCount {
936                        schema: ir.input_schema.feature_count,
937                        preprocessing: feature_index + 1,
938                    }
939                })?;
940                *slot = Some(FeaturePreprocessing::Numeric {
941                    bin_boundaries: boundaries.clone(),
942                });
943            }
944            FeatureBinning::Binary { feature_index } => {
945                let slot = features.get_mut(*feature_index).ok_or_else(|| {
946                    IrError::InvalidFeatureCount {
947                        schema: ir.input_schema.feature_count,
948                        preprocessing: feature_index + 1,
949                    }
950                })?;
951                *slot = Some(FeaturePreprocessing::Binary);
952            }
953        }
954    }
955
956    if features.len() != ir.input_schema.feature_count {
957        return Err(IrError::InvalidFeatureCount {
958            schema: ir.input_schema.feature_count,
959            preprocessing: features.len(),
960        });
961    }
962
963    features
964        .into_iter()
965        .map(|feature| {
966            feature.ok_or_else(|| {
967                IrError::InvalidPreprocessing(
968                    "every feature must have a preprocessing entry".to_string(),
969                )
970            })
971        })
972        .collect()
973}
974
975fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
976    ir.output_schema
977        .class_order
978        .clone()
979        .or_else(|| ir.training_metadata.class_labels.clone())
980        .ok_or(IrError::MissingClassLabels)
981}
982
983fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
984    match leaf {
985        LeafPayload::ClassIndex {
986            class_index,
987            class_value,
988        } => {
989            let Some(expected) = class_labels.get(*class_index) else {
990                return Err(IrError::InvalidLeaf(format!(
991                    "class index {} out of bounds",
992                    class_index
993                )));
994            };
995            if expected.total_cmp(class_value).is_ne() {
996                return Err(IrError::InvalidLeaf(format!(
997                    "class value {} does not match class order entry {}",
998                    class_value, expected
999                )));
1000            }
1001            Ok(*class_index)
1002        }
1003        LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1004            "expected class_index leaf".to_string(),
1005        )),
1006    }
1007}
1008
1009fn rebuild_classifier_nodes(
1010    nodes: Vec<NodeTreeNode>,
1011    class_labels: &[f64],
1012) -> Result<Vec<ClassifierTreeNode>, IrError> {
1013    let mut rebuilt = vec![None; nodes.len()];
1014    for node in nodes {
1015        match node {
1016            NodeTreeNode::Leaf {
1017                node_id,
1018                leaf,
1019                stats,
1020                ..
1021            } => {
1022                let class_index = classifier_class_index(&leaf, class_labels)?;
1023                assign_node(
1024                    &mut rebuilt,
1025                    node_id,
1026                    ClassifierTreeNode::Leaf {
1027                        class_index,
1028                        sample_count: stats.sample_count,
1029                        class_counts: stats
1030                            .class_counts
1031                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1032                    },
1033                )?;
1034            }
1035            NodeTreeNode::BinaryBranch {
1036                node_id,
1037                split,
1038                children,
1039                stats,
1040                ..
1041            } => {
1042                let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1043                assign_node(
1044                    &mut rebuilt,
1045                    node_id,
1046                    ClassifierTreeNode::BinarySplit {
1047                        feature_index,
1048                        threshold_bin,
1049                        left_child: children.left,
1050                        right_child: children.right,
1051                        sample_count: stats.sample_count,
1052                        impurity: stats.impurity.unwrap_or(0.0),
1053                        gain: stats.gain.unwrap_or(0.0),
1054                        class_counts: stats
1055                            .class_counts
1056                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1057                    },
1058                )?;
1059            }
1060            NodeTreeNode::MultiwayBranch {
1061                node_id,
1062                split,
1063                branches,
1064                unmatched_leaf,
1065                stats,
1066                ..
1067            } => {
1068                let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1069                assign_node(
1070                    &mut rebuilt,
1071                    node_id,
1072                    ClassifierTreeNode::MultiwaySplit {
1073                        feature_index: split.feature_index,
1074                        fallback_class_index,
1075                        branches: branches
1076                            .into_iter()
1077                            .map(|branch| (branch.bin, branch.child))
1078                            .collect(),
1079                        sample_count: stats.sample_count,
1080                        impurity: stats.impurity.unwrap_or(0.0),
1081                        gain: stats.gain.unwrap_or(0.0),
1082                        class_counts: stats
1083                            .class_counts
1084                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1085                    },
1086                )?;
1087            }
1088        }
1089    }
1090    collect_nodes(rebuilt)
1091}
1092
1093fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1094    let mut rebuilt = vec![None; nodes.len()];
1095    for node in nodes {
1096        match node {
1097            NodeTreeNode::Leaf {
1098                node_id,
1099                leaf: LeafPayload::RegressionValue { value },
1100                stats,
1101                ..
1102            } => {
1103                assign_node(
1104                    &mut rebuilt,
1105                    node_id,
1106                    RegressionNode::Leaf {
1107                        value,
1108                        sample_count: stats.sample_count,
1109                        variance: stats.variance,
1110                    },
1111                )?;
1112            }
1113            NodeTreeNode::Leaf { .. } => {
1114                return Err(IrError::InvalidLeaf(
1115                    "regression trees require regression_value leaves".to_string(),
1116                ));
1117            }
1118            NodeTreeNode::BinaryBranch {
1119                node_id,
1120                split,
1121                children,
1122                stats,
1123                ..
1124            } => {
1125                let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1126                assign_node(
1127                    &mut rebuilt,
1128                    node_id,
1129                    RegressionNode::BinarySplit {
1130                        feature_index,
1131                        threshold_bin,
1132                        left_child: children.left,
1133                        right_child: children.right,
1134                        sample_count: stats.sample_count,
1135                        impurity: stats.impurity.unwrap_or(0.0),
1136                        gain: stats.gain.unwrap_or(0.0),
1137                        variance: stats.variance,
1138                    },
1139                )?;
1140            }
1141            NodeTreeNode::MultiwayBranch { .. } => {
1142                return Err(IrError::InvalidNode(
1143                    "regression trees do not support multiway branches".to_string(),
1144                ));
1145            }
1146        }
1147    }
1148    collect_nodes(rebuilt)
1149}
1150
1151fn rebuild_classifier_oblivious_splits(
1152    levels: Vec<ObliviousLevel>,
1153) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1154    let mut rebuilt = Vec::with_capacity(levels.len());
1155    for level in levels {
1156        rebuilt.push(match level.split {
1157            ObliviousSplit::NumericBinThreshold {
1158                feature_index,
1159                threshold_bin,
1160                ..
1161            } => ClassifierObliviousSplit {
1162                feature_index,
1163                threshold_bin,
1164                sample_count: level.stats.sample_count,
1165                impurity: level.stats.impurity.unwrap_or(0.0),
1166                gain: level.stats.gain.unwrap_or(0.0),
1167            },
1168            ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1169                feature_index,
1170                threshold_bin: 0,
1171                sample_count: level.stats.sample_count,
1172                impurity: level.stats.impurity.unwrap_or(0.0),
1173                gain: level.stats.gain.unwrap_or(0.0),
1174            },
1175        });
1176    }
1177    Ok(rebuilt)
1178}
1179
1180fn rebuild_regressor_oblivious_splits(
1181    levels: Vec<ObliviousLevel>,
1182) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1183    let mut rebuilt = Vec::with_capacity(levels.len());
1184    for level in levels {
1185        rebuilt.push(match level.split {
1186            ObliviousSplit::NumericBinThreshold {
1187                feature_index,
1188                threshold_bin,
1189                ..
1190            } => RegressorObliviousSplit {
1191                feature_index,
1192                threshold_bin,
1193                sample_count: level.stats.sample_count,
1194                impurity: level.stats.impurity.unwrap_or(0.0),
1195                gain: level.stats.gain.unwrap_or(0.0),
1196            },
1197            ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1198                feature_index,
1199                threshold_bin: 0,
1200                sample_count: level.stats.sample_count,
1201                impurity: level.stats.impurity.unwrap_or(0.0),
1202                gain: level.stats.gain.unwrap_or(0.0),
1203            },
1204        });
1205    }
1206    Ok(rebuilt)
1207}
1208
1209fn rebuild_classifier_leaf_indices(
1210    leaves: Vec<IndexedLeaf>,
1211    class_labels: &[f64],
1212) -> Result<Vec<usize>, IrError> {
1213    let mut rebuilt = vec![None; leaves.len()];
1214    for indexed_leaf in leaves {
1215        let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1216        assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1217    }
1218    collect_nodes(rebuilt)
1219}
1220
1221fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1222    let mut rebuilt = vec![None; leaves.len()];
1223    for indexed_leaf in leaves {
1224        let value = match indexed_leaf.leaf {
1225            LeafPayload::RegressionValue { value } => value,
1226            LeafPayload::ClassIndex { .. } => {
1227                return Err(IrError::InvalidLeaf(
1228                    "regression oblivious leaves require regression_value".to_string(),
1229                ));
1230            }
1231        };
1232        assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1233    }
1234    collect_nodes(rebuilt)
1235}
1236
1237fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1238    let mut rebuilt = vec![None; leaves.len()];
1239    for indexed_leaf in leaves {
1240        assign_node(
1241            &mut rebuilt,
1242            indexed_leaf.leaf_index,
1243            indexed_leaf.stats.sample_count,
1244        )?;
1245    }
1246    collect_nodes(rebuilt)
1247}
1248
1249fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1250    let mut rebuilt = vec![None; leaves.len()];
1251    for indexed_leaf in leaves {
1252        assign_node(
1253            &mut rebuilt,
1254            indexed_leaf.leaf_index,
1255            indexed_leaf.stats.variance,
1256        )?;
1257    }
1258    collect_nodes(rebuilt)
1259}
1260
1261fn rebuild_classifier_leaf_class_counts(
1262    leaves: &[IndexedLeaf],
1263    num_classes: usize,
1264) -> Result<Vec<Vec<usize>>, IrError> {
1265    let mut rebuilt = vec![None; leaves.len()];
1266    for indexed_leaf in leaves {
1267        assign_node(
1268            &mut rebuilt,
1269            indexed_leaf.leaf_index,
1270            indexed_leaf
1271                .stats
1272                .class_counts
1273                .clone()
1274                .unwrap_or_else(|| vec![0; num_classes]),
1275        )?;
1276    }
1277    collect_nodes(rebuilt)
1278}
1279
1280fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1281    match split {
1282        BinarySplit::NumericBinThreshold {
1283            feature_index,
1284            threshold_bin,
1285            ..
1286        } => Ok((feature_index, threshold_bin)),
1287        BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1288    }
1289}
1290
1291fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1292    classifier_binary_split(split)
1293}
1294
1295fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1296    let Some(slot) = slots.get_mut(index) else {
1297        return Err(IrError::InvalidNode(format!(
1298            "node index {} is out of bounds",
1299            index
1300        )));
1301    };
1302    if slot.is_some() {
1303        return Err(IrError::InvalidNode(format!(
1304            "duplicate node index {}",
1305            index
1306        )));
1307    }
1308    *slot = Some(value);
1309    Ok(())
1310}
1311
1312fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1313    slots
1314        .into_iter()
1315        .enumerate()
1316        .map(|(index, slot)| {
1317            slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1318        })
1319        .collect()
1320}
1321
1322fn input_schema(model: &Model) -> InputSchema {
1323    let features = model
1324        .feature_preprocessing()
1325        .iter()
1326        .enumerate()
1327        .map(|(feature_index, preprocessing)| {
1328            let kind = match preprocessing {
1329                FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1330                FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1331            };
1332
1333            InputFeature {
1334                index: feature_index,
1335                name: feature_name(feature_index),
1336                dtype: match kind {
1337                    InputFeatureKind::Numeric => "float64".to_string(),
1338                    InputFeatureKind::Binary => "bool".to_string(),
1339                },
1340                logical_type: match kind {
1341                    InputFeatureKind::Numeric => "numeric".to_string(),
1342                    InputFeatureKind::Binary => "boolean".to_string(),
1343                },
1344                nullable: false,
1345            }
1346        })
1347        .collect();
1348
1349    InputSchema {
1350        feature_count: model.num_features(),
1351        features,
1352        ordering: "strict_index_order".to_string(),
1353        input_tensor_layout: "row_major".to_string(),
1354        accepts_feature_names: false,
1355    }
1356}
1357
1358fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1359    match model.task() {
1360        Task::Regression => OutputSchema {
1361            raw_outputs: vec![OutputField {
1362                name: "value".to_string(),
1363                kind: "regression_value".to_string(),
1364                shape: Vec::new(),
1365                dtype: "float64".to_string(),
1366            }],
1367            final_outputs: vec![OutputField {
1368                name: "prediction".to_string(),
1369                kind: "value".to_string(),
1370                shape: Vec::new(),
1371                dtype: "float64".to_string(),
1372            }],
1373            class_order: None,
1374        },
1375        Task::Classification => OutputSchema {
1376            raw_outputs: vec![OutputField {
1377                name: "class_index".to_string(),
1378                kind: "class_index".to_string(),
1379                shape: Vec::new(),
1380                dtype: "uint64".to_string(),
1381            }],
1382            final_outputs: vec![OutputField {
1383                name: "predicted_class".to_string(),
1384                kind: "class_label".to_string(),
1385                shape: Vec::new(),
1386                dtype: "float64".to_string(),
1387            }],
1388            class_order: class_labels,
1389        },
1390    }
1391}
1392
1393fn preprocessing(model: &Model) -> PreprocessingSection {
1394    let features = model
1395        .feature_preprocessing()
1396        .iter()
1397        .enumerate()
1398        .map(|(feature_index, preprocessing)| match preprocessing {
1399            FeaturePreprocessing::Numeric { bin_boundaries } => FeatureBinning::Numeric {
1400                feature_index,
1401                boundaries: bin_boundaries.clone(),
1402            },
1403            FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1404        })
1405        .collect();
1406
1407    PreprocessingSection {
1408        included_in_model: true,
1409        numeric_binning: NumericBinning {
1410            kind: "rank_bin_128".to_string(),
1411            features,
1412        },
1413        notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1414            .to_string(),
1415    }
1416}
1417
1418fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1419    match model.task() {
1420        Task::Regression => PostprocessingSection {
1421            raw_output_kind: "regression_value".to_string(),
1422            steps: vec![PostprocessingStep::Identity],
1423        },
1424        Task::Classification => PostprocessingSection {
1425            raw_output_kind: "class_index".to_string(),
1426            steps: vec![PostprocessingStep::MapClassIndexToLabel {
1427                labels: class_labels.expect("classification IR requires class labels"),
1428            }],
1429        },
1430    }
1431}
1432
1433fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1434    let mut capabilities = vec![
1435        representation.to_string(),
1436        "training_rank_bin_128".to_string(),
1437    ];
1438    match model.tree_type() {
1439        TreeType::Id3 | TreeType::C45 => {
1440            capabilities.push("binned_multiway_splits".to_string());
1441        }
1442        TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1443            capabilities.push("numeric_bin_threshold_splits".to_string());
1444        }
1445    }
1446    if model
1447        .feature_preprocessing()
1448        .iter()
1449        .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1450    {
1451        capabilities.push("boolean_features".to_string());
1452    }
1453    match model.task() {
1454        Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1455        Task::Classification => capabilities.push("class_index_leaves".to_string()),
1456    }
1457    capabilities
1458}
1459
1460pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1461    match algorithm {
1462        TrainAlgorithm::Dt => "dt",
1463        TrainAlgorithm::Rf => "rf",
1464        TrainAlgorithm::Gbm => "gbm",
1465    }
1466}
1467
1468fn model_tree_definition(model: &Model) -> TreeDefinition {
1469    match model {
1470        Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1471        Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1472        Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1473            unreachable!("ensemble IR expands into member trees")
1474        }
1475    }
1476}
1477
1478pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1479    match criterion {
1480        crate::Criterion::Auto => "auto",
1481        crate::Criterion::Gini => "gini",
1482        crate::Criterion::Entropy => "entropy",
1483        crate::Criterion::Mean => "mean",
1484        crate::Criterion::Median => "median",
1485        crate::Criterion::SecondOrder => "second_order",
1486    }
1487}
1488
1489pub(crate) fn task_name(task: Task) -> &'static str {
1490    match task {
1491        Task::Regression => "regression",
1492        Task::Classification => "classification",
1493    }
1494}
1495
1496pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1497    match tree_type {
1498        TreeType::Id3 => "id3",
1499        TreeType::C45 => "c45",
1500        TreeType::Cart => "cart",
1501        TreeType::Randomized => "randomized",
1502        TreeType::Oblivious => "oblivious",
1503    }
1504}
1505
1506pub(crate) fn feature_name(feature_index: usize) -> String {
1507    format!("f{}", feature_index)
1508}
1509
1510pub(crate) fn threshold_upper_bound(
1511    preprocessing: &[FeaturePreprocessing],
1512    feature_index: usize,
1513    threshold_bin: u16,
1514) -> Option<f64> {
1515    match preprocessing.get(feature_index)? {
1516        FeaturePreprocessing::Numeric { bin_boundaries } => bin_boundaries
1517            .iter()
1518            .find(|boundary| boundary.bin == threshold_bin)
1519            .map(|boundary| boundary.upper_bound),
1520        FeaturePreprocessing::Binary => None,
1521    }
1522}