Skip to main content

forestfire_core/
introspection.rs

1use super::*;
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct TreeStructureSummary {
5    pub representation: String,
6    pub node_count: usize,
7    pub internal_node_count: usize,
8    pub leaf_count: usize,
9    pub actual_depth: usize,
10    pub shortest_path: usize,
11    pub longest_path: usize,
12    pub average_path: f64,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictionValueStats {
17    pub count: usize,
18    pub unique_count: usize,
19    pub min: f64,
20    pub max: f64,
21    pub mean: f64,
22    pub std_dev: f64,
23    pub histogram: Vec<PredictionHistogramEntry>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PredictionHistogramEntry {
28    pub prediction: f64,
29    pub count: usize,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum IntrospectionError {
34    TreeIndexOutOfBounds { requested: usize, available: usize },
35    NodeIndexOutOfBounds { requested: usize, available: usize },
36    LevelIndexOutOfBounds { requested: usize, available: usize },
37    LeafIndexOutOfBounds { requested: usize, available: usize },
38    NotANodeTree,
39    NotAnObliviousTree,
40}
41
42impl Display for IntrospectionError {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        match self {
45            IntrospectionError::TreeIndexOutOfBounds {
46                requested,
47                available,
48            } => write!(
49                f,
50                "Tree index {} is out of bounds for model with {} trees.",
51                requested, available
52            ),
53            IntrospectionError::NodeIndexOutOfBounds {
54                requested,
55                available,
56            } => write!(
57                f,
58                "Node index {} is out of bounds for tree with {} nodes.",
59                requested, available
60            ),
61            IntrospectionError::LevelIndexOutOfBounds {
62                requested,
63                available,
64            } => write!(
65                f,
66                "Level index {} is out of bounds for tree with {} levels.",
67                requested, available
68            ),
69            IntrospectionError::LeafIndexOutOfBounds {
70                requested,
71                available,
72            } => write!(
73                f,
74                "Leaf index {} is out of bounds for tree with {} leaves.",
75                requested, available
76            ),
77            IntrospectionError::NotANodeTree => write!(
78                f,
79                "This tree uses oblivious-level representation; inspect levels or leaves instead."
80            ),
81            IntrospectionError::NotAnObliviousTree => write!(
82                f,
83                "This tree uses node-tree representation; inspect nodes instead."
84            ),
85        }
86    }
87}
88
89impl Error for IntrospectionError {}
90
91pub(crate) fn tree_structure_summary(
92    tree: ir::TreeDefinition,
93) -> Result<TreeStructureSummary, IntrospectionError> {
94    match tree {
95        ir::TreeDefinition::NodeTree {
96            root_node_id,
97            nodes,
98            ..
99        } => {
100            let node_map = nodes
101                .iter()
102                .cloned()
103                .map(|node| match &node {
104                    ir::NodeTreeNode::Leaf { node_id, .. }
105                    | ir::NodeTreeNode::BinaryBranch { node_id, .. }
106                    | ir::NodeTreeNode::MultiwayBranch { node_id, .. } => (*node_id, node),
107                })
108                .collect::<BTreeMap<_, _>>();
109            let mut leaf_depths = Vec::new();
110            collect_leaf_depths(&node_map, root_node_id, &mut leaf_depths)?;
111            let internal_node_count = nodes
112                .iter()
113                .filter(|node| !matches!(node, ir::NodeTreeNode::Leaf { .. }))
114                .count();
115            let leaf_count = leaf_depths.len();
116            let shortest_path = *leaf_depths.iter().min().unwrap_or(&0);
117            let longest_path = *leaf_depths.iter().max().unwrap_or(&0);
118            let average_path = if leaf_depths.is_empty() {
119                0.0
120            } else {
121                leaf_depths.iter().sum::<usize>() as f64 / leaf_depths.len() as f64
122            };
123            Ok(TreeStructureSummary {
124                representation: "node_tree".to_string(),
125                node_count: internal_node_count + leaf_count,
126                internal_node_count,
127                leaf_count,
128                actual_depth: longest_path,
129                shortest_path,
130                longest_path,
131                average_path,
132            })
133        }
134        ir::TreeDefinition::ObliviousLevels { depth, leaves, .. } => Ok(TreeStructureSummary {
135            representation: "oblivious_levels".to_string(),
136            node_count: ((1usize << depth) - 1) + leaves.len(),
137            internal_node_count: (1usize << depth) - 1,
138            leaf_count: leaves.len(),
139            actual_depth: depth,
140            shortest_path: depth,
141            longest_path: depth,
142            average_path: depth as f64,
143        }),
144    }
145}
146
147fn collect_leaf_depths(
148    nodes: &BTreeMap<usize, ir::NodeTreeNode>,
149    node_id: usize,
150    output: &mut Vec<usize>,
151) -> Result<(), IntrospectionError> {
152    match nodes
153        .get(&node_id)
154        .ok_or(IntrospectionError::NodeIndexOutOfBounds {
155            requested: node_id,
156            available: nodes.len(),
157        })? {
158        ir::NodeTreeNode::Leaf { depth, .. } => output.push(*depth),
159        ir::NodeTreeNode::BinaryBranch {
160            depth: _, children, ..
161        } => {
162            collect_leaf_depths(nodes, children.left, output)?;
163            collect_leaf_depths(nodes, children.right, output)?;
164        }
165        ir::NodeTreeNode::MultiwayBranch {
166            depth,
167            branches,
168            unmatched_leaf: _,
169            ..
170        } => {
171            output.push(depth + 1);
172            for branch in branches {
173                collect_leaf_depths(nodes, branch.child, output)?;
174            }
175        }
176    }
177    Ok(())
178}
179
180pub(crate) fn prediction_value_stats(
181    tree: ir::TreeDefinition,
182) -> Result<PredictionValueStats, IntrospectionError> {
183    let predictions = match tree {
184        ir::TreeDefinition::NodeTree { nodes, .. } => nodes
185            .into_iter()
186            .flat_map(|node| match node {
187                ir::NodeTreeNode::Leaf { leaf, .. } => vec![leaf_payload_value(&leaf)],
188                ir::NodeTreeNode::MultiwayBranch { unmatched_leaf, .. } => {
189                    vec![leaf_payload_value(&unmatched_leaf)]
190                }
191                ir::NodeTreeNode::BinaryBranch { .. } => Vec::new(),
192            })
193            .collect::<Vec<_>>(),
194        ir::TreeDefinition::ObliviousLevels { leaves, .. } => leaves
195            .into_iter()
196            .map(|leaf| leaf_payload_value(&leaf.leaf))
197            .collect::<Vec<_>>(),
198    };
199
200    let count = predictions.len();
201    let min = predictions
202        .iter()
203        .copied()
204        .min_by(f64::total_cmp)
205        .unwrap_or(0.0);
206    let max = predictions
207        .iter()
208        .copied()
209        .max_by(f64::total_cmp)
210        .unwrap_or(0.0);
211    let mean = if count == 0 {
212        0.0
213    } else {
214        predictions.iter().sum::<f64>() / count as f64
215    };
216    let std_dev = if count == 0 {
217        0.0
218    } else {
219        let variance = predictions
220            .iter()
221            .map(|value| (*value - mean).powi(2))
222            .sum::<f64>()
223            / count as f64;
224        variance.sqrt()
225    };
226    let mut histogram = BTreeMap::<String, usize>::new();
227    for prediction in &predictions {
228        *histogram.entry(prediction.to_string()).or_insert(0) += 1;
229    }
230    let histogram = histogram
231        .into_iter()
232        .map(|(prediction, count)| PredictionHistogramEntry {
233            prediction: prediction
234                .parse::<f64>()
235                .expect("histogram keys are numeric"),
236            count,
237        })
238        .collect::<Vec<_>>();
239
240    Ok(PredictionValueStats {
241        count,
242        unique_count: histogram.len(),
243        min,
244        max,
245        mean,
246        std_dev,
247        histogram,
248    })
249}
250
251fn leaf_payload_value(leaf: &ir::LeafPayload) -> f64 {
252    match leaf {
253        ir::LeafPayload::RegressionValue { value } => *value,
254        ir::LeafPayload::ClassIndex { class_value, .. } => *class_value,
255    }
256}