1use super::*;
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct TreeStructureSummary {
5 pub representation: String,
6 pub node_count: usize,
7 pub internal_node_count: usize,
8 pub leaf_count: usize,
9 pub actual_depth: usize,
10 pub shortest_path: usize,
11 pub longest_path: usize,
12 pub average_path: f64,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PredictionValueStats {
17 pub count: usize,
18 pub unique_count: usize,
19 pub min: f64,
20 pub max: f64,
21 pub mean: f64,
22 pub std_dev: f64,
23 pub histogram: Vec<PredictionHistogramEntry>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct PredictionHistogramEntry {
28 pub prediction: f64,
29 pub count: usize,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum IntrospectionError {
34 TreeIndexOutOfBounds { requested: usize, available: usize },
35 NodeIndexOutOfBounds { requested: usize, available: usize },
36 LevelIndexOutOfBounds { requested: usize, available: usize },
37 LeafIndexOutOfBounds { requested: usize, available: usize },
38 NotANodeTree,
39 NotAnObliviousTree,
40}
41
42impl Display for IntrospectionError {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 match self {
45 IntrospectionError::TreeIndexOutOfBounds {
46 requested,
47 available,
48 } => write!(
49 f,
50 "Tree index {} is out of bounds for model with {} trees.",
51 requested, available
52 ),
53 IntrospectionError::NodeIndexOutOfBounds {
54 requested,
55 available,
56 } => write!(
57 f,
58 "Node index {} is out of bounds for tree with {} nodes.",
59 requested, available
60 ),
61 IntrospectionError::LevelIndexOutOfBounds {
62 requested,
63 available,
64 } => write!(
65 f,
66 "Level index {} is out of bounds for tree with {} levels.",
67 requested, available
68 ),
69 IntrospectionError::LeafIndexOutOfBounds {
70 requested,
71 available,
72 } => write!(
73 f,
74 "Leaf index {} is out of bounds for tree with {} leaves.",
75 requested, available
76 ),
77 IntrospectionError::NotANodeTree => write!(
78 f,
79 "This tree uses oblivious-level representation; inspect levels or leaves instead."
80 ),
81 IntrospectionError::NotAnObliviousTree => write!(
82 f,
83 "This tree uses node-tree representation; inspect nodes instead."
84 ),
85 }
86 }
87}
88
89impl Error for IntrospectionError {}
90
91pub(crate) fn tree_structure_summary(
92 tree: ir::TreeDefinition,
93) -> Result<TreeStructureSummary, IntrospectionError> {
94 match tree {
95 ir::TreeDefinition::NodeTree {
96 root_node_id,
97 nodes,
98 ..
99 } => {
100 let node_map = nodes
101 .iter()
102 .cloned()
103 .map(|node| match &node {
104 ir::NodeTreeNode::Leaf { node_id, .. }
105 | ir::NodeTreeNode::BinaryBranch { node_id, .. }
106 | ir::NodeTreeNode::MultiwayBranch { node_id, .. } => (*node_id, node),
107 })
108 .collect::<BTreeMap<_, _>>();
109 let mut leaf_depths = Vec::new();
110 collect_leaf_depths(&node_map, root_node_id, &mut leaf_depths)?;
111 let internal_node_count = nodes
112 .iter()
113 .filter(|node| !matches!(node, ir::NodeTreeNode::Leaf { .. }))
114 .count();
115 let leaf_count = leaf_depths.len();
116 let shortest_path = *leaf_depths.iter().min().unwrap_or(&0);
117 let longest_path = *leaf_depths.iter().max().unwrap_or(&0);
118 let average_path = if leaf_depths.is_empty() {
119 0.0
120 } else {
121 leaf_depths.iter().sum::<usize>() as f64 / leaf_depths.len() as f64
122 };
123 Ok(TreeStructureSummary {
124 representation: "node_tree".to_string(),
125 node_count: internal_node_count + leaf_count,
126 internal_node_count,
127 leaf_count,
128 actual_depth: longest_path,
129 shortest_path,
130 longest_path,
131 average_path,
132 })
133 }
134 ir::TreeDefinition::ObliviousLevels { depth, leaves, .. } => Ok(TreeStructureSummary {
135 representation: "oblivious_levels".to_string(),
136 node_count: ((1usize << depth) - 1) + leaves.len(),
137 internal_node_count: (1usize << depth) - 1,
138 leaf_count: leaves.len(),
139 actual_depth: depth,
140 shortest_path: depth,
141 longest_path: depth,
142 average_path: depth as f64,
143 }),
144 }
145}
146
147fn collect_leaf_depths(
148 nodes: &BTreeMap<usize, ir::NodeTreeNode>,
149 node_id: usize,
150 output: &mut Vec<usize>,
151) -> Result<(), IntrospectionError> {
152 match nodes
153 .get(&node_id)
154 .ok_or(IntrospectionError::NodeIndexOutOfBounds {
155 requested: node_id,
156 available: nodes.len(),
157 })? {
158 ir::NodeTreeNode::Leaf { depth, .. } => output.push(*depth),
159 ir::NodeTreeNode::BinaryBranch {
160 depth: _, children, ..
161 } => {
162 collect_leaf_depths(nodes, children.left, output)?;
163 collect_leaf_depths(nodes, children.right, output)?;
164 }
165 ir::NodeTreeNode::MultiwayBranch {
166 depth,
167 branches,
168 unmatched_leaf: _,
169 ..
170 } => {
171 output.push(depth + 1);
172 for branch in branches {
173 collect_leaf_depths(nodes, branch.child, output)?;
174 }
175 }
176 }
177 Ok(())
178}
179
180pub(crate) fn prediction_value_stats(
181 tree: ir::TreeDefinition,
182) -> Result<PredictionValueStats, IntrospectionError> {
183 let predictions = match tree {
184 ir::TreeDefinition::NodeTree { nodes, .. } => nodes
185 .into_iter()
186 .flat_map(|node| match node {
187 ir::NodeTreeNode::Leaf { leaf, .. } => vec![leaf_payload_value(&leaf)],
188 ir::NodeTreeNode::MultiwayBranch { unmatched_leaf, .. } => {
189 vec![leaf_payload_value(&unmatched_leaf)]
190 }
191 ir::NodeTreeNode::BinaryBranch { .. } => Vec::new(),
192 })
193 .collect::<Vec<_>>(),
194 ir::TreeDefinition::ObliviousLevels { leaves, .. } => leaves
195 .into_iter()
196 .map(|leaf| leaf_payload_value(&leaf.leaf))
197 .collect::<Vec<_>>(),
198 };
199
200 let count = predictions.len();
201 let min = predictions
202 .iter()
203 .copied()
204 .min_by(f64::total_cmp)
205 .unwrap_or(0.0);
206 let max = predictions
207 .iter()
208 .copied()
209 .max_by(f64::total_cmp)
210 .unwrap_or(0.0);
211 let mean = if count == 0 {
212 0.0
213 } else {
214 predictions.iter().sum::<f64>() / count as f64
215 };
216 let std_dev = if count == 0 {
217 0.0
218 } else {
219 let variance = predictions
220 .iter()
221 .map(|value| (*value - mean).powi(2))
222 .sum::<f64>()
223 / count as f64;
224 variance.sqrt()
225 };
226 let mut histogram = BTreeMap::<String, usize>::new();
227 for prediction in &predictions {
228 *histogram.entry(prediction.to_string()).or_insert(0) += 1;
229 }
230 let histogram = histogram
231 .into_iter()
232 .map(|(prediction, count)| PredictionHistogramEntry {
233 prediction: prediction
234 .parse::<f64>()
235 .expect("histogram keys are numeric"),
236 count,
237 })
238 .collect::<Vec<_>>();
239
240 Ok(PredictionValueStats {
241 count,
242 unique_count: histogram.len(),
243 min,
244 max,
245 mean,
246 std_dev,
247 histogram,
248 })
249}
250
251fn leaf_payload_value(leaf: &ir::LeafPayload) -> f64 {
252 match leaf {
253 ir::LeafPayload::RegressionValue { value } => *value,
254 ir::LeafPayload::ClassIndex { class_value, .. } => *class_value,
255 }
256}