use super::*;
pub(super) fn normalized_class_probabilities(class_counts: &[usize]) -> Vec<f64> {
let total = class_counts.iter().sum::<usize>();
if total == 0 {
return vec![0.0; class_counts.len()];
}
class_counts
.iter()
.map(|count| *count as f64 / total as f64)
.collect()
}
pub(super) fn standard_node_depths(nodes: &[TreeNode], root: usize) -> Vec<usize> {
let mut depths = vec![0; nodes.len()];
populate_depths(nodes, root, 0, &mut depths);
depths
}
fn populate_depths(nodes: &[TreeNode], node_id: usize, depth: usize, depths: &mut [usize]) {
depths[node_id] = depth;
match &nodes[node_id] {
TreeNode::Leaf { .. } => {}
TreeNode::BinarySplit {
left_child,
right_child,
..
} => {
populate_depths(nodes, *left_child, depth + 1, depths);
populate_depths(nodes, *right_child, depth + 1, depths);
}
TreeNode::MultiwaySplit { branches, .. } => {
for (_, child) in branches {
populate_depths(nodes, *child, depth + 1, depths);
}
}
}
}
pub(super) fn binary_split_ir(
feature_index: usize,
threshold_bin: u16,
_missing_direction: MissingBranchDirection,
preprocessing: &[FeaturePreprocessing],
) -> BinarySplit {
match preprocessing.get(feature_index) {
Some(FeaturePreprocessing::Binary) => BinarySplit::BooleanTest {
feature_index,
feature_name: feature_name(feature_index),
false_child_semantics: "left".to_string(),
true_child_semantics: "right".to_string(),
},
Some(FeaturePreprocessing::Numeric { .. }) | None => BinarySplit::NumericBinThreshold {
feature_index,
feature_name: feature_name(feature_index),
operator: "<=".to_string(),
threshold_bin,
threshold_upper_bound: threshold_upper_bound(
preprocessing,
feature_index,
threshold_bin,
),
comparison_dtype: "uint16".to_string(),
},
}
}
pub(super) fn oblivious_split_ir(
feature_index: usize,
threshold_bin: u16,
preprocessing: &[FeaturePreprocessing],
) -> IrObliviousSplit {
match preprocessing.get(feature_index) {
Some(FeaturePreprocessing::Binary) => IrObliviousSplit::BooleanTest {
feature_index,
feature_name: feature_name(feature_index),
bit_when_false: 0,
bit_when_true: 1,
},
Some(FeaturePreprocessing::Numeric { .. }) | None => {
IrObliviousSplit::NumericBinThreshold {
feature_index,
feature_name: feature_name(feature_index),
operator: "<=".to_string(),
threshold_bin,
threshold_upper_bound: threshold_upper_bound(
preprocessing,
feature_index,
threshold_bin,
),
comparison_dtype: "uint16".to_string(),
bit_when_true: 0,
bit_when_false: 1,
}
}
}
}