1use alloc::vec;
8use alloc::vec::Vec;
9
10use crate::tree::node::{NodeId, TreeArena};
11
12#[inline]
22pub fn traverse_to_leaf(arena: &TreeArena, root: NodeId, features: &[f64]) -> NodeId {
23 let mut current = root;
24 loop {
25 let idx = current.idx();
26 if arena.is_leaf[idx] {
27 return current;
28 }
29 let feat_idx = arena.feature_idx[idx] as usize;
30 if features[feat_idx] <= arena.threshold[idx] {
31 current = arena.left[idx];
32 } else {
33 current = arena.right[idx];
34 }
35 }
36}
37
38#[inline]
42pub fn predict_from_root(arena: &TreeArena, root: NodeId, features: &[f64]) -> f64 {
43 let leaf = traverse_to_leaf(arena, root, features);
44 arena.leaf_value[leaf.idx()]
45}
46
47pub fn predict_batch(arena: &TreeArena, root: NodeId, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
52 feature_matrix
53 .iter()
54 .map(|features| predict_from_root(arena, root, features))
55 .collect()
56}
57
58pub fn collect_leaves(arena: &TreeArena, root: NodeId) -> Vec<NodeId> {
63 let mut leaves = Vec::new();
64 let mut stack = vec![root];
65 while let Some(node) = stack.pop() {
66 let idx = node.idx();
67 if arena.is_leaf[idx] {
68 leaves.push(node);
69 } else {
70 stack.push(arena.right[idx]);
72 stack.push(arena.left[idx]);
73 }
74 }
75 leaves
76}
77
78pub fn tree_depth(arena: &TreeArena, root: NodeId) -> u16 {
83 let leaves = collect_leaves(arena, root);
84 leaves
85 .iter()
86 .map(|id| arena.depth[id.idx()])
87 .max()
88 .unwrap_or(0)
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::tree::node::{NodeId, TreeArena};
95
96 fn empty_arena() -> TreeArena {
98 TreeArena {
99 feature_idx: Vec::new(),
100 threshold: Vec::new(),
101 left: Vec::new(),
102 right: Vec::new(),
103 leaf_value: Vec::new(),
104 is_leaf: Vec::new(),
105 depth: Vec::new(),
106 sample_count: Vec::new(),
107 categorical_mask: Vec::new(),
108 }
109 }
110
111 fn push_leaf(arena: &mut TreeArena, value: f64, depth: u16) -> NodeId {
113 let id = NodeId(arena.is_leaf.len() as u32);
114 arena.feature_idx.push(0);
115 arena.threshold.push(0.0);
116 arena.left.push(NodeId::NONE);
117 arena.right.push(NodeId::NONE);
118 arena.leaf_value.push(value);
119 arena.is_leaf.push(true);
120 arena.depth.push(depth);
121 arena.sample_count.push(0);
122 arena.categorical_mask.push(None);
123 id
124 }
125
126 fn convert_to_split(
130 arena: &mut TreeArena,
131 node: NodeId,
132 feature: u32,
133 threshold: f64,
134 left: NodeId,
135 right: NodeId,
136 ) {
137 let idx = node.idx();
138 arena.feature_idx[idx] = feature;
139 arena.threshold[idx] = threshold;
140 arena.left[idx] = left;
141 arena.right[idx] = right;
142 arena.is_leaf[idx] = false;
143 }
144
145 #[test]
150 fn single_leaf_returns_value() {
151 let mut arena = empty_arena();
152 let root = push_leaf(&mut arena, 0.0, 0);
153
154 assert_eq!(predict_from_root(&arena, root, &[1.0, 2.0, 3.0]), 0.0);
156 }
157
158 #[test]
159 fn single_leaf_with_nonzero_value() {
160 let mut arena = empty_arena();
161 let root = push_leaf(&mut arena, 42.5, 0);
162
163 assert_eq!(predict_from_root(&arena, root, &[]), 42.5);
164 }
165
166 fn build_one_split() -> (TreeArena, NodeId) {
177 let mut arena = empty_arena();
178
179 let root = push_leaf(&mut arena, 0.0, 0);
181 let left_child = push_leaf(&mut arena, -1.0, 1);
182 let right_child = push_leaf(&mut arena, 1.0, 1);
183 convert_to_split(&mut arena, root, 0, 5.0, left_child, right_child);
184
185 (arena, root)
186 }
187
188 #[test]
189 fn one_split_goes_left() {
190 let (arena, root) = build_one_split();
191 assert_eq!(predict_from_root(&arena, root, &[3.0]), -1.0);
193 }
194
195 #[test]
196 fn one_split_goes_right() {
197 let (arena, root) = build_one_split();
198 assert_eq!(predict_from_root(&arena, root, &[7.0]), 1.0);
200 }
201
202 #[test]
203 fn one_split_equal_goes_left() {
204 let (arena, root) = build_one_split();
205 assert_eq!(predict_from_root(&arena, root, &[5.0]), -1.0);
207 }
208
209 fn build_two_level() -> (TreeArena, NodeId) {
222 let mut arena = empty_arena();
223
224 let root = push_leaf(&mut arena, 0.0, 0); let inner = push_leaf(&mut arena, 0.0, 1); let right_leaf = push_leaf(&mut arena, 10.0, 1); let left_left = push_leaf(&mut arena, -5.0, 2); let left_right = push_leaf(&mut arena, 3.0, 2); convert_to_split(&mut arena, root, 0, 5.0, inner, right_leaf);
231 convert_to_split(&mut arena, inner, 1, 2.0, left_left, left_right);
232
233 (arena, root)
234 }
235
236 #[test]
237 fn two_level_reaches_left_left() {
238 let (arena, root) = build_two_level();
239 assert_eq!(predict_from_root(&arena, root, &[1.0, 0.5]), -5.0);
241 }
242
243 #[test]
244 fn two_level_reaches_left_right() {
245 let (arena, root) = build_two_level();
246 assert_eq!(predict_from_root(&arena, root, &[4.0, 3.0]), 3.0);
248 }
249
250 #[test]
251 fn two_level_reaches_right_leaf() {
252 let (arena, root) = build_two_level();
253 assert_eq!(predict_from_root(&arena, root, &[8.0, 999.0]), 10.0);
255 }
256
257 #[test]
262 fn batch_matches_individual() {
263 let (arena, root) = build_two_level();
264
265 let rows = vec![
266 vec![1.0, 0.5],
267 vec![4.0, 3.0],
268 vec![8.0, 0.0],
269 vec![5.0, 2.0], ];
271
272 let batch = predict_batch(&arena, root, &rows);
273
274 for (i, row) in rows.iter().enumerate() {
275 let individual = predict_from_root(&arena, root, row);
276 assert_eq!(
277 batch[i], individual,
278 "batch[{}] = {} but individual = {} for features {:?}",
279 i, batch[i], individual, row
280 );
281 }
282 }
283
284 #[test]
285 fn batch_empty_input() {
286 let (arena, root) = build_one_split();
287 let result = predict_batch(&arena, root, &[]);
288 assert!(result.is_empty());
289 }
290
291 #[test]
296 fn collect_leaves_single_leaf() {
297 let mut arena = empty_arena();
298 let root = push_leaf(&mut arena, 0.0, 0);
299 let leaves = collect_leaves(&arena, root);
300 assert_eq!(leaves.len(), 1);
301 assert_eq!(leaves[0].idx(), root.idx());
302 }
303
304 #[test]
305 fn collect_leaves_one_split() {
306 let (arena, root) = build_one_split();
307 let leaves = collect_leaves(&arena, root);
308 assert_eq!(leaves.len(), 2);
309 }
310
311 #[test]
312 fn collect_leaves_two_level() {
313 let (arena, root) = build_two_level();
314 let leaves = collect_leaves(&arena, root);
315 assert_eq!(leaves.len(), 3);
317
318 let values: Vec<f64> = leaves.iter().map(|id| arena.leaf_value[id.idx()]).collect();
320 assert_eq!(values, vec![-5.0, 3.0, 10.0]);
321 }
322
323 #[test]
324 fn collect_leaves_balanced_depth2() {
325 let mut arena = empty_arena();
333
334 let root = push_leaf(&mut arena, 0.0, 0);
335 let left = push_leaf(&mut arena, 0.0, 1);
336 let right = push_leaf(&mut arena, 0.0, 1);
337 let ll = push_leaf(&mut arena, 1.0, 2);
338 let lr = push_leaf(&mut arena, 2.0, 2);
339 let rl = push_leaf(&mut arena, 3.0, 2);
340 let rr = push_leaf(&mut arena, 4.0, 2);
341
342 convert_to_split(&mut arena, root, 0, 5.0, left, right);
343 convert_to_split(&mut arena, left, 1, 2.0, ll, lr);
344 convert_to_split(&mut arena, right, 1, 8.0, rl, rr);
345
346 let leaves = collect_leaves(&arena, root);
347 assert_eq!(leaves.len(), 4);
348
349 let values: Vec<f64> = leaves.iter().map(|id| arena.leaf_value[id.idx()]).collect();
350 assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);
351 }
352
353 #[test]
358 fn depth_single_leaf() {
359 let mut arena = empty_arena();
360 let root = push_leaf(&mut arena, 0.0, 0);
361 assert_eq!(tree_depth(&arena, root), 0);
362 }
363
364 #[test]
365 fn depth_one_split() {
366 let (arena, root) = build_one_split();
367 assert_eq!(tree_depth(&arena, root), 1);
368 }
369
370 #[test]
371 fn depth_two_level_unbalanced() {
372 let (arena, root) = build_two_level();
373 assert_eq!(tree_depth(&arena, root), 2);
376 }
377
378 #[test]
379 fn depth_left_skewed() {
380 let mut arena = empty_arena();
392
393 let n0 = push_leaf(&mut arena, 0.0, 0);
394 let n1 = push_leaf(&mut arena, 0.0, 1);
395 let n2 = push_leaf(&mut arena, 0.0, 1);
396 let n3 = push_leaf(&mut arena, 0.0, 2);
397 let n4 = push_leaf(&mut arena, 0.0, 2);
398 let n5 = push_leaf(&mut arena, 0.0, 3);
399 let n6 = push_leaf(&mut arena, 0.0, 3);
400 let n7 = push_leaf(&mut arena, 0.0, 4);
401 let n8 = push_leaf(&mut arena, 0.0, 4);
402
403 convert_to_split(&mut arena, n0, 0, 1.0, n1, n2);
404 convert_to_split(&mut arena, n1, 0, 2.0, n3, n4);
405 convert_to_split(&mut arena, n3, 0, 3.0, n5, n6);
406 convert_to_split(&mut arena, n5, 0, 4.0, n7, n8);
407
408 assert_eq!(tree_depth(&arena, n0), 4);
409 assert_eq!(collect_leaves(&arena, n0).len(), 5);
411 }
412
413 #[test]
418 fn threshold_equality_goes_left() {
419 let (arena, root) = build_one_split();
420 let leaf = traverse_to_leaf(&arena, root, &[5.0]);
421 assert_eq!(leaf.idx(), 1, "value == threshold must route left");
423 assert_eq!(arena.leaf_value[leaf.idx()], -1.0);
424 }
425
426 #[test]
427 fn threshold_equality_two_level() {
428 let (arena, root) = build_two_level();
429 assert_eq!(predict_from_root(&arena, root, &[5.0, 2.0]), -5.0);
432 }
433
434 #[test]
439 fn traverse_returns_correct_node_id() {
440 let (arena, root) = build_two_level();
441
442 let leaf = traverse_to_leaf(&arena, root, &[0.0, 0.0]);
444 assert_eq!(leaf.idx(), 3);
445
446 let leaf = traverse_to_leaf(&arena, root, &[0.0, 5.0]);
448 assert_eq!(leaf.idx(), 4);
449
450 let leaf = traverse_to_leaf(&arena, root, &[10.0, 0.0]);
452 assert_eq!(leaf.idx(), 2);
453 }
454}