Skip to main content

forestfire_core/
ir.rs

1//! Stable intermediate representation for ForestFire models.
2//!
3//! The IR sits between two worlds:
4//!
5//! - the semantic in-memory model types used for training and introspection
6//! - the lowered runtime structures used by optimized inference
7//!
8//! It exists so models can be serialized, schema-checked, inspected from other
9//! languages, and reconstructed without depending on the exact Rust memory
10//! layout of the training structs.
11
12use crate::tree::classifier::{
13    DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14    ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15    TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18    DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19    RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22    Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23    NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33/// Top-level model package serialized by the library.
34#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36    pub ir_version: String,
37    pub format_name: String,
38    pub producer: ProducerMetadata,
39    pub model: ModelSection,
40    pub input_schema: InputSchema,
41    pub output_schema: OutputSchema,
42    pub inference_options: InferenceOptions,
43    pub preprocessing: PreprocessingSection,
44    pub postprocessing: PostprocessingSection,
45    pub training_metadata: TrainingMetadata,
46    pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51    pub library: String,
52    pub library_version: String,
53    pub language: String,
54    pub platform: String,
55}
56
57/// Structural model description independent of any concrete runtime layout.
58#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60    pub algorithm: String,
61    pub task: String,
62    pub tree_type: String,
63    pub representation: String,
64    pub num_features: usize,
65    pub num_outputs: usize,
66    pub supports_missing: bool,
67    pub supports_categorical: bool,
68    pub is_ensemble: bool,
69    pub trees: Vec<TreeDefinition>,
70    pub aggregation: Aggregation,
71}
72
73/// Concrete tree payload stored in the IR.
74#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77    NodeTree {
78        tree_id: usize,
79        weight: f64,
80        root_node_id: usize,
81        nodes: Vec<NodeTreeNode>,
82    },
83    ObliviousLevels {
84        tree_id: usize,
85        weight: f64,
86        depth: usize,
87        levels: Vec<ObliviousLevel>,
88        leaf_indexing: LeafIndexing,
89        leaves: Vec<IndexedLeaf>,
90    },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96    Leaf {
97        node_id: usize,
98        depth: usize,
99        leaf: LeafPayload,
100        stats: NodeStats,
101    },
102    BinaryBranch {
103        node_id: usize,
104        depth: usize,
105        split: BinarySplit,
106        children: BinaryChildren,
107        stats: NodeStats,
108    },
109    MultiwayBranch {
110        node_id: usize,
111        depth: usize,
112        split: MultiwaySplit,
113        branches: Vec<MultiwayBranch>,
114        unmatched_leaf: LeafPayload,
115        stats: NodeStats,
116    },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121    pub left: usize,
122    pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127    pub bin: u16,
128    pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134    NumericBinThreshold {
135        feature_index: usize,
136        feature_name: String,
137        operator: String,
138        threshold_bin: u16,
139        threshold_upper_bound: Option<f64>,
140        comparison_dtype: String,
141    },
142    BooleanTest {
143        feature_index: usize,
144        feature_name: String,
145        false_child_semantics: String,
146        true_child_semantics: String,
147    },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152    pub split_type: String,
153    pub feature_index: usize,
154    pub feature_name: String,
155    pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160    pub level: usize,
161    pub split: ObliviousSplit,
162    pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168    NumericBinThreshold {
169        feature_index: usize,
170        feature_name: String,
171        operator: String,
172        threshold_bin: u16,
173        threshold_upper_bound: Option<f64>,
174        comparison_dtype: String,
175        bit_when_true: u8,
176        bit_when_false: u8,
177    },
178    BooleanTest {
179        feature_index: usize,
180        feature_name: String,
181        bit_when_false: u8,
182        bit_when_true: u8,
183    },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188    pub bit_order: String,
189    pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194    pub leaf_index: usize,
195    pub leaf: LeafPayload,
196    pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202    RegressionValue {
203        value: f64,
204    },
205    ClassIndex {
206        class_index: usize,
207        class_value: f64,
208    },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213    pub kind: String,
214    pub tree_weights: Vec<f64>,
215    pub normalize_by_weight_sum: bool,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222    pub feature_count: usize,
223    pub features: Vec<InputFeature>,
224    pub ordering: String,
225    pub input_tensor_layout: String,
226    pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231    pub index: usize,
232    pub name: String,
233    pub dtype: String,
234    pub logical_type: String,
235    pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240    pub raw_outputs: Vec<OutputField>,
241    pub final_outputs: Vec<OutputField>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248    pub name: String,
249    pub kind: String,
250    pub shape: Vec<usize>,
251    pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256    pub numeric_precision: String,
257    pub threshold_comparison: String,
258    pub nan_policy: String,
259    pub bool_encoding: BoolEncoding,
260    pub tie_breaking: TieBreaking,
261    pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266    pub false_values: Vec<String>,
267    pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272    pub classification: String,
273    pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278    pub guaranteed: bool,
279    pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284    pub included_in_model: bool,
285    pub numeric_binning: NumericBinning,
286    pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291    pub kind: String,
292    pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298    Numeric {
299        feature_index: usize,
300        boundaries: Vec<NumericBinBoundary>,
301    },
302    Binary {
303        feature_index: usize,
304    },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309    pub raw_output_kind: String,
310    pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316    Identity,
317    MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320/// Serialized training metadata reflected back to bindings and docs.
321#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323    pub algorithm: String,
324    pub task: String,
325    pub tree_type: String,
326    pub criterion: String,
327    pub canaries: usize,
328    pub compute_oob: bool,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub max_depth: Option<usize>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub min_samples_split: Option<usize>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub min_samples_leaf: Option<usize>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub n_trees: Option<usize>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub max_features: Option<usize>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub seed: Option<u64>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub oob_score: Option<f64>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub class_labels: Option<Vec<f64>>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub learning_rate: Option<f64>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub bootstrap: Option<bool>,
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub top_gradient_fraction: Option<f64>,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357    pub serialization: String,
358    pub canonical_json: bool,
359    pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364    pub minimum_runtime_version: String,
365    pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370    pub sample_count: usize,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub impurity: Option<f64>,
373    #[serde(skip_serializing_if = "Option::is_none")]
374    pub gain: Option<f64>,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    pub class_counts: Option<Vec<usize>>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383    UnsupportedIrVersion(String),
384    UnsupportedFormatName(String),
385    UnsupportedAlgorithm(String),
386    UnsupportedTask(String),
387    UnsupportedTreeType(String),
388    InvalidTreeCount(usize),
389    UnsupportedRepresentation(String),
390    InvalidFeatureCount { schema: usize, preprocessing: usize },
391    MissingClassLabels,
392    InvalidLeaf(String),
393    InvalidNode(String),
394    InvalidPreprocessing(String),
395    InvalidInferenceOption(String),
396    Json(String),
397}
398
399impl Display for IrError {
400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401        match self {
402            IrError::UnsupportedIrVersion(version) => {
403                write!(f, "Unsupported IR version: {}.", version)
404            }
405            IrError::UnsupportedFormatName(name) => {
406                write!(f, "Unsupported IR format: {}.", name)
407            }
408            IrError::UnsupportedAlgorithm(algorithm) => {
409                write!(f, "Unsupported algorithm: {}.", algorithm)
410            }
411            IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412            IrError::UnsupportedTreeType(tree_type) => {
413                write!(f, "Unsupported tree type: {}.", tree_type)
414            }
415            IrError::InvalidTreeCount(count) => {
416                write!(f, "Expected exactly one tree in the IR, found {}.", count)
417            }
418            IrError::UnsupportedRepresentation(representation) => {
419                write!(f, "Unsupported tree representation: {}.", representation)
420            }
421            IrError::InvalidFeatureCount {
422                schema,
423                preprocessing,
424            } => write!(
425                f,
426                "Input schema declares {} features, but preprocessing declares {}.",
427                schema, preprocessing
428            ),
429            IrError::MissingClassLabels => {
430                write!(f, "Classification IR requires explicit class labels.")
431            }
432            IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433            IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434            IrError::InvalidPreprocessing(message) => {
435                write!(f, "Invalid preprocessing section: {}.", message)
436            }
437            IrError::InvalidInferenceOption(message) => {
438                write!(f, "Invalid inference options: {}.", message)
439            }
440            IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441        }
442    }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448    pub fn json_schema() -> RootSchema {
449        schema_for!(ModelPackageIr)
450    }
451
452    pub fn json_schema_json() -> Result<String, IrError> {
453        serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454    }
455
456    pub fn json_schema_json_pretty() -> Result<String, IrError> {
457        serde_json::to_string_pretty(&Self::json_schema())
458            .map_err(|err| IrError::Json(err.to_string()))
459    }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463    let trees = match model {
464        Model::RandomForest(forest) => forest
465            .trees()
466            .iter()
467            .map(model_tree_definition)
468            .collect::<Vec<_>>(),
469        Model::GradientBoostedTrees(boosted) => boosted
470            .trees()
471            .iter()
472            .map(model_tree_definition)
473            .collect::<Vec<_>>(),
474        _ => vec![model_tree_definition(model)],
475    };
476    let representation = if let Some(first_tree) = trees.first() {
477        match first_tree {
478            TreeDefinition::NodeTree { .. } => "node_tree",
479            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480        }
481    } else {
482        match model.tree_type() {
483            TreeType::Oblivious => "oblivious_levels",
484            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485        }
486    };
487    let class_labels = model.class_labels();
488    let is_ensemble = matches!(
489        model,
490        Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491    );
492    let tree_count = trees.len();
493    let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494        Model::RandomForest(_) => (
495            match model.task() {
496                Task::Regression => "average",
497                Task::Classification => "average_class_probabilities",
498            },
499            vec![1.0; tree_count],
500            true,
501            None,
502        ),
503        Model::GradientBoostedTrees(boosted) => (
504            match boosted.task() {
505                Task::Regression => "sum_tree_outputs",
506                Task::Classification => "sum_tree_outputs_then_sigmoid",
507            },
508            boosted.tree_weights().to_vec(),
509            false,
510            Some(boosted.base_score()),
511        ),
512        _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513    };
514
515    ModelPackageIr {
516        ir_version: IR_VERSION.to_string(),
517        format_name: FORMAT_NAME.to_string(),
518        producer: ProducerMetadata {
519            library: "forestfire-core".to_string(),
520            library_version: env!("CARGO_PKG_VERSION").to_string(),
521            language: "rust".to_string(),
522            platform: std::env::consts::ARCH.to_string(),
523        },
524        model: ModelSection {
525            algorithm: algorithm_name(model.algorithm()).to_string(),
526            task: task_name(model.task()).to_string(),
527            tree_type: tree_type_name(model.tree_type()).to_string(),
528            representation: representation.to_string(),
529            num_features: model.num_features(),
530            num_outputs: 1,
531            supports_missing: false,
532            supports_categorical: false,
533            is_ensemble,
534            trees,
535            aggregation: Aggregation {
536                kind: aggregation_kind.to_string(),
537                tree_weights,
538                normalize_by_weight_sum,
539                base_score,
540            },
541        },
542        input_schema: input_schema(model),
543        output_schema: output_schema(model, class_labels.clone()),
544        inference_options: InferenceOptions {
545            numeric_precision: "float64".to_string(),
546            threshold_comparison: "leq_left_gt_right".to_string(),
547            nan_policy: "not_supported".to_string(),
548            bool_encoding: BoolEncoding {
549                false_values: vec!["0".to_string(), "false".to_string()],
550                true_values: vec!["1".to_string(), "true".to_string()],
551            },
552            tie_breaking: TieBreaking {
553                classification: "lowest_class_index".to_string(),
554                argmax: "first_max_index".to_string(),
555            },
556            determinism: Determinism {
557                guaranteed: true,
558                notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559                    .to_string(),
560            },
561        },
562        preprocessing: preprocessing(model),
563        postprocessing: postprocessing(model, class_labels),
564        training_metadata: model.training_metadata(),
565        integrity: IntegritySection {
566            serialization: "json".to_string(),
567            canonical_json: true,
568            compatibility: Compatibility {
569                minimum_runtime_version: IR_VERSION.to_string(),
570                required_capabilities: required_capabilities(model, representation),
571            },
572        },
573    }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577    validate_ir_header(&ir)?;
578    validate_inference_options(&ir.inference_options)?;
579
580    let algorithm = parse_algorithm(&ir.model.algorithm)?;
581    let task = parse_task(&ir.model.task)?;
582    let tree_type = parse_tree_type(&ir.model.tree_type)?;
583    let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584    let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585    let num_features = ir.input_schema.feature_count;
586    let options = tree_options(&ir.training_metadata);
587    let training_canaries = ir.training_metadata.canaries;
588    let deserialized_class_labels = classification_labels(&ir).ok();
589
590    if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591        return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592    }
593
594    if algorithm == TrainAlgorithm::Rf {
595        let trees = ir
596            .model
597            .trees
598            .into_iter()
599            .map(|tree| {
600                single_model_from_ir_parts(
601                    task,
602                    tree_type,
603                    criterion,
604                    feature_preprocessing.clone(),
605                    num_features,
606                    options.clone(),
607                    training_canaries,
608                    deserialized_class_labels.clone(),
609                    tree,
610                )
611            })
612            .collect::<Result<Vec<_>, _>>()?;
613        return Ok(Model::RandomForest(RandomForest::new(
614            task,
615            criterion,
616            tree_type,
617            trees,
618            ir.training_metadata.compute_oob,
619            ir.training_metadata.oob_score,
620            ir.training_metadata
621                .max_features
622                .unwrap_or(num_features.max(1)),
623            ir.training_metadata.seed,
624            num_features,
625            feature_preprocessing,
626        )));
627    }
628
629    if algorithm == TrainAlgorithm::Gbm {
630        let tree_weights = ir.model.aggregation.tree_weights.clone();
631        let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632        let trees = ir
633            .model
634            .trees
635            .into_iter()
636            .map(|tree| {
637                boosted_tree_model_from_ir_parts(
638                    tree_type,
639                    criterion,
640                    feature_preprocessing.clone(),
641                    num_features,
642                    options.clone(),
643                    training_canaries,
644                    tree,
645                )
646            })
647            .collect::<Result<Vec<_>, _>>()?;
648        return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
649            task,
650            tree_type,
651            trees,
652            tree_weights,
653            base_score,
654            ir.training_metadata.learning_rate.unwrap_or(0.1),
655            ir.training_metadata.bootstrap.unwrap_or(false),
656            ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
657            ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
658            ir.training_metadata
659                .max_features
660                .unwrap_or(num_features.max(1)),
661            ir.training_metadata.seed,
662            num_features,
663            feature_preprocessing,
664            deserialized_class_labels,
665            training_canaries,
666        )));
667    }
668
669    let tree = ir
670        .model
671        .trees
672        .into_iter()
673        .next()
674        .expect("validated single tree");
675
676    single_model_from_ir_parts(
677        task,
678        tree_type,
679        criterion,
680        feature_preprocessing,
681        num_features,
682        options,
683        training_canaries,
684        deserialized_class_labels,
685        tree,
686    )
687}
688
689fn boosted_tree_model_from_ir_parts(
690    tree_type: TreeType,
691    criterion: Criterion,
692    feature_preprocessing: Vec<FeaturePreprocessing>,
693    num_features: usize,
694    options: DecisionTreeOptions,
695    training_canaries: usize,
696    tree: TreeDefinition,
697) -> Result<Model, IrError> {
698    match (tree_type, tree) {
699        (
700            TreeType::Cart | TreeType::Randomized,
701            TreeDefinition::NodeTree {
702                nodes,
703                root_node_id,
704                ..
705            },
706        ) => Ok(Model::DecisionTreeRegressor(
707            DecisionTreeRegressor::from_ir_parts(
708                match tree_type {
709                    TreeType::Cart => RegressionTreeAlgorithm::Cart,
710                    TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
711                    _ => unreachable!(),
712                },
713                criterion,
714                RegressionTreeStructure::Standard {
715                    nodes: rebuild_regressor_nodes(nodes)?,
716                    root: root_node_id,
717                },
718                RegressionTreeOptions {
719                    max_depth: options.max_depth,
720                    min_samples_split: options.min_samples_split,
721                    min_samples_leaf: options.min_samples_leaf,
722                    max_features: None,
723                    random_seed: 0,
724                    missing_value_strategies: Vec::new(),
725                },
726                num_features,
727                feature_preprocessing,
728                training_canaries,
729            ),
730        )),
731        (TreeType::Oblivious, TreeDefinition::ObliviousLevels { levels, leaves, .. }) => {
732            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
733            let leaf_variances = rebuild_leaf_variances(&leaves)?;
734            Ok(Model::DecisionTreeRegressor(
735                DecisionTreeRegressor::from_ir_parts(
736                    RegressionTreeAlgorithm::Oblivious,
737                    criterion,
738                    RegressionTreeStructure::Oblivious {
739                        splits: rebuild_regressor_oblivious_splits(levels)?,
740                        leaf_values: rebuild_regressor_leaf_values(leaves)?,
741                        leaf_sample_counts,
742                        leaf_variances,
743                    },
744                    RegressionTreeOptions {
745                        max_depth: options.max_depth,
746                        min_samples_split: options.min_samples_split,
747                        min_samples_leaf: options.min_samples_leaf,
748                        max_features: None,
749                        random_seed: 0,
750                        missing_value_strategies: Vec::new(),
751                    },
752                    num_features,
753                    feature_preprocessing,
754                    training_canaries,
755                ),
756            ))
757        }
758        (_, tree) => Err(IrError::UnsupportedRepresentation(match tree {
759            TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
760            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
761        })),
762    }
763}
764
765#[allow(clippy::too_many_arguments)]
766fn single_model_from_ir_parts(
767    task: Task,
768    tree_type: TreeType,
769    criterion: Criterion,
770    feature_preprocessing: Vec<FeaturePreprocessing>,
771    num_features: usize,
772    options: DecisionTreeOptions,
773    training_canaries: usize,
774    deserialized_class_labels: Option<Vec<f64>>,
775    tree: TreeDefinition,
776) -> Result<Model, IrError> {
777    match (task, tree_type, tree) {
778        (
779            Task::Classification,
780            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
781            TreeDefinition::NodeTree {
782                nodes,
783                root_node_id,
784                ..
785            },
786        ) => {
787            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
788            let structure = ClassifierTreeStructure::Standard {
789                nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
790                root: root_node_id,
791            };
792            Ok(Model::DecisionTreeClassifier(
793                DecisionTreeClassifier::from_ir_parts(
794                    match tree_type {
795                        TreeType::Id3 => DecisionTreeAlgorithm::Id3,
796                        TreeType::C45 => DecisionTreeAlgorithm::C45,
797                        TreeType::Cart => DecisionTreeAlgorithm::Cart,
798                        TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
799                        TreeType::Oblivious => unreachable!(),
800                    },
801                    criterion,
802                    class_labels,
803                    structure,
804                    options,
805                    num_features,
806                    feature_preprocessing,
807                    training_canaries,
808                ),
809            ))
810        }
811        (
812            Task::Classification,
813            TreeType::Oblivious,
814            TreeDefinition::ObliviousLevels { levels, leaves, .. },
815        ) => {
816            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
817            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
818            let leaf_class_counts =
819                rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
820            let structure = ClassifierTreeStructure::Oblivious {
821                splits: rebuild_classifier_oblivious_splits(levels)?,
822                leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
823                leaf_sample_counts,
824                leaf_class_counts,
825            };
826            Ok(Model::DecisionTreeClassifier(
827                DecisionTreeClassifier::from_ir_parts(
828                    DecisionTreeAlgorithm::Oblivious,
829                    criterion,
830                    class_labels,
831                    structure,
832                    options,
833                    num_features,
834                    feature_preprocessing,
835                    training_canaries,
836                ),
837            ))
838        }
839        (
840            Task::Regression,
841            TreeType::Cart | TreeType::Randomized,
842            TreeDefinition::NodeTree {
843                nodes,
844                root_node_id,
845                ..
846            },
847        ) => Ok(Model::DecisionTreeRegressor(
848            DecisionTreeRegressor::from_ir_parts(
849                match tree_type {
850                    TreeType::Cart => RegressionTreeAlgorithm::Cart,
851                    TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
852                    _ => unreachable!(),
853                },
854                criterion,
855                RegressionTreeStructure::Standard {
856                    nodes: rebuild_regressor_nodes(nodes)?,
857                    root: root_node_id,
858                },
859                RegressionTreeOptions {
860                    max_depth: options.max_depth,
861                    min_samples_split: options.min_samples_split,
862                    min_samples_leaf: options.min_samples_leaf,
863                    max_features: None,
864                    random_seed: 0,
865                    missing_value_strategies: Vec::new(),
866                },
867                num_features,
868                feature_preprocessing,
869                training_canaries,
870            ),
871        )),
872        (
873            Task::Regression,
874            TreeType::Oblivious,
875            TreeDefinition::ObliviousLevels { levels, leaves, .. },
876        ) => {
877            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
878            let leaf_variances = rebuild_leaf_variances(&leaves)?;
879            Ok(Model::DecisionTreeRegressor(
880                DecisionTreeRegressor::from_ir_parts(
881                    RegressionTreeAlgorithm::Oblivious,
882                    criterion,
883                    RegressionTreeStructure::Oblivious {
884                        splits: rebuild_regressor_oblivious_splits(levels)?,
885                        leaf_values: rebuild_regressor_leaf_values(leaves)?,
886                        leaf_sample_counts,
887                        leaf_variances,
888                    },
889                    RegressionTreeOptions {
890                        max_depth: options.max_depth,
891                        min_samples_split: options.min_samples_split,
892                        min_samples_leaf: options.min_samples_leaf,
893                        max_features: None,
894                        random_seed: 0,
895                        missing_value_strategies: Vec::new(),
896                    },
897                    num_features,
898                    feature_preprocessing,
899                    training_canaries,
900                ),
901            ))
902        }
903        (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
904            TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
905            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
906        })),
907    }
908}
909
910fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
911    if ir.ir_version != IR_VERSION {
912        return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
913    }
914    if ir.format_name != FORMAT_NAME {
915        return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
916    }
917    if ir.model.supports_missing {
918        return Err(IrError::InvalidInferenceOption(
919            "missing values are not supported in IR v1".to_string(),
920        ));
921    }
922    if ir.model.supports_categorical {
923        return Err(IrError::InvalidInferenceOption(
924            "categorical features are not supported in IR v1".to_string(),
925        ));
926    }
927    Ok(())
928}
929
930fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
931    if options.threshold_comparison != "leq_left_gt_right" {
932        return Err(IrError::InvalidInferenceOption(format!(
933            "unsupported threshold comparison '{}'",
934            options.threshold_comparison
935        )));
936    }
937    if options.nan_policy != "not_supported" {
938        return Err(IrError::InvalidInferenceOption(format!(
939            "unsupported nan policy '{}'",
940            options.nan_policy
941        )));
942    }
943    Ok(())
944}
945
946fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
947    match value {
948        "dt" => Ok(TrainAlgorithm::Dt),
949        "rf" => Ok(TrainAlgorithm::Rf),
950        "gbm" => Ok(TrainAlgorithm::Gbm),
951        _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
952    }
953}
954
955fn parse_task(value: &str) -> Result<Task, IrError> {
956    match value {
957        "regression" => Ok(Task::Regression),
958        "classification" => Ok(Task::Classification),
959        _ => Err(IrError::UnsupportedTask(value.to_string())),
960    }
961}
962
963fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
964    match value {
965        "id3" => Ok(TreeType::Id3),
966        "c45" => Ok(TreeType::C45),
967        "cart" => Ok(TreeType::Cart),
968        "randomized" => Ok(TreeType::Randomized),
969        "oblivious" => Ok(TreeType::Oblivious),
970        _ => Err(IrError::UnsupportedTreeType(value.to_string())),
971    }
972}
973
974fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
975    match value {
976        "gini" => Ok(crate::Criterion::Gini),
977        "entropy" => Ok(crate::Criterion::Entropy),
978        "mean" => Ok(crate::Criterion::Mean),
979        "median" => Ok(crate::Criterion::Median),
980        "second_order" => Ok(crate::Criterion::SecondOrder),
981        "auto" => Ok(crate::Criterion::Auto),
982        _ => Err(IrError::InvalidInferenceOption(format!(
983            "unsupported criterion '{}'",
984            value
985        ))),
986    }
987}
988
989fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
990    DecisionTreeOptions {
991        max_depth: training.max_depth.unwrap_or(8),
992        min_samples_split: training.min_samples_split.unwrap_or(2),
993        min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
994        max_features: None,
995        random_seed: 0,
996        missing_value_strategies: Vec::new(),
997    }
998}
999
1000fn feature_preprocessing_from_ir(
1001    ir: &ModelPackageIr,
1002) -> Result<Vec<FeaturePreprocessing>, IrError> {
1003    let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
1004
1005    for feature in &ir.preprocessing.numeric_binning.features {
1006        match feature {
1007            FeatureBinning::Numeric {
1008                feature_index,
1009                boundaries,
1010            } => {
1011                let slot = features.get_mut(*feature_index).ok_or_else(|| {
1012                    IrError::InvalidFeatureCount {
1013                        schema: ir.input_schema.feature_count,
1014                        preprocessing: feature_index + 1,
1015                    }
1016                })?;
1017                *slot = Some(FeaturePreprocessing::Numeric {
1018                    bin_boundaries: boundaries.clone(),
1019                    missing_bin: boundaries
1020                        .iter()
1021                        .map(|boundary| boundary.bin)
1022                        .max()
1023                        .map_or(0, |bin| bin.saturating_add(1)),
1024                });
1025            }
1026            FeatureBinning::Binary { feature_index } => {
1027                let slot = features.get_mut(*feature_index).ok_or_else(|| {
1028                    IrError::InvalidFeatureCount {
1029                        schema: ir.input_schema.feature_count,
1030                        preprocessing: feature_index + 1,
1031                    }
1032                })?;
1033                *slot = Some(FeaturePreprocessing::Binary);
1034            }
1035        }
1036    }
1037
1038    if features.len() != ir.input_schema.feature_count {
1039        return Err(IrError::InvalidFeatureCount {
1040            schema: ir.input_schema.feature_count,
1041            preprocessing: features.len(),
1042        });
1043    }
1044
1045    features
1046        .into_iter()
1047        .map(|feature| {
1048            feature.ok_or_else(|| {
1049                IrError::InvalidPreprocessing(
1050                    "every feature must have a preprocessing entry".to_string(),
1051                )
1052            })
1053        })
1054        .collect()
1055}
1056
1057fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
1058    ir.output_schema
1059        .class_order
1060        .clone()
1061        .or_else(|| ir.training_metadata.class_labels.clone())
1062        .ok_or(IrError::MissingClassLabels)
1063}
1064
1065fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
1066    match leaf {
1067        LeafPayload::ClassIndex {
1068            class_index,
1069            class_value,
1070        } => {
1071            let Some(expected) = class_labels.get(*class_index) else {
1072                return Err(IrError::InvalidLeaf(format!(
1073                    "class index {} out of bounds",
1074                    class_index
1075                )));
1076            };
1077            if expected.total_cmp(class_value).is_ne() {
1078                return Err(IrError::InvalidLeaf(format!(
1079                    "class value {} does not match class order entry {}",
1080                    class_value, expected
1081                )));
1082            }
1083            Ok(*class_index)
1084        }
1085        LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1086            "expected class_index leaf".to_string(),
1087        )),
1088    }
1089}
1090
1091fn rebuild_classifier_nodes(
1092    nodes: Vec<NodeTreeNode>,
1093    class_labels: &[f64],
1094) -> Result<Vec<ClassifierTreeNode>, IrError> {
1095    let mut rebuilt = vec![None; nodes.len()];
1096    for node in nodes {
1097        match node {
1098            NodeTreeNode::Leaf {
1099                node_id,
1100                leaf,
1101                stats,
1102                ..
1103            } => {
1104                let class_index = classifier_class_index(&leaf, class_labels)?;
1105                assign_node(
1106                    &mut rebuilt,
1107                    node_id,
1108                    ClassifierTreeNode::Leaf {
1109                        class_index,
1110                        sample_count: stats.sample_count,
1111                        class_counts: stats
1112                            .class_counts
1113                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1114                    },
1115                )?;
1116            }
1117            NodeTreeNode::BinaryBranch {
1118                node_id,
1119                split,
1120                children,
1121                stats,
1122                ..
1123            } => {
1124                let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1125                assign_node(
1126                    &mut rebuilt,
1127                    node_id,
1128                    ClassifierTreeNode::BinarySplit {
1129                        feature_index,
1130                        threshold_bin,
1131                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1132                        left_child: children.left,
1133                        right_child: children.right,
1134                        sample_count: stats.sample_count,
1135                        impurity: stats.impurity.unwrap_or(0.0),
1136                        gain: stats.gain.unwrap_or(0.0),
1137                        class_counts: stats
1138                            .class_counts
1139                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1140                    },
1141                )?;
1142            }
1143            NodeTreeNode::MultiwayBranch {
1144                node_id,
1145                split,
1146                branches,
1147                unmatched_leaf,
1148                stats,
1149                ..
1150            } => {
1151                let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1152                assign_node(
1153                    &mut rebuilt,
1154                    node_id,
1155                    ClassifierTreeNode::MultiwaySplit {
1156                        feature_index: split.feature_index,
1157                        fallback_class_index,
1158                        branches: branches
1159                            .into_iter()
1160                            .map(|branch| (branch.bin, branch.child))
1161                            .collect(),
1162                        missing_child: None,
1163                        sample_count: stats.sample_count,
1164                        impurity: stats.impurity.unwrap_or(0.0),
1165                        gain: stats.gain.unwrap_or(0.0),
1166                        class_counts: stats
1167                            .class_counts
1168                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1169                    },
1170                )?;
1171            }
1172        }
1173    }
1174    collect_nodes(rebuilt)
1175}
1176
1177fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1178    let mut rebuilt = vec![None; nodes.len()];
1179    for node in nodes {
1180        match node {
1181            NodeTreeNode::Leaf {
1182                node_id,
1183                leaf: LeafPayload::RegressionValue { value },
1184                stats,
1185                ..
1186            } => {
1187                assign_node(
1188                    &mut rebuilt,
1189                    node_id,
1190                    RegressionNode::Leaf {
1191                        value,
1192                        sample_count: stats.sample_count,
1193                        variance: stats.variance,
1194                    },
1195                )?;
1196            }
1197            NodeTreeNode::Leaf { .. } => {
1198                return Err(IrError::InvalidLeaf(
1199                    "regression trees require regression_value leaves".to_string(),
1200                ));
1201            }
1202            NodeTreeNode::BinaryBranch {
1203                node_id,
1204                split,
1205                children,
1206                stats,
1207                ..
1208            } => {
1209                let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1210                assign_node(
1211                    &mut rebuilt,
1212                    node_id,
1213                    RegressionNode::BinarySplit {
1214                        feature_index,
1215                        threshold_bin,
1216                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1217                        missing_value: 0.0,
1218                        left_child: children.left,
1219                        right_child: children.right,
1220                        sample_count: stats.sample_count,
1221                        impurity: stats.impurity.unwrap_or(0.0),
1222                        gain: stats.gain.unwrap_or(0.0),
1223                        variance: stats.variance,
1224                    },
1225                )?;
1226            }
1227            NodeTreeNode::MultiwayBranch { .. } => {
1228                return Err(IrError::InvalidNode(
1229                    "regression trees do not support multiway branches".to_string(),
1230                ));
1231            }
1232        }
1233    }
1234    collect_nodes(rebuilt)
1235}
1236
1237fn rebuild_classifier_oblivious_splits(
1238    levels: Vec<ObliviousLevel>,
1239) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1240    let mut rebuilt = Vec::with_capacity(levels.len());
1241    for level in levels {
1242        rebuilt.push(match level.split {
1243            ObliviousSplit::NumericBinThreshold {
1244                feature_index,
1245                threshold_bin,
1246                ..
1247            } => ClassifierObliviousSplit {
1248                feature_index,
1249                threshold_bin,
1250                missing_directions: Vec::new(),
1251                sample_count: level.stats.sample_count,
1252                impurity: level.stats.impurity.unwrap_or(0.0),
1253                gain: level.stats.gain.unwrap_or(0.0),
1254            },
1255            ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1256                feature_index,
1257                threshold_bin: 0,
1258                missing_directions: Vec::new(),
1259                sample_count: level.stats.sample_count,
1260                impurity: level.stats.impurity.unwrap_or(0.0),
1261                gain: level.stats.gain.unwrap_or(0.0),
1262            },
1263        });
1264    }
1265    Ok(rebuilt)
1266}
1267
1268fn rebuild_regressor_oblivious_splits(
1269    levels: Vec<ObliviousLevel>,
1270) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1271    let mut rebuilt = Vec::with_capacity(levels.len());
1272    for level in levels {
1273        rebuilt.push(match level.split {
1274            ObliviousSplit::NumericBinThreshold {
1275                feature_index,
1276                threshold_bin,
1277                ..
1278            } => RegressorObliviousSplit {
1279                feature_index,
1280                threshold_bin,
1281                sample_count: level.stats.sample_count,
1282                impurity: level.stats.impurity.unwrap_or(0.0),
1283                gain: level.stats.gain.unwrap_or(0.0),
1284            },
1285            ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1286                feature_index,
1287                threshold_bin: 0,
1288                sample_count: level.stats.sample_count,
1289                impurity: level.stats.impurity.unwrap_or(0.0),
1290                gain: level.stats.gain.unwrap_or(0.0),
1291            },
1292        });
1293    }
1294    Ok(rebuilt)
1295}
1296
1297fn rebuild_classifier_leaf_indices(
1298    leaves: Vec<IndexedLeaf>,
1299    class_labels: &[f64],
1300) -> Result<Vec<usize>, IrError> {
1301    let mut rebuilt = vec![None; leaves.len()];
1302    for indexed_leaf in leaves {
1303        let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1304        assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1305    }
1306    collect_nodes(rebuilt)
1307}
1308
1309fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1310    let mut rebuilt = vec![None; leaves.len()];
1311    for indexed_leaf in leaves {
1312        let value = match indexed_leaf.leaf {
1313            LeafPayload::RegressionValue { value } => value,
1314            LeafPayload::ClassIndex { .. } => {
1315                return Err(IrError::InvalidLeaf(
1316                    "regression oblivious leaves require regression_value".to_string(),
1317                ));
1318            }
1319        };
1320        assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1321    }
1322    collect_nodes(rebuilt)
1323}
1324
1325fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1326    let mut rebuilt = vec![None; leaves.len()];
1327    for indexed_leaf in leaves {
1328        assign_node(
1329            &mut rebuilt,
1330            indexed_leaf.leaf_index,
1331            indexed_leaf.stats.sample_count,
1332        )?;
1333    }
1334    collect_nodes(rebuilt)
1335}
1336
1337fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1338    let mut rebuilt = vec![None; leaves.len()];
1339    for indexed_leaf in leaves {
1340        assign_node(
1341            &mut rebuilt,
1342            indexed_leaf.leaf_index,
1343            indexed_leaf.stats.variance,
1344        )?;
1345    }
1346    collect_nodes(rebuilt)
1347}
1348
1349fn rebuild_classifier_leaf_class_counts(
1350    leaves: &[IndexedLeaf],
1351    num_classes: usize,
1352) -> Result<Vec<Vec<usize>>, IrError> {
1353    let mut rebuilt = vec![None; leaves.len()];
1354    for indexed_leaf in leaves {
1355        assign_node(
1356            &mut rebuilt,
1357            indexed_leaf.leaf_index,
1358            indexed_leaf
1359                .stats
1360                .class_counts
1361                .clone()
1362                .unwrap_or_else(|| vec![0; num_classes]),
1363        )?;
1364    }
1365    collect_nodes(rebuilt)
1366}
1367
1368fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1369    match split {
1370        BinarySplit::NumericBinThreshold {
1371            feature_index,
1372            threshold_bin,
1373            ..
1374        } => Ok((feature_index, threshold_bin)),
1375        BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1376    }
1377}
1378
1379fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1380    classifier_binary_split(split)
1381}
1382
1383fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1384    let Some(slot) = slots.get_mut(index) else {
1385        return Err(IrError::InvalidNode(format!(
1386            "node index {} is out of bounds",
1387            index
1388        )));
1389    };
1390    if slot.is_some() {
1391        return Err(IrError::InvalidNode(format!(
1392            "duplicate node index {}",
1393            index
1394        )));
1395    }
1396    *slot = Some(value);
1397    Ok(())
1398}
1399
1400fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1401    slots
1402        .into_iter()
1403        .enumerate()
1404        .map(|(index, slot)| {
1405            slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1406        })
1407        .collect()
1408}
1409
1410fn input_schema(model: &Model) -> InputSchema {
1411    let features = model
1412        .feature_preprocessing()
1413        .iter()
1414        .enumerate()
1415        .map(|(feature_index, preprocessing)| {
1416            let kind = match preprocessing {
1417                FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1418                FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1419            };
1420
1421            InputFeature {
1422                index: feature_index,
1423                name: feature_name(feature_index),
1424                dtype: match kind {
1425                    InputFeatureKind::Numeric => "float64".to_string(),
1426                    InputFeatureKind::Binary => "bool".to_string(),
1427                },
1428                logical_type: match kind {
1429                    InputFeatureKind::Numeric => "numeric".to_string(),
1430                    InputFeatureKind::Binary => "boolean".to_string(),
1431                },
1432                nullable: false,
1433            }
1434        })
1435        .collect();
1436
1437    InputSchema {
1438        feature_count: model.num_features(),
1439        features,
1440        ordering: "strict_index_order".to_string(),
1441        input_tensor_layout: "row_major".to_string(),
1442        accepts_feature_names: false,
1443    }
1444}
1445
1446fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1447    match model.task() {
1448        Task::Regression => OutputSchema {
1449            raw_outputs: vec![OutputField {
1450                name: "value".to_string(),
1451                kind: "regression_value".to_string(),
1452                shape: Vec::new(),
1453                dtype: "float64".to_string(),
1454            }],
1455            final_outputs: vec![OutputField {
1456                name: "prediction".to_string(),
1457                kind: "value".to_string(),
1458                shape: Vec::new(),
1459                dtype: "float64".to_string(),
1460            }],
1461            class_order: None,
1462        },
1463        Task::Classification => OutputSchema {
1464            raw_outputs: vec![OutputField {
1465                name: "class_index".to_string(),
1466                kind: "class_index".to_string(),
1467                shape: Vec::new(),
1468                dtype: "uint64".to_string(),
1469            }],
1470            final_outputs: vec![OutputField {
1471                name: "predicted_class".to_string(),
1472                kind: "class_label".to_string(),
1473                shape: Vec::new(),
1474                dtype: "float64".to_string(),
1475            }],
1476            class_order: class_labels,
1477        },
1478    }
1479}
1480
1481fn preprocessing(model: &Model) -> PreprocessingSection {
1482    let features = model
1483        .feature_preprocessing()
1484        .iter()
1485        .enumerate()
1486        .map(|(feature_index, preprocessing)| match preprocessing {
1487            FeaturePreprocessing::Numeric { bin_boundaries, .. } => FeatureBinning::Numeric {
1488                feature_index,
1489                boundaries: bin_boundaries.clone(),
1490            },
1491            FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1492        })
1493        .collect();
1494
1495    PreprocessingSection {
1496        included_in_model: true,
1497        numeric_binning: NumericBinning {
1498            kind: "rank_bin_128".to_string(),
1499            features,
1500        },
1501        notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1502            .to_string(),
1503    }
1504}
1505
1506fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1507    match model.task() {
1508        Task::Regression => PostprocessingSection {
1509            raw_output_kind: "regression_value".to_string(),
1510            steps: vec![PostprocessingStep::Identity],
1511        },
1512        Task::Classification => PostprocessingSection {
1513            raw_output_kind: "class_index".to_string(),
1514            steps: vec![PostprocessingStep::MapClassIndexToLabel {
1515                labels: class_labels.expect("classification IR requires class labels"),
1516            }],
1517        },
1518    }
1519}
1520
1521fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1522    let mut capabilities = vec![
1523        representation.to_string(),
1524        "training_rank_bin_128".to_string(),
1525    ];
1526    match model.tree_type() {
1527        TreeType::Id3 | TreeType::C45 => {
1528            capabilities.push("binned_multiway_splits".to_string());
1529        }
1530        TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1531            capabilities.push("numeric_bin_threshold_splits".to_string());
1532        }
1533    }
1534    if model
1535        .feature_preprocessing()
1536        .iter()
1537        .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1538    {
1539        capabilities.push("boolean_features".to_string());
1540    }
1541    match model.task() {
1542        Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1543        Task::Classification => capabilities.push("class_index_leaves".to_string()),
1544    }
1545    capabilities
1546}
1547
1548pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1549    match algorithm {
1550        TrainAlgorithm::Dt => "dt",
1551        TrainAlgorithm::Rf => "rf",
1552        TrainAlgorithm::Gbm => "gbm",
1553    }
1554}
1555
1556fn model_tree_definition(model: &Model) -> TreeDefinition {
1557    match model {
1558        Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1559        Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1560        Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1561            unreachable!("ensemble IR expands into member trees")
1562        }
1563    }
1564}
1565
1566pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1567    match criterion {
1568        crate::Criterion::Auto => "auto",
1569        crate::Criterion::Gini => "gini",
1570        crate::Criterion::Entropy => "entropy",
1571        crate::Criterion::Mean => "mean",
1572        crate::Criterion::Median => "median",
1573        crate::Criterion::SecondOrder => "second_order",
1574    }
1575}
1576
1577pub(crate) fn task_name(task: Task) -> &'static str {
1578    match task {
1579        Task::Regression => "regression",
1580        Task::Classification => "classification",
1581    }
1582}
1583
1584pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1585    match tree_type {
1586        TreeType::Id3 => "id3",
1587        TreeType::C45 => "c45",
1588        TreeType::Cart => "cart",
1589        TreeType::Randomized => "randomized",
1590        TreeType::Oblivious => "oblivious",
1591    }
1592}
1593
1594pub(crate) fn feature_name(feature_index: usize) -> String {
1595    format!("f{}", feature_index)
1596}
1597
1598pub(crate) fn threshold_upper_bound(
1599    preprocessing: &[FeaturePreprocessing],
1600    feature_index: usize,
1601    threshold_bin: u16,
1602) -> Option<f64> {
1603    match preprocessing.get(feature_index)? {
1604        FeaturePreprocessing::Numeric { bin_boundaries, .. } => bin_boundaries
1605            .iter()
1606            .find(|boundary| boundary.bin == threshold_bin)
1607            .map(|boundary| boundary.upper_bound),
1608        FeaturePreprocessing::Binary => None,
1609    }
1610}