Skip to main content

forestfire_core/
ir.rs

1//! Stable intermediate representation for ForestFire models.
2//!
3//! The IR sits between two worlds:
4//!
5//! - the semantic in-memory model types used for training and introspection
6//! - the lowered runtime structures used by optimized inference
7//!
8//! It exists so models can be serialized, schema-checked, inspected from other
9//! languages, and reconstructed without depending on the exact Rust memory
10//! layout of the training structs.
11
12use crate::tree::classifier::{
13    DecisionTreeAlgorithm, DecisionTreeClassifier, DecisionTreeOptions,
14    ObliviousSplit as ClassifierObliviousSplit, TreeNode as ClassifierTreeNode,
15    TreeStructure as ClassifierTreeStructure,
16};
17use crate::tree::regressor::{
18    DecisionTreeRegressor, ObliviousSplit as RegressorObliviousSplit, RegressionNode,
19    RegressionTreeAlgorithm, RegressionTreeOptions, RegressionTreeStructure,
20};
21use crate::{
22    Criterion, FeaturePreprocessing, GradientBoostedTrees, InputFeatureKind, Model,
23    NumericBinBoundary, RandomForest, Task, TrainAlgorithm, TreeType,
24};
25use schemars::schema::RootSchema;
26use schemars::{JsonSchema, schema_for};
27use serde::{Deserialize, Serialize};
28use std::fmt::{Display, Formatter};
29
30const IR_VERSION: &str = "1.0.0";
31const FORMAT_NAME: &str = "forestfire-ir";
32
33/// Top-level model package serialized by the library.
34#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
35pub struct ModelPackageIr {
36    pub ir_version: String,
37    pub format_name: String,
38    pub producer: ProducerMetadata,
39    pub model: ModelSection,
40    pub input_schema: InputSchema,
41    pub output_schema: OutputSchema,
42    pub inference_options: InferenceOptions,
43    pub preprocessing: PreprocessingSection,
44    pub postprocessing: PostprocessingSection,
45    pub training_metadata: TrainingMetadata,
46    pub integrity: IntegritySection,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50pub struct ProducerMetadata {
51    pub library: String,
52    pub library_version: String,
53    pub language: String,
54    pub platform: String,
55}
56
57/// Structural model description independent of any concrete runtime layout.
58#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
59pub struct ModelSection {
60    pub algorithm: String,
61    pub task: String,
62    pub tree_type: String,
63    pub representation: String,
64    pub num_features: usize,
65    pub num_outputs: usize,
66    pub supports_missing: bool,
67    pub supports_categorical: bool,
68    pub is_ensemble: bool,
69    pub trees: Vec<TreeDefinition>,
70    pub aggregation: Aggregation,
71}
72
73/// Concrete tree payload stored in the IR.
74#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
75#[serde(tag = "representation", rename_all = "snake_case")]
76pub enum TreeDefinition {
77    NodeTree {
78        tree_id: usize,
79        weight: f64,
80        root_node_id: usize,
81        nodes: Vec<NodeTreeNode>,
82    },
83    ObliviousLevels {
84        tree_id: usize,
85        weight: f64,
86        depth: usize,
87        levels: Vec<ObliviousLevel>,
88        leaf_indexing: LeafIndexing,
89        leaves: Vec<IndexedLeaf>,
90    },
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
94#[serde(tag = "kind", rename_all = "snake_case")]
95pub enum NodeTreeNode {
96    Leaf {
97        node_id: usize,
98        depth: usize,
99        leaf: LeafPayload,
100        stats: NodeStats,
101    },
102    BinaryBranch {
103        node_id: usize,
104        depth: usize,
105        split: BinarySplit,
106        children: BinaryChildren,
107        stats: NodeStats,
108    },
109    MultiwayBranch {
110        node_id: usize,
111        depth: usize,
112        split: MultiwaySplit,
113        branches: Vec<MultiwayBranch>,
114        unmatched_leaf: LeafPayload,
115        stats: NodeStats,
116    },
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
120pub struct BinaryChildren {
121    pub left: usize,
122    pub right: usize,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
126pub struct MultiwayBranch {
127    pub bin: u16,
128    pub child: usize,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132#[serde(tag = "split_type", rename_all = "snake_case")]
133pub enum BinarySplit {
134    NumericBinThreshold {
135        feature_index: usize,
136        feature_name: String,
137        operator: String,
138        threshold_bin: u16,
139        threshold_upper_bound: Option<f64>,
140        comparison_dtype: String,
141    },
142    BooleanTest {
143        feature_index: usize,
144        feature_name: String,
145        false_child_semantics: String,
146        true_child_semantics: String,
147    },
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct MultiwaySplit {
152    pub split_type: String,
153    pub feature_index: usize,
154    pub feature_name: String,
155    pub comparison_dtype: String,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
159pub struct ObliviousLevel {
160    pub level: usize,
161    pub split: ObliviousSplit,
162    pub stats: NodeStats,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
166#[serde(tag = "split_type", rename_all = "snake_case")]
167pub enum ObliviousSplit {
168    NumericBinThreshold {
169        feature_index: usize,
170        feature_name: String,
171        operator: String,
172        threshold_bin: u16,
173        threshold_upper_bound: Option<f64>,
174        comparison_dtype: String,
175        bit_when_true: u8,
176        bit_when_false: u8,
177    },
178    BooleanTest {
179        feature_index: usize,
180        feature_name: String,
181        bit_when_false: u8,
182        bit_when_true: u8,
183    },
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
187pub struct LeafIndexing {
188    pub bit_order: String,
189    pub index_formula: String,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
193pub struct IndexedLeaf {
194    pub leaf_index: usize,
195    pub leaf: LeafPayload,
196    pub stats: NodeStats,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
200#[serde(tag = "prediction_kind", rename_all = "snake_case")]
201pub enum LeafPayload {
202    RegressionValue {
203        value: f64,
204    },
205    ClassIndex {
206        class_index: usize,
207        class_value: f64,
208    },
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
212pub struct Aggregation {
213    pub kind: String,
214    pub tree_weights: Vec<f64>,
215    pub normalize_by_weight_sum: bool,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub base_score: Option<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
221pub struct InputSchema {
222    pub feature_count: usize,
223    pub features: Vec<InputFeature>,
224    pub ordering: String,
225    pub input_tensor_layout: String,
226    pub accepts_feature_names: bool,
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
230pub struct InputFeature {
231    pub index: usize,
232    pub name: String,
233    pub dtype: String,
234    pub logical_type: String,
235    pub nullable: bool,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
239pub struct OutputSchema {
240    pub raw_outputs: Vec<OutputField>,
241    pub final_outputs: Vec<OutputField>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub class_order: Option<Vec<f64>>,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
247pub struct OutputField {
248    pub name: String,
249    pub kind: String,
250    pub shape: Vec<usize>,
251    pub dtype: String,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
255pub struct InferenceOptions {
256    pub numeric_precision: String,
257    pub threshold_comparison: String,
258    pub nan_policy: String,
259    pub bool_encoding: BoolEncoding,
260    pub tie_breaking: TieBreaking,
261    pub determinism: Determinism,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
265pub struct BoolEncoding {
266    pub false_values: Vec<String>,
267    pub true_values: Vec<String>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
271pub struct TieBreaking {
272    pub classification: String,
273    pub argmax: String,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
277pub struct Determinism {
278    pub guaranteed: bool,
279    pub notes: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
283pub struct PreprocessingSection {
284    pub included_in_model: bool,
285    pub numeric_binning: NumericBinning,
286    pub notes: String,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
290pub struct NumericBinning {
291    pub kind: String,
292    pub features: Vec<FeatureBinning>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
296#[serde(tag = "kind", rename_all = "snake_case")]
297pub enum FeatureBinning {
298    Numeric {
299        feature_index: usize,
300        boundaries: Vec<NumericBinBoundary>,
301    },
302    Binary {
303        feature_index: usize,
304    },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct PostprocessingSection {
309    pub raw_output_kind: String,
310    pub steps: Vec<PostprocessingStep>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
314#[serde(tag = "op", rename_all = "snake_case")]
315pub enum PostprocessingStep {
316    Identity,
317    MapClassIndexToLabel { labels: Vec<f64> },
318}
319
320/// Serialized training metadata reflected back to bindings and docs.
321#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
322pub struct TrainingMetadata {
323    pub algorithm: String,
324    pub task: String,
325    pub tree_type: String,
326    pub criterion: String,
327    pub canaries: usize,
328    pub compute_oob: bool,
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub max_depth: Option<usize>,
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub min_samples_split: Option<usize>,
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub min_samples_leaf: Option<usize>,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub n_trees: Option<usize>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub max_features: Option<usize>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub seed: Option<u64>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub oob_score: Option<f64>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub class_labels: Option<Vec<f64>>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub learning_rate: Option<f64>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub bootstrap: Option<bool>,
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub top_gradient_fraction: Option<f64>,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub other_gradient_fraction: Option<f64>,
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
356pub struct IntegritySection {
357    pub serialization: String,
358    pub canonical_json: bool,
359    pub compatibility: Compatibility,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
363pub struct Compatibility {
364    pub minimum_runtime_version: String,
365    pub required_capabilities: Vec<String>,
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
369pub struct NodeStats {
370    pub sample_count: usize,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub impurity: Option<f64>,
373    #[serde(skip_serializing_if = "Option::is_none")]
374    pub gain: Option<f64>,
375    #[serde(skip_serializing_if = "Option::is_none")]
376    pub class_counts: Option<Vec<usize>>,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub variance: Option<f64>,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub enum IrError {
383    UnsupportedIrVersion(String),
384    UnsupportedFormatName(String),
385    UnsupportedAlgorithm(String),
386    UnsupportedTask(String),
387    UnsupportedTreeType(String),
388    InvalidTreeCount(usize),
389    UnsupportedRepresentation(String),
390    InvalidFeatureCount { schema: usize, preprocessing: usize },
391    MissingClassLabels,
392    InvalidLeaf(String),
393    InvalidNode(String),
394    InvalidPreprocessing(String),
395    InvalidInferenceOption(String),
396    Json(String),
397}
398
399impl Display for IrError {
400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401        match self {
402            IrError::UnsupportedIrVersion(version) => {
403                write!(f, "Unsupported IR version: {}.", version)
404            }
405            IrError::UnsupportedFormatName(name) => {
406                write!(f, "Unsupported IR format: {}.", name)
407            }
408            IrError::UnsupportedAlgorithm(algorithm) => {
409                write!(f, "Unsupported algorithm: {}.", algorithm)
410            }
411            IrError::UnsupportedTask(task) => write!(f, "Unsupported task: {}.", task),
412            IrError::UnsupportedTreeType(tree_type) => {
413                write!(f, "Unsupported tree type: {}.", tree_type)
414            }
415            IrError::InvalidTreeCount(count) => {
416                write!(f, "Expected exactly one tree in the IR, found {}.", count)
417            }
418            IrError::UnsupportedRepresentation(representation) => {
419                write!(f, "Unsupported tree representation: {}.", representation)
420            }
421            IrError::InvalidFeatureCount {
422                schema,
423                preprocessing,
424            } => write!(
425                f,
426                "Input schema declares {} features, but preprocessing declares {}.",
427                schema, preprocessing
428            ),
429            IrError::MissingClassLabels => {
430                write!(f, "Classification IR requires explicit class labels.")
431            }
432            IrError::InvalidLeaf(message) => write!(f, "Invalid leaf payload: {}.", message),
433            IrError::InvalidNode(message) => write!(f, "Invalid tree node: {}.", message),
434            IrError::InvalidPreprocessing(message) => {
435                write!(f, "Invalid preprocessing section: {}.", message)
436            }
437            IrError::InvalidInferenceOption(message) => {
438                write!(f, "Invalid inference options: {}.", message)
439            }
440            IrError::Json(message) => write!(f, "Invalid JSON: {}.", message),
441        }
442    }
443}
444
445impl std::error::Error for IrError {}
446
447impl ModelPackageIr {
448    pub fn json_schema() -> RootSchema {
449        schema_for!(ModelPackageIr)
450    }
451
452    pub fn json_schema_json() -> Result<String, IrError> {
453        serde_json::to_string(&Self::json_schema()).map_err(|err| IrError::Json(err.to_string()))
454    }
455
456    pub fn json_schema_json_pretty() -> Result<String, IrError> {
457        serde_json::to_string_pretty(&Self::json_schema())
458            .map_err(|err| IrError::Json(err.to_string()))
459    }
460}
461
462pub(crate) fn model_to_ir(model: &Model) -> ModelPackageIr {
463    let trees = match model {
464        Model::RandomForest(forest) => forest
465            .trees()
466            .iter()
467            .map(model_tree_definition)
468            .collect::<Vec<_>>(),
469        Model::GradientBoostedTrees(boosted) => boosted
470            .trees()
471            .iter()
472            .map(model_tree_definition)
473            .collect::<Vec<_>>(),
474        _ => vec![model_tree_definition(model)],
475    };
476    let representation = if let Some(first_tree) = trees.first() {
477        match first_tree {
478            TreeDefinition::NodeTree { .. } => "node_tree",
479            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels",
480        }
481    } else {
482        match model.tree_type() {
483            TreeType::Oblivious => "oblivious_levels",
484            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized => "node_tree",
485        }
486    };
487    let class_labels = model.class_labels();
488    let is_ensemble = matches!(
489        model,
490        Model::RandomForest(_) | Model::GradientBoostedTrees(_)
491    );
492    let tree_count = trees.len();
493    let (aggregation_kind, tree_weights, normalize_by_weight_sum, base_score) = match model {
494        Model::RandomForest(_) => (
495            match model.task() {
496                Task::Regression => "average",
497                Task::Classification => "average_class_probabilities",
498            },
499            vec![1.0; tree_count],
500            true,
501            None,
502        ),
503        Model::GradientBoostedTrees(boosted) => (
504            match boosted.task() {
505                Task::Regression => "sum_tree_outputs",
506                Task::Classification => "sum_tree_outputs_then_sigmoid",
507            },
508            boosted.tree_weights().to_vec(),
509            false,
510            Some(boosted.base_score()),
511        ),
512        _ => ("identity_single_tree", vec![1.0; tree_count], true, None),
513    };
514
515    ModelPackageIr {
516        ir_version: IR_VERSION.to_string(),
517        format_name: FORMAT_NAME.to_string(),
518        producer: ProducerMetadata {
519            library: "forestfire-core".to_string(),
520            library_version: env!("CARGO_PKG_VERSION").to_string(),
521            language: "rust".to_string(),
522            platform: std::env::consts::ARCH.to_string(),
523        },
524        model: ModelSection {
525            algorithm: algorithm_name(model.algorithm()).to_string(),
526            task: task_name(model.task()).to_string(),
527            tree_type: tree_type_name(model.tree_type()).to_string(),
528            representation: representation.to_string(),
529            num_features: model.num_features(),
530            num_outputs: 1,
531            supports_missing: false,
532            supports_categorical: false,
533            is_ensemble,
534            trees,
535            aggregation: Aggregation {
536                kind: aggregation_kind.to_string(),
537                tree_weights,
538                normalize_by_weight_sum,
539                base_score,
540            },
541        },
542        input_schema: input_schema(model),
543        output_schema: output_schema(model, class_labels.clone()),
544        inference_options: InferenceOptions {
545            numeric_precision: "float64".to_string(),
546            threshold_comparison: "leq_left_gt_right".to_string(),
547            nan_policy: "not_supported".to_string(),
548            bool_encoding: BoolEncoding {
549                false_values: vec!["0".to_string(), "false".to_string()],
550                true_values: vec!["1".to_string(), "true".to_string()],
551            },
552            tie_breaking: TieBreaking {
553                classification: "lowest_class_index".to_string(),
554                argmax: "first_max_index".to_string(),
555            },
556            determinism: Determinism {
557                guaranteed: true,
558                notes: "Inference is deterministic when the serialized preprocessing artifacts are applied before split evaluation."
559                    .to_string(),
560            },
561        },
562        preprocessing: preprocessing(model),
563        postprocessing: postprocessing(model, class_labels),
564        training_metadata: model.training_metadata(),
565        integrity: IntegritySection {
566            serialization: "json".to_string(),
567            canonical_json: true,
568            compatibility: Compatibility {
569                minimum_runtime_version: IR_VERSION.to_string(),
570                required_capabilities: required_capabilities(model, representation),
571            },
572        },
573    }
574}
575
576pub(crate) fn model_from_ir(ir: ModelPackageIr) -> Result<Model, IrError> {
577    validate_ir_header(&ir)?;
578    validate_inference_options(&ir.inference_options)?;
579
580    let algorithm = parse_algorithm(&ir.model.algorithm)?;
581    let task = parse_task(&ir.model.task)?;
582    let tree_type = parse_tree_type(&ir.model.tree_type)?;
583    let criterion = parse_criterion(&ir.training_metadata.criterion)?;
584    let feature_preprocessing = feature_preprocessing_from_ir(&ir)?;
585    let num_features = ir.input_schema.feature_count;
586    let options = tree_options(&ir.training_metadata);
587    let training_canaries = ir.training_metadata.canaries;
588    let deserialized_class_labels = classification_labels(&ir).ok();
589
590    if algorithm == TrainAlgorithm::Dt && ir.model.trees.len() != 1 {
591        return Err(IrError::InvalidTreeCount(ir.model.trees.len()));
592    }
593
594    if algorithm == TrainAlgorithm::Rf {
595        let trees = ir
596            .model
597            .trees
598            .into_iter()
599            .map(|tree| {
600                single_model_from_ir_parts(
601                    task,
602                    tree_type,
603                    criterion,
604                    feature_preprocessing.clone(),
605                    num_features,
606                    options.clone(),
607                    training_canaries,
608                    deserialized_class_labels.clone(),
609                    tree,
610                )
611            })
612            .collect::<Result<Vec<_>, _>>()?;
613        return Ok(Model::RandomForest(RandomForest::new(
614            task,
615            criterion,
616            tree_type,
617            trees,
618            ir.training_metadata.compute_oob,
619            ir.training_metadata.oob_score,
620            ir.training_metadata
621                .max_features
622                .unwrap_or(num_features.max(1)),
623            ir.training_metadata.seed,
624            num_features,
625            feature_preprocessing,
626        )));
627    }
628
629    if algorithm == TrainAlgorithm::Gbm {
630        let tree_weights = ir.model.aggregation.tree_weights.clone();
631        let base_score = ir.model.aggregation.base_score.unwrap_or(0.0);
632        let trees = ir
633            .model
634            .trees
635            .into_iter()
636            .map(|tree| {
637                boosted_tree_model_from_ir_parts(
638                    tree_type,
639                    criterion,
640                    feature_preprocessing.clone(),
641                    num_features,
642                    options.clone(),
643                    training_canaries,
644                    tree,
645                )
646            })
647            .collect::<Result<Vec<_>, _>>()?;
648        return Ok(Model::GradientBoostedTrees(GradientBoostedTrees::new(
649            task,
650            tree_type,
651            trees,
652            tree_weights,
653            base_score,
654            ir.training_metadata.learning_rate.unwrap_or(0.1),
655            ir.training_metadata.bootstrap.unwrap_or(false),
656            ir.training_metadata.top_gradient_fraction.unwrap_or(0.2),
657            ir.training_metadata.other_gradient_fraction.unwrap_or(0.1),
658            ir.training_metadata
659                .max_features
660                .unwrap_or(num_features.max(1)),
661            ir.training_metadata.seed,
662            num_features,
663            feature_preprocessing,
664            deserialized_class_labels,
665            training_canaries,
666        )));
667    }
668
669    let tree = ir
670        .model
671        .trees
672        .into_iter()
673        .next()
674        .expect("validated single tree");
675
676    single_model_from_ir_parts(
677        task,
678        tree_type,
679        criterion,
680        feature_preprocessing,
681        num_features,
682        options,
683        training_canaries,
684        deserialized_class_labels,
685        tree,
686    )
687}
688
689fn boosted_tree_model_from_ir_parts(
690    tree_type: TreeType,
691    criterion: Criterion,
692    feature_preprocessing: Vec<FeaturePreprocessing>,
693    num_features: usize,
694    options: DecisionTreeOptions,
695    training_canaries: usize,
696    tree: TreeDefinition,
697) -> Result<Model, IrError> {
698    match (tree_type, tree) {
699        (
700            TreeType::Cart | TreeType::Randomized,
701            TreeDefinition::NodeTree {
702                nodes,
703                root_node_id,
704                ..
705            },
706        ) => Ok(Model::DecisionTreeRegressor(
707            DecisionTreeRegressor::from_ir_parts(
708                match tree_type {
709                    TreeType::Cart => RegressionTreeAlgorithm::Cart,
710                    TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
711                    _ => unreachable!(),
712                },
713                criterion,
714                RegressionTreeStructure::Standard {
715                    nodes: rebuild_regressor_nodes(nodes)?,
716                    root: root_node_id,
717                },
718                RegressionTreeOptions {
719                    max_depth: options.max_depth,
720                    min_samples_split: options.min_samples_split,
721                    min_samples_leaf: options.min_samples_leaf,
722                    max_features: None,
723                    random_seed: 0,
724                    missing_value_strategies: Vec::new(),
725                    canary_filter: crate::CanaryFilter::default(),
726                },
727                num_features,
728                feature_preprocessing,
729                training_canaries,
730            ),
731        )),
732        (TreeType::Oblivious, TreeDefinition::ObliviousLevels { levels, leaves, .. }) => {
733            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
734            let leaf_variances = rebuild_leaf_variances(&leaves)?;
735            Ok(Model::DecisionTreeRegressor(
736                DecisionTreeRegressor::from_ir_parts(
737                    RegressionTreeAlgorithm::Oblivious,
738                    criterion,
739                    RegressionTreeStructure::Oblivious {
740                        splits: rebuild_regressor_oblivious_splits(levels)?,
741                        leaf_values: rebuild_regressor_leaf_values(leaves)?,
742                        leaf_sample_counts,
743                        leaf_variances,
744                    },
745                    RegressionTreeOptions {
746                        max_depth: options.max_depth,
747                        min_samples_split: options.min_samples_split,
748                        min_samples_leaf: options.min_samples_leaf,
749                        max_features: None,
750                        random_seed: 0,
751                        missing_value_strategies: Vec::new(),
752                        canary_filter: crate::CanaryFilter::default(),
753                    },
754                    num_features,
755                    feature_preprocessing,
756                    training_canaries,
757                ),
758            ))
759        }
760        (_, tree) => Err(IrError::UnsupportedRepresentation(match tree {
761            TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
762            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
763        })),
764    }
765}
766
767#[allow(clippy::too_many_arguments)]
768fn single_model_from_ir_parts(
769    task: Task,
770    tree_type: TreeType,
771    criterion: Criterion,
772    feature_preprocessing: Vec<FeaturePreprocessing>,
773    num_features: usize,
774    options: DecisionTreeOptions,
775    training_canaries: usize,
776    deserialized_class_labels: Option<Vec<f64>>,
777    tree: TreeDefinition,
778) -> Result<Model, IrError> {
779    match (task, tree_type, tree) {
780        (
781            Task::Classification,
782            TreeType::Id3 | TreeType::C45 | TreeType::Cart | TreeType::Randomized,
783            TreeDefinition::NodeTree {
784                nodes,
785                root_node_id,
786                ..
787            },
788        ) => {
789            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
790            let structure = ClassifierTreeStructure::Standard {
791                nodes: rebuild_classifier_nodes(nodes, &class_labels)?,
792                root: root_node_id,
793            };
794            Ok(Model::DecisionTreeClassifier(
795                DecisionTreeClassifier::from_ir_parts(
796                    match tree_type {
797                        TreeType::Id3 => DecisionTreeAlgorithm::Id3,
798                        TreeType::C45 => DecisionTreeAlgorithm::C45,
799                        TreeType::Cart => DecisionTreeAlgorithm::Cart,
800                        TreeType::Randomized => DecisionTreeAlgorithm::Randomized,
801                        TreeType::Oblivious => unreachable!(),
802                    },
803                    criterion,
804                    class_labels,
805                    structure,
806                    options,
807                    num_features,
808                    feature_preprocessing,
809                    training_canaries,
810                ),
811            ))
812        }
813        (
814            Task::Classification,
815            TreeType::Oblivious,
816            TreeDefinition::ObliviousLevels { levels, leaves, .. },
817        ) => {
818            let class_labels = deserialized_class_labels.ok_or(IrError::MissingClassLabels)?;
819            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
820            let leaf_class_counts =
821                rebuild_classifier_leaf_class_counts(&leaves, class_labels.len())?;
822            let structure = ClassifierTreeStructure::Oblivious {
823                splits: rebuild_classifier_oblivious_splits(levels)?,
824                leaf_class_indices: rebuild_classifier_leaf_indices(leaves, &class_labels)?,
825                leaf_sample_counts,
826                leaf_class_counts,
827            };
828            Ok(Model::DecisionTreeClassifier(
829                DecisionTreeClassifier::from_ir_parts(
830                    DecisionTreeAlgorithm::Oblivious,
831                    criterion,
832                    class_labels,
833                    structure,
834                    options,
835                    num_features,
836                    feature_preprocessing,
837                    training_canaries,
838                ),
839            ))
840        }
841        (
842            Task::Regression,
843            TreeType::Cart | TreeType::Randomized,
844            TreeDefinition::NodeTree {
845                nodes,
846                root_node_id,
847                ..
848            },
849        ) => Ok(Model::DecisionTreeRegressor(
850            DecisionTreeRegressor::from_ir_parts(
851                match tree_type {
852                    TreeType::Cart => RegressionTreeAlgorithm::Cart,
853                    TreeType::Randomized => RegressionTreeAlgorithm::Randomized,
854                    _ => unreachable!(),
855                },
856                criterion,
857                RegressionTreeStructure::Standard {
858                    nodes: rebuild_regressor_nodes(nodes)?,
859                    root: root_node_id,
860                },
861                RegressionTreeOptions {
862                    max_depth: options.max_depth,
863                    min_samples_split: options.min_samples_split,
864                    min_samples_leaf: options.min_samples_leaf,
865                    max_features: None,
866                    random_seed: 0,
867                    missing_value_strategies: Vec::new(),
868                    canary_filter: crate::CanaryFilter::default(),
869                },
870                num_features,
871                feature_preprocessing,
872                training_canaries,
873            ),
874        )),
875        (
876            Task::Regression,
877            TreeType::Oblivious,
878            TreeDefinition::ObliviousLevels { levels, leaves, .. },
879        ) => {
880            let leaf_sample_counts = rebuild_leaf_sample_counts(&leaves)?;
881            let leaf_variances = rebuild_leaf_variances(&leaves)?;
882            Ok(Model::DecisionTreeRegressor(
883                DecisionTreeRegressor::from_ir_parts(
884                    RegressionTreeAlgorithm::Oblivious,
885                    criterion,
886                    RegressionTreeStructure::Oblivious {
887                        splits: rebuild_regressor_oblivious_splits(levels)?,
888                        leaf_values: rebuild_regressor_leaf_values(leaves)?,
889                        leaf_sample_counts,
890                        leaf_variances,
891                    },
892                    RegressionTreeOptions {
893                        max_depth: options.max_depth,
894                        min_samples_split: options.min_samples_split,
895                        min_samples_leaf: options.min_samples_leaf,
896                        max_features: None,
897                        random_seed: 0,
898                        missing_value_strategies: Vec::new(),
899                        canary_filter: crate::CanaryFilter::default(),
900                    },
901                    num_features,
902                    feature_preprocessing,
903                    training_canaries,
904                ),
905            ))
906        }
907        (_, _, tree) => Err(IrError::UnsupportedRepresentation(match tree {
908            TreeDefinition::NodeTree { .. } => "node_tree".to_string(),
909            TreeDefinition::ObliviousLevels { .. } => "oblivious_levels".to_string(),
910        })),
911    }
912}
913
914fn validate_ir_header(ir: &ModelPackageIr) -> Result<(), IrError> {
915    if ir.ir_version != IR_VERSION {
916        return Err(IrError::UnsupportedIrVersion(ir.ir_version.clone()));
917    }
918    if ir.format_name != FORMAT_NAME {
919        return Err(IrError::UnsupportedFormatName(ir.format_name.clone()));
920    }
921    if ir.model.supports_missing {
922        return Err(IrError::InvalidInferenceOption(
923            "missing values are not supported in IR v1".to_string(),
924        ));
925    }
926    if ir.model.supports_categorical {
927        return Err(IrError::InvalidInferenceOption(
928            "categorical features are not supported in IR v1".to_string(),
929        ));
930    }
931    Ok(())
932}
933
934fn validate_inference_options(options: &InferenceOptions) -> Result<(), IrError> {
935    if options.threshold_comparison != "leq_left_gt_right" {
936        return Err(IrError::InvalidInferenceOption(format!(
937            "unsupported threshold comparison '{}'",
938            options.threshold_comparison
939        )));
940    }
941    if options.nan_policy != "not_supported" {
942        return Err(IrError::InvalidInferenceOption(format!(
943            "unsupported nan policy '{}'",
944            options.nan_policy
945        )));
946    }
947    Ok(())
948}
949
950fn parse_algorithm(value: &str) -> Result<TrainAlgorithm, IrError> {
951    match value {
952        "dt" => Ok(TrainAlgorithm::Dt),
953        "rf" => Ok(TrainAlgorithm::Rf),
954        "gbm" => Ok(TrainAlgorithm::Gbm),
955        _ => Err(IrError::UnsupportedAlgorithm(value.to_string())),
956    }
957}
958
959fn parse_task(value: &str) -> Result<Task, IrError> {
960    match value {
961        "regression" => Ok(Task::Regression),
962        "classification" => Ok(Task::Classification),
963        _ => Err(IrError::UnsupportedTask(value.to_string())),
964    }
965}
966
967fn parse_tree_type(value: &str) -> Result<TreeType, IrError> {
968    match value {
969        "id3" => Ok(TreeType::Id3),
970        "c45" => Ok(TreeType::C45),
971        "cart" => Ok(TreeType::Cart),
972        "randomized" => Ok(TreeType::Randomized),
973        "oblivious" => Ok(TreeType::Oblivious),
974        _ => Err(IrError::UnsupportedTreeType(value.to_string())),
975    }
976}
977
978fn parse_criterion(value: &str) -> Result<crate::Criterion, IrError> {
979    match value {
980        "gini" => Ok(crate::Criterion::Gini),
981        "entropy" => Ok(crate::Criterion::Entropy),
982        "mean" => Ok(crate::Criterion::Mean),
983        "median" => Ok(crate::Criterion::Median),
984        "second_order" => Ok(crate::Criterion::SecondOrder),
985        "auto" => Ok(crate::Criterion::Auto),
986        _ => Err(IrError::InvalidInferenceOption(format!(
987            "unsupported criterion '{}'",
988            value
989        ))),
990    }
991}
992
993fn tree_options(training: &TrainingMetadata) -> DecisionTreeOptions {
994    DecisionTreeOptions {
995        max_depth: training.max_depth.unwrap_or(8),
996        min_samples_split: training.min_samples_split.unwrap_or(2),
997        min_samples_leaf: training.min_samples_leaf.unwrap_or(1),
998        max_features: None,
999        random_seed: 0,
1000        missing_value_strategies: Vec::new(),
1001        canary_filter: crate::CanaryFilter::default(),
1002    }
1003}
1004
1005fn feature_preprocessing_from_ir(
1006    ir: &ModelPackageIr,
1007) -> Result<Vec<FeaturePreprocessing>, IrError> {
1008    let mut features: Vec<Option<FeaturePreprocessing>> = vec![None; ir.input_schema.feature_count];
1009
1010    for feature in &ir.preprocessing.numeric_binning.features {
1011        match feature {
1012            FeatureBinning::Numeric {
1013                feature_index,
1014                boundaries,
1015            } => {
1016                let slot = features.get_mut(*feature_index).ok_or_else(|| {
1017                    IrError::InvalidFeatureCount {
1018                        schema: ir.input_schema.feature_count,
1019                        preprocessing: feature_index + 1,
1020                    }
1021                })?;
1022                *slot = Some(FeaturePreprocessing::Numeric {
1023                    bin_boundaries: boundaries.clone(),
1024                    missing_bin: boundaries
1025                        .iter()
1026                        .map(|boundary| boundary.bin)
1027                        .max()
1028                        .map_or(0, |bin| bin.saturating_add(1)),
1029                });
1030            }
1031            FeatureBinning::Binary { feature_index } => {
1032                let slot = features.get_mut(*feature_index).ok_or_else(|| {
1033                    IrError::InvalidFeatureCount {
1034                        schema: ir.input_schema.feature_count,
1035                        preprocessing: feature_index + 1,
1036                    }
1037                })?;
1038                *slot = Some(FeaturePreprocessing::Binary);
1039            }
1040        }
1041    }
1042
1043    if features.len() != ir.input_schema.feature_count {
1044        return Err(IrError::InvalidFeatureCount {
1045            schema: ir.input_schema.feature_count,
1046            preprocessing: features.len(),
1047        });
1048    }
1049
1050    features
1051        .into_iter()
1052        .map(|feature| {
1053            feature.ok_or_else(|| {
1054                IrError::InvalidPreprocessing(
1055                    "every feature must have a preprocessing entry".to_string(),
1056                )
1057            })
1058        })
1059        .collect()
1060}
1061
1062fn classification_labels(ir: &ModelPackageIr) -> Result<Vec<f64>, IrError> {
1063    ir.output_schema
1064        .class_order
1065        .clone()
1066        .or_else(|| ir.training_metadata.class_labels.clone())
1067        .ok_or(IrError::MissingClassLabels)
1068}
1069
1070fn classifier_class_index(leaf: &LeafPayload, class_labels: &[f64]) -> Result<usize, IrError> {
1071    match leaf {
1072        LeafPayload::ClassIndex {
1073            class_index,
1074            class_value,
1075        } => {
1076            let Some(expected) = class_labels.get(*class_index) else {
1077                return Err(IrError::InvalidLeaf(format!(
1078                    "class index {} out of bounds",
1079                    class_index
1080                )));
1081            };
1082            if expected.total_cmp(class_value).is_ne() {
1083                return Err(IrError::InvalidLeaf(format!(
1084                    "class value {} does not match class order entry {}",
1085                    class_value, expected
1086                )));
1087            }
1088            Ok(*class_index)
1089        }
1090        LeafPayload::RegressionValue { .. } => Err(IrError::InvalidLeaf(
1091            "expected class_index leaf".to_string(),
1092        )),
1093    }
1094}
1095
1096fn rebuild_classifier_nodes(
1097    nodes: Vec<NodeTreeNode>,
1098    class_labels: &[f64],
1099) -> Result<Vec<ClassifierTreeNode>, IrError> {
1100    let mut rebuilt = vec![None; nodes.len()];
1101    for node in nodes {
1102        match node {
1103            NodeTreeNode::Leaf {
1104                node_id,
1105                leaf,
1106                stats,
1107                ..
1108            } => {
1109                let class_index = classifier_class_index(&leaf, class_labels)?;
1110                assign_node(
1111                    &mut rebuilt,
1112                    node_id,
1113                    ClassifierTreeNode::Leaf {
1114                        class_index,
1115                        sample_count: stats.sample_count,
1116                        class_counts: stats
1117                            .class_counts
1118                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1119                    },
1120                )?;
1121            }
1122            NodeTreeNode::BinaryBranch {
1123                node_id,
1124                split,
1125                children,
1126                stats,
1127                ..
1128            } => {
1129                let (feature_index, threshold_bin) = classifier_binary_split(split)?;
1130                assign_node(
1131                    &mut rebuilt,
1132                    node_id,
1133                    ClassifierTreeNode::BinarySplit {
1134                        feature_index,
1135                        threshold_bin,
1136                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1137                        left_child: children.left,
1138                        right_child: children.right,
1139                        sample_count: stats.sample_count,
1140                        impurity: stats.impurity.unwrap_or(0.0),
1141                        gain: stats.gain.unwrap_or(0.0),
1142                        class_counts: stats
1143                            .class_counts
1144                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1145                    },
1146                )?;
1147            }
1148            NodeTreeNode::MultiwayBranch {
1149                node_id,
1150                split,
1151                branches,
1152                unmatched_leaf,
1153                stats,
1154                ..
1155            } => {
1156                let fallback_class_index = classifier_class_index(&unmatched_leaf, class_labels)?;
1157                assign_node(
1158                    &mut rebuilt,
1159                    node_id,
1160                    ClassifierTreeNode::MultiwaySplit {
1161                        feature_index: split.feature_index,
1162                        fallback_class_index,
1163                        branches: branches
1164                            .into_iter()
1165                            .map(|branch| (branch.bin, branch.child))
1166                            .collect(),
1167                        missing_child: None,
1168                        sample_count: stats.sample_count,
1169                        impurity: stats.impurity.unwrap_or(0.0),
1170                        gain: stats.gain.unwrap_or(0.0),
1171                        class_counts: stats
1172                            .class_counts
1173                            .unwrap_or_else(|| vec![0; class_labels.len()]),
1174                    },
1175                )?;
1176            }
1177        }
1178    }
1179    collect_nodes(rebuilt)
1180}
1181
1182fn rebuild_regressor_nodes(nodes: Vec<NodeTreeNode>) -> Result<Vec<RegressionNode>, IrError> {
1183    let mut rebuilt = vec![None; nodes.len()];
1184    for node in nodes {
1185        match node {
1186            NodeTreeNode::Leaf {
1187                node_id,
1188                leaf: LeafPayload::RegressionValue { value },
1189                stats,
1190                ..
1191            } => {
1192                assign_node(
1193                    &mut rebuilt,
1194                    node_id,
1195                    RegressionNode::Leaf {
1196                        value,
1197                        sample_count: stats.sample_count,
1198                        variance: stats.variance,
1199                    },
1200                )?;
1201            }
1202            NodeTreeNode::Leaf { .. } => {
1203                return Err(IrError::InvalidLeaf(
1204                    "regression trees require regression_value leaves".to_string(),
1205                ));
1206            }
1207            NodeTreeNode::BinaryBranch {
1208                node_id,
1209                split,
1210                children,
1211                stats,
1212                ..
1213            } => {
1214                let (feature_index, threshold_bin) = regressor_binary_split(split)?;
1215                assign_node(
1216                    &mut rebuilt,
1217                    node_id,
1218                    RegressionNode::BinarySplit {
1219                        feature_index,
1220                        threshold_bin,
1221                        missing_direction: crate::tree::shared::MissingBranchDirection::Node,
1222                        missing_value: 0.0,
1223                        left_child: children.left,
1224                        right_child: children.right,
1225                        sample_count: stats.sample_count,
1226                        impurity: stats.impurity.unwrap_or(0.0),
1227                        gain: stats.gain.unwrap_or(0.0),
1228                        variance: stats.variance,
1229                    },
1230                )?;
1231            }
1232            NodeTreeNode::MultiwayBranch { .. } => {
1233                return Err(IrError::InvalidNode(
1234                    "regression trees do not support multiway branches".to_string(),
1235                ));
1236            }
1237        }
1238    }
1239    collect_nodes(rebuilt)
1240}
1241
1242fn rebuild_classifier_oblivious_splits(
1243    levels: Vec<ObliviousLevel>,
1244) -> Result<Vec<ClassifierObliviousSplit>, IrError> {
1245    let mut rebuilt = Vec::with_capacity(levels.len());
1246    for level in levels {
1247        rebuilt.push(match level.split {
1248            ObliviousSplit::NumericBinThreshold {
1249                feature_index,
1250                threshold_bin,
1251                ..
1252            } => ClassifierObliviousSplit {
1253                feature_index,
1254                threshold_bin,
1255                missing_directions: Vec::new(),
1256                sample_count: level.stats.sample_count,
1257                impurity: level.stats.impurity.unwrap_or(0.0),
1258                gain: level.stats.gain.unwrap_or(0.0),
1259            },
1260            ObliviousSplit::BooleanTest { feature_index, .. } => ClassifierObliviousSplit {
1261                feature_index,
1262                threshold_bin: 0,
1263                missing_directions: Vec::new(),
1264                sample_count: level.stats.sample_count,
1265                impurity: level.stats.impurity.unwrap_or(0.0),
1266                gain: level.stats.gain.unwrap_or(0.0),
1267            },
1268        });
1269    }
1270    Ok(rebuilt)
1271}
1272
1273fn rebuild_regressor_oblivious_splits(
1274    levels: Vec<ObliviousLevel>,
1275) -> Result<Vec<RegressorObliviousSplit>, IrError> {
1276    let mut rebuilt = Vec::with_capacity(levels.len());
1277    for level in levels {
1278        rebuilt.push(match level.split {
1279            ObliviousSplit::NumericBinThreshold {
1280                feature_index,
1281                threshold_bin,
1282                ..
1283            } => RegressorObliviousSplit {
1284                feature_index,
1285                threshold_bin,
1286                sample_count: level.stats.sample_count,
1287                impurity: level.stats.impurity.unwrap_or(0.0),
1288                gain: level.stats.gain.unwrap_or(0.0),
1289            },
1290            ObliviousSplit::BooleanTest { feature_index, .. } => RegressorObliviousSplit {
1291                feature_index,
1292                threshold_bin: 0,
1293                sample_count: level.stats.sample_count,
1294                impurity: level.stats.impurity.unwrap_or(0.0),
1295                gain: level.stats.gain.unwrap_or(0.0),
1296            },
1297        });
1298    }
1299    Ok(rebuilt)
1300}
1301
1302fn rebuild_classifier_leaf_indices(
1303    leaves: Vec<IndexedLeaf>,
1304    class_labels: &[f64],
1305) -> Result<Vec<usize>, IrError> {
1306    let mut rebuilt = vec![None; leaves.len()];
1307    for indexed_leaf in leaves {
1308        let class_index = classifier_class_index(&indexed_leaf.leaf, class_labels)?;
1309        assign_node(&mut rebuilt, indexed_leaf.leaf_index, class_index)?;
1310    }
1311    collect_nodes(rebuilt)
1312}
1313
1314fn rebuild_regressor_leaf_values(leaves: Vec<IndexedLeaf>) -> Result<Vec<f64>, IrError> {
1315    let mut rebuilt = vec![None; leaves.len()];
1316    for indexed_leaf in leaves {
1317        let value = match indexed_leaf.leaf {
1318            LeafPayload::RegressionValue { value } => value,
1319            LeafPayload::ClassIndex { .. } => {
1320                return Err(IrError::InvalidLeaf(
1321                    "regression oblivious leaves require regression_value".to_string(),
1322                ));
1323            }
1324        };
1325        assign_node(&mut rebuilt, indexed_leaf.leaf_index, value)?;
1326    }
1327    collect_nodes(rebuilt)
1328}
1329
1330fn rebuild_leaf_sample_counts(leaves: &[IndexedLeaf]) -> Result<Vec<usize>, IrError> {
1331    let mut rebuilt = vec![None; leaves.len()];
1332    for indexed_leaf in leaves {
1333        assign_node(
1334            &mut rebuilt,
1335            indexed_leaf.leaf_index,
1336            indexed_leaf.stats.sample_count,
1337        )?;
1338    }
1339    collect_nodes(rebuilt)
1340}
1341
1342fn rebuild_leaf_variances(leaves: &[IndexedLeaf]) -> Result<Vec<Option<f64>>, IrError> {
1343    let mut rebuilt = vec![None; leaves.len()];
1344    for indexed_leaf in leaves {
1345        assign_node(
1346            &mut rebuilt,
1347            indexed_leaf.leaf_index,
1348            indexed_leaf.stats.variance,
1349        )?;
1350    }
1351    collect_nodes(rebuilt)
1352}
1353
1354fn rebuild_classifier_leaf_class_counts(
1355    leaves: &[IndexedLeaf],
1356    num_classes: usize,
1357) -> Result<Vec<Vec<usize>>, IrError> {
1358    let mut rebuilt = vec![None; leaves.len()];
1359    for indexed_leaf in leaves {
1360        assign_node(
1361            &mut rebuilt,
1362            indexed_leaf.leaf_index,
1363            indexed_leaf
1364                .stats
1365                .class_counts
1366                .clone()
1367                .unwrap_or_else(|| vec![0; num_classes]),
1368        )?;
1369    }
1370    collect_nodes(rebuilt)
1371}
1372
1373fn classifier_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1374    match split {
1375        BinarySplit::NumericBinThreshold {
1376            feature_index,
1377            threshold_bin,
1378            ..
1379        } => Ok((feature_index, threshold_bin)),
1380        BinarySplit::BooleanTest { feature_index, .. } => Ok((feature_index, 0)),
1381    }
1382}
1383
1384fn regressor_binary_split(split: BinarySplit) -> Result<(usize, u16), IrError> {
1385    classifier_binary_split(split)
1386}
1387
1388fn assign_node<T>(slots: &mut [Option<T>], index: usize, value: T) -> Result<(), IrError> {
1389    let Some(slot) = slots.get_mut(index) else {
1390        return Err(IrError::InvalidNode(format!(
1391            "node index {} is out of bounds",
1392            index
1393        )));
1394    };
1395    if slot.is_some() {
1396        return Err(IrError::InvalidNode(format!(
1397            "duplicate node index {}",
1398            index
1399        )));
1400    }
1401    *slot = Some(value);
1402    Ok(())
1403}
1404
1405fn collect_nodes<T>(slots: Vec<Option<T>>) -> Result<Vec<T>, IrError> {
1406    slots
1407        .into_iter()
1408        .enumerate()
1409        .map(|(index, slot)| {
1410            slot.ok_or_else(|| IrError::InvalidNode(format!("missing node index {}", index)))
1411        })
1412        .collect()
1413}
1414
1415fn input_schema(model: &Model) -> InputSchema {
1416    let features = model
1417        .feature_preprocessing()
1418        .iter()
1419        .enumerate()
1420        .map(|(feature_index, preprocessing)| {
1421            let kind = match preprocessing {
1422                FeaturePreprocessing::Numeric { .. } => InputFeatureKind::Numeric,
1423                FeaturePreprocessing::Binary => InputFeatureKind::Binary,
1424            };
1425
1426            InputFeature {
1427                index: feature_index,
1428                name: feature_name(feature_index),
1429                dtype: match kind {
1430                    InputFeatureKind::Numeric => "float64".to_string(),
1431                    InputFeatureKind::Binary => "bool".to_string(),
1432                },
1433                logical_type: match kind {
1434                    InputFeatureKind::Numeric => "numeric".to_string(),
1435                    InputFeatureKind::Binary => "boolean".to_string(),
1436                },
1437                nullable: false,
1438            }
1439        })
1440        .collect();
1441
1442    InputSchema {
1443        feature_count: model.num_features(),
1444        features,
1445        ordering: "strict_index_order".to_string(),
1446        input_tensor_layout: "row_major".to_string(),
1447        accepts_feature_names: false,
1448    }
1449}
1450
1451fn output_schema(model: &Model, class_labels: Option<Vec<f64>>) -> OutputSchema {
1452    match model.task() {
1453        Task::Regression => OutputSchema {
1454            raw_outputs: vec![OutputField {
1455                name: "value".to_string(),
1456                kind: "regression_value".to_string(),
1457                shape: Vec::new(),
1458                dtype: "float64".to_string(),
1459            }],
1460            final_outputs: vec![OutputField {
1461                name: "prediction".to_string(),
1462                kind: "value".to_string(),
1463                shape: Vec::new(),
1464                dtype: "float64".to_string(),
1465            }],
1466            class_order: None,
1467        },
1468        Task::Classification => OutputSchema {
1469            raw_outputs: vec![OutputField {
1470                name: "class_index".to_string(),
1471                kind: "class_index".to_string(),
1472                shape: Vec::new(),
1473                dtype: "uint64".to_string(),
1474            }],
1475            final_outputs: vec![OutputField {
1476                name: "predicted_class".to_string(),
1477                kind: "class_label".to_string(),
1478                shape: Vec::new(),
1479                dtype: "float64".to_string(),
1480            }],
1481            class_order: class_labels,
1482        },
1483    }
1484}
1485
1486fn preprocessing(model: &Model) -> PreprocessingSection {
1487    let features = model
1488        .feature_preprocessing()
1489        .iter()
1490        .enumerate()
1491        .map(|(feature_index, preprocessing)| match preprocessing {
1492            FeaturePreprocessing::Numeric { bin_boundaries, .. } => FeatureBinning::Numeric {
1493                feature_index,
1494                boundaries: bin_boundaries.clone(),
1495            },
1496            FeaturePreprocessing::Binary => FeatureBinning::Binary { feature_index },
1497        })
1498        .collect();
1499
1500    PreprocessingSection {
1501        included_in_model: true,
1502        numeric_binning: NumericBinning {
1503            kind: "rank_bin_128".to_string(),
1504            features,
1505        },
1506        notes: "Numeric features use serialized training-time rank bins. Binary features are serialized as booleans. Missing values and categorical encodings are not implemented in IR v1."
1507            .to_string(),
1508    }
1509}
1510
1511fn postprocessing(model: &Model, class_labels: Option<Vec<f64>>) -> PostprocessingSection {
1512    match model.task() {
1513        Task::Regression => PostprocessingSection {
1514            raw_output_kind: "regression_value".to_string(),
1515            steps: vec![PostprocessingStep::Identity],
1516        },
1517        Task::Classification => PostprocessingSection {
1518            raw_output_kind: "class_index".to_string(),
1519            steps: vec![PostprocessingStep::MapClassIndexToLabel {
1520                labels: class_labels.expect("classification IR requires class labels"),
1521            }],
1522        },
1523    }
1524}
1525
1526fn required_capabilities(model: &Model, representation: &str) -> Vec<String> {
1527    let mut capabilities = vec![
1528        representation.to_string(),
1529        "training_rank_bin_128".to_string(),
1530    ];
1531    match model.tree_type() {
1532        TreeType::Id3 | TreeType::C45 => {
1533            capabilities.push("binned_multiway_splits".to_string());
1534        }
1535        TreeType::Cart | TreeType::Randomized | TreeType::Oblivious => {
1536            capabilities.push("numeric_bin_threshold_splits".to_string());
1537        }
1538    }
1539    if model
1540        .feature_preprocessing()
1541        .iter()
1542        .any(|feature| matches!(feature, FeaturePreprocessing::Binary))
1543    {
1544        capabilities.push("boolean_features".to_string());
1545    }
1546    match model.task() {
1547        Task::Regression => capabilities.push("regression_value_leaves".to_string()),
1548        Task::Classification => capabilities.push("class_index_leaves".to_string()),
1549    }
1550    capabilities
1551}
1552
1553pub(crate) fn algorithm_name(algorithm: TrainAlgorithm) -> &'static str {
1554    match algorithm {
1555        TrainAlgorithm::Dt => "dt",
1556        TrainAlgorithm::Rf => "rf",
1557        TrainAlgorithm::Gbm => "gbm",
1558    }
1559}
1560
1561fn model_tree_definition(model: &Model) -> TreeDefinition {
1562    match model {
1563        Model::DecisionTreeClassifier(classifier) => classifier.to_ir_tree(),
1564        Model::DecisionTreeRegressor(regressor) => regressor.to_ir_tree(),
1565        Model::RandomForest(_) | Model::GradientBoostedTrees(_) => {
1566            unreachable!("ensemble IR expands into member trees")
1567        }
1568    }
1569}
1570
1571pub(crate) fn criterion_name(criterion: crate::Criterion) -> &'static str {
1572    match criterion {
1573        crate::Criterion::Auto => "auto",
1574        crate::Criterion::Gini => "gini",
1575        crate::Criterion::Entropy => "entropy",
1576        crate::Criterion::Mean => "mean",
1577        crate::Criterion::Median => "median",
1578        crate::Criterion::SecondOrder => "second_order",
1579    }
1580}
1581
1582pub(crate) fn task_name(task: Task) -> &'static str {
1583    match task {
1584        Task::Regression => "regression",
1585        Task::Classification => "classification",
1586    }
1587}
1588
1589pub(crate) fn tree_type_name(tree_type: TreeType) -> &'static str {
1590    match tree_type {
1591        TreeType::Id3 => "id3",
1592        TreeType::C45 => "c45",
1593        TreeType::Cart => "cart",
1594        TreeType::Randomized => "randomized",
1595        TreeType::Oblivious => "oblivious",
1596    }
1597}
1598
1599pub(crate) fn feature_name(feature_index: usize) -> String {
1600    format!("f{}", feature_index)
1601}
1602
1603pub(crate) fn threshold_upper_bound(
1604    preprocessing: &[FeaturePreprocessing],
1605    feature_index: usize,
1606    threshold_bin: u16,
1607) -> Option<f64> {
1608    match preprocessing.get(feature_index)? {
1609        FeaturePreprocessing::Numeric { bin_boundaries, .. } => bin_boundaries
1610            .iter()
1611            .find(|boundary| boundary.bin == threshold_bin)
1612            .map(|boundary| boundary.upper_bound),
1613        FeaturePreprocessing::Binary => None,
1614    }
1615}