use super::*;
use std::collections::BTreeSet;
pub(crate) fn build_feature_projection(model: &Model) -> Vec<usize> {
model_used_feature_indices(model)
}
pub(crate) fn build_feature_index_map(num_features: usize, projection: &[usize]) -> Vec<usize> {
let mut map = vec![usize::MAX; num_features];
for (local_index, feature_index) in projection.iter().copied().enumerate() {
map[feature_index] = local_index;
}
map
}
pub(crate) fn remap_feature_index(feature_index: usize, feature_index_map: &[usize]) -> usize {
feature_index_map[feature_index]
}
pub(crate) fn ordered_ensemble_indices(trees: &[Model]) -> Vec<usize> {
let mut keyed = trees
.iter()
.enumerate()
.map(|(tree_index, tree)| {
let used = model_used_feature_indices(tree);
let primary_feature = tree_primary_feature(tree).unwrap_or(usize::MAX);
(tree_index, primary_feature, used.len(), used)
})
.collect::<Vec<_>>();
keyed.sort_by(|left, right| {
left.1
.cmp(&right.1)
.then_with(|| left.2.cmp(&right.2))
.then_with(|| left.3.cmp(&right.3))
.then_with(|| left.0.cmp(&right.0))
});
keyed
.into_iter()
.map(|(tree_index, _, _, _)| tree_index)
.collect()
}
pub(crate) fn model_used_feature_indices(model: &Model) -> Vec<usize> {
let ir = model.to_ir();
let mut used = BTreeSet::new();
for tree in &ir.model.trees {
collect_tree_used_features(tree, &mut used);
}
used.into_iter().collect()
}
pub(crate) fn tree_primary_feature(model: &Model) -> Option<usize> {
let ir = model.to_ir();
ir.model
.trees
.first()
.and_then(tree_definition_primary_feature)
}
fn collect_tree_used_features(tree: &ir::TreeDefinition, used: &mut BTreeSet<usize>) {
match tree {
ir::TreeDefinition::NodeTree { nodes, .. } => {
for node in nodes {
match node {
ir::NodeTreeNode::Leaf { .. } => {}
ir::NodeTreeNode::BinaryBranch { split, .. } => {
used.insert(binary_split_feature_index(split));
}
ir::NodeTreeNode::MultiwayBranch { split, .. } => {
used.insert(split.feature_index);
}
}
}
}
ir::TreeDefinition::ObliviousLevels { levels, .. } => {
for level in levels {
used.insert(oblivious_split_feature_index(&level.split));
}
}
}
}
fn tree_definition_primary_feature(tree: &ir::TreeDefinition) -> Option<usize> {
match tree {
ir::TreeDefinition::NodeTree {
root_node_id,
nodes,
..
} => nodes.iter().find_map(|node| match node {
ir::NodeTreeNode::Leaf { node_id, .. } if node_id == root_node_id => None,
ir::NodeTreeNode::BinaryBranch { node_id, split, .. } if node_id == root_node_id => {
Some(binary_split_feature_index(split))
}
ir::NodeTreeNode::MultiwayBranch { node_id, split, .. } if node_id == root_node_id => {
Some(split.feature_index)
}
_ => None,
}),
ir::TreeDefinition::ObliviousLevels { levels, .. } => levels
.first()
.map(|level| oblivious_split_feature_index(&level.split)),
}
}
fn binary_split_feature_index(split: &ir::BinarySplit) -> usize {
match split {
ir::BinarySplit::NumericBinThreshold { feature_index, .. }
| ir::BinarySplit::BooleanTest { feature_index, .. } => *feature_index,
}
}
fn oblivious_split_feature_index(split: &ir::ObliviousSplit) -> usize {
match split {
ir::ObliviousSplit::NumericBinThreshold { feature_index, .. }
| ir::ObliviousSplit::BooleanTest { feature_index, .. } => *feature_index,
}
}