forestfire-core 0.4.0

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::*;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeStructureSummary {
    pub representation: String,
    pub node_count: usize,
    pub internal_node_count: usize,
    pub leaf_count: usize,
    pub actual_depth: usize,
    pub shortest_path: usize,
    pub longest_path: usize,
    pub average_path: f64,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionValueStats {
    pub count: usize,
    pub unique_count: usize,
    pub min: f64,
    pub max: f64,
    pub mean: f64,
    pub std_dev: f64,
    pub histogram: Vec<PredictionHistogramEntry>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionHistogramEntry {
    pub prediction: f64,
    pub count: usize,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IntrospectionError {
    TreeIndexOutOfBounds { requested: usize, available: usize },
    NodeIndexOutOfBounds { requested: usize, available: usize },
    LevelIndexOutOfBounds { requested: usize, available: usize },
    LeafIndexOutOfBounds { requested: usize, available: usize },
    NotANodeTree,
    NotAnObliviousTree,
}

impl Display for IntrospectionError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            IntrospectionError::TreeIndexOutOfBounds {
                requested,
                available,
            } => write!(
                f,
                "Tree index {} is out of bounds for model with {} trees.",
                requested, available
            ),
            IntrospectionError::NodeIndexOutOfBounds {
                requested,
                available,
            } => write!(
                f,
                "Node index {} is out of bounds for tree with {} nodes.",
                requested, available
            ),
            IntrospectionError::LevelIndexOutOfBounds {
                requested,
                available,
            } => write!(
                f,
                "Level index {} is out of bounds for tree with {} levels.",
                requested, available
            ),
            IntrospectionError::LeafIndexOutOfBounds {
                requested,
                available,
            } => write!(
                f,
                "Leaf index {} is out of bounds for tree with {} leaves.",
                requested, available
            ),
            IntrospectionError::NotANodeTree => write!(
                f,
                "This tree uses oblivious-level representation; inspect levels or leaves instead."
            ),
            IntrospectionError::NotAnObliviousTree => write!(
                f,
                "This tree uses node-tree representation; inspect nodes instead."
            ),
        }
    }
}

impl Error for IntrospectionError {}

pub(crate) fn tree_structure_summary(
    tree: ir::TreeDefinition,
) -> Result<TreeStructureSummary, IntrospectionError> {
    match tree {
        ir::TreeDefinition::NodeTree {
            root_node_id,
            nodes,
            ..
        } => {
            let node_map = nodes
                .iter()
                .cloned()
                .map(|node| match &node {
                    ir::NodeTreeNode::Leaf { node_id, .. }
                    | ir::NodeTreeNode::BinaryBranch { node_id, .. }
                    | ir::NodeTreeNode::MultiwayBranch { node_id, .. } => (*node_id, node),
                })
                .collect::<BTreeMap<_, _>>();
            let mut leaf_depths = Vec::new();
            collect_leaf_depths(&node_map, root_node_id, &mut leaf_depths)?;
            let internal_node_count = nodes
                .iter()
                .filter(|node| !matches!(node, ir::NodeTreeNode::Leaf { .. }))
                .count();
            let leaf_count = leaf_depths.len();
            let shortest_path = *leaf_depths.iter().min().unwrap_or(&0);
            let longest_path = *leaf_depths.iter().max().unwrap_or(&0);
            let average_path = if leaf_depths.is_empty() {
                0.0
            } else {
                leaf_depths.iter().sum::<usize>() as f64 / leaf_depths.len() as f64
            };
            Ok(TreeStructureSummary {
                representation: "node_tree".to_string(),
                node_count: internal_node_count + leaf_count,
                internal_node_count,
                leaf_count,
                actual_depth: longest_path,
                shortest_path,
                longest_path,
                average_path,
            })
        }
        ir::TreeDefinition::ObliviousLevels { depth, leaves, .. } => Ok(TreeStructureSummary {
            representation: "oblivious_levels".to_string(),
            node_count: ((1usize << depth) - 1) + leaves.len(),
            internal_node_count: (1usize << depth) - 1,
            leaf_count: leaves.len(),
            actual_depth: depth,
            shortest_path: depth,
            longest_path: depth,
            average_path: depth as f64,
        }),
    }
}

fn collect_leaf_depths(
    nodes: &BTreeMap<usize, ir::NodeTreeNode>,
    node_id: usize,
    output: &mut Vec<usize>,
) -> Result<(), IntrospectionError> {
    match nodes
        .get(&node_id)
        .ok_or(IntrospectionError::NodeIndexOutOfBounds {
            requested: node_id,
            available: nodes.len(),
        })? {
        ir::NodeTreeNode::Leaf { depth, .. } => output.push(*depth),
        ir::NodeTreeNode::BinaryBranch {
            depth: _, children, ..
        } => {
            collect_leaf_depths(nodes, children.left, output)?;
            collect_leaf_depths(nodes, children.right, output)?;
        }
        ir::NodeTreeNode::MultiwayBranch {
            depth,
            branches,
            unmatched_leaf: _,
            ..
        } => {
            output.push(depth + 1);
            for branch in branches {
                collect_leaf_depths(nodes, branch.child, output)?;
            }
        }
    }
    Ok(())
}

pub(crate) fn prediction_value_stats(
    tree: ir::TreeDefinition,
) -> Result<PredictionValueStats, IntrospectionError> {
    let predictions = match tree {
        ir::TreeDefinition::NodeTree { nodes, .. } => nodes
            .into_iter()
            .flat_map(|node| match node {
                ir::NodeTreeNode::Leaf { leaf, .. } => vec![leaf_payload_value(&leaf)],
                ir::NodeTreeNode::MultiwayBranch { unmatched_leaf, .. } => {
                    vec![leaf_payload_value(&unmatched_leaf)]
                }
                ir::NodeTreeNode::BinaryBranch { .. } => Vec::new(),
            })
            .collect::<Vec<_>>(),
        ir::TreeDefinition::ObliviousLevels { leaves, .. } => leaves
            .into_iter()
            .map(|leaf| leaf_payload_value(&leaf.leaf))
            .collect::<Vec<_>>(),
    };

    let count = predictions.len();
    let min = predictions
        .iter()
        .copied()
        .min_by(f64::total_cmp)
        .unwrap_or(0.0);
    let max = predictions
        .iter()
        .copied()
        .max_by(f64::total_cmp)
        .unwrap_or(0.0);
    let mean = if count == 0 {
        0.0
    } else {
        predictions.iter().sum::<f64>() / count as f64
    };
    let std_dev = if count == 0 {
        0.0
    } else {
        let variance = predictions
            .iter()
            .map(|value| (*value - mean).powi(2))
            .sum::<f64>()
            / count as f64;
        variance.sqrt()
    };
    let mut histogram = BTreeMap::<String, usize>::new();
    for prediction in &predictions {
        *histogram.entry(prediction.to_string()).or_insert(0) += 1;
    }
    let histogram = histogram
        .into_iter()
        .map(|(prediction, count)| PredictionHistogramEntry {
            prediction: prediction
                .parse::<f64>()
                .expect("histogram keys are numeric"),
            count,
        })
        .collect::<Vec<_>>();

    Ok(PredictionValueStats {
        count,
        unique_count: histogram.len(),
        min,
        max,
        mean,
        std_dev,
        histogram,
    })
}

fn leaf_payload_value(leaf: &ir::LeafPayload) -> f64 {
    match leaf {
        ir::LeafPayload::RegressionValue { value } => *value,
        ir::LeafPayload::ClassIndex { class_value, .. } => *class_value,
    }
}