1use alloc::vec::Vec;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct NodeId(pub u32);
21
22impl NodeId {
23 pub const NONE: NodeId = NodeId(u32::MAX);
25
26 #[inline]
28 pub fn is_none(self) -> bool {
29 self.0 == u32::MAX
30 }
31
32 #[inline]
34 pub fn idx(self) -> usize {
35 self.0 as usize
36 }
37}
38
39#[derive(Debug, Clone)]
52pub struct TreeArena {
53 pub feature_idx: Vec<u32>,
55 pub threshold: Vec<f64>,
57 pub left: Vec<NodeId>,
59 pub right: Vec<NodeId>,
61 pub leaf_value: Vec<f64>,
63 pub is_leaf: Vec<bool>,
65 pub depth: Vec<u16>,
67 pub sample_count: Vec<u64>,
69 pub categorical_mask: Vec<Option<u64>>,
72}
73
74impl TreeArena {
75 pub fn new() -> Self {
77 Self {
78 feature_idx: Vec::new(),
79 threshold: Vec::new(),
80 left: Vec::new(),
81 right: Vec::new(),
82 leaf_value: Vec::new(),
83 is_leaf: Vec::new(),
84 depth: Vec::new(),
85 sample_count: Vec::new(),
86 categorical_mask: Vec::new(),
87 }
88 }
89
90 pub fn with_capacity(cap: usize) -> Self {
94 Self {
95 feature_idx: Vec::with_capacity(cap),
96 threshold: Vec::with_capacity(cap),
97 left: Vec::with_capacity(cap),
98 right: Vec::with_capacity(cap),
99 leaf_value: Vec::with_capacity(cap),
100 is_leaf: Vec::with_capacity(cap),
101 depth: Vec::with_capacity(cap),
102 sample_count: Vec::with_capacity(cap),
103 categorical_mask: Vec::with_capacity(cap),
104 }
105 }
106
107 pub fn add_leaf(&mut self, depth: u16) -> NodeId {
111 let id = self.feature_idx.len() as u32;
112 self.feature_idx.push(0);
113 self.threshold.push(0.0);
114 self.left.push(NodeId::NONE);
115 self.right.push(NodeId::NONE);
116 self.leaf_value.push(0.0);
117 self.is_leaf.push(true);
118 self.depth.push(depth);
119 self.sample_count.push(0);
120 self.categorical_mask.push(None);
121 NodeId(id)
122 }
123
124 pub fn split_leaf(
139 &mut self,
140 leaf_id: NodeId,
141 feature_idx: u32,
142 threshold: f64,
143 left_value: f64,
144 right_value: f64,
145 ) -> (NodeId, NodeId) {
146 let i = leaf_id.idx();
147 assert!(
148 self.is_leaf[i],
149 "split_leaf called on non-leaf node {:?}",
150 leaf_id
151 );
152
153 let child_depth = self.depth[i] + 1;
154
155 let left_id = self.add_leaf(child_depth);
157 self.leaf_value[left_id.idx()] = left_value;
158
159 let right_id = self.add_leaf(child_depth);
161 self.leaf_value[right_id.idx()] = right_value;
162
163 self.is_leaf[i] = false;
165 self.feature_idx[i] = feature_idx;
166 self.threshold[i] = threshold;
167 self.left[i] = left_id;
168 self.right[i] = right_id;
169
170 (left_id, right_id)
171 }
172
173 pub fn split_leaf_categorical(
183 &mut self,
184 leaf_id: NodeId,
185 feature_idx: u32,
186 threshold: f64,
187 left_value: f64,
188 right_value: f64,
189 mask: u64,
190 ) -> (NodeId, NodeId) {
191 let (left_id, right_id) =
192 self.split_leaf(leaf_id, feature_idx, threshold, left_value, right_value);
193 self.categorical_mask[leaf_id.idx()] = Some(mask);
194 (left_id, right_id)
195 }
196
197 #[inline]
199 pub fn get_categorical_mask(&self, id: NodeId) -> Option<u64> {
200 self.categorical_mask[id.idx()]
201 }
202
203 #[inline]
205 pub fn is_leaf(&self, id: NodeId) -> bool {
206 self.is_leaf[id.idx()]
207 }
208
209 #[inline]
215 pub fn predict(&self, id: NodeId) -> f64 {
216 let i = id.idx();
217 assert!(self.is_leaf[i], "predict called on internal node {:?}", id);
218 self.leaf_value[i]
219 }
220
221 #[inline]
227 pub fn set_leaf_value(&mut self, id: NodeId, value: f64) {
228 let i = id.idx();
229 assert!(
230 self.is_leaf[i],
231 "set_leaf_value called on internal node {:?}",
232 id
233 );
234 self.leaf_value[i] = value;
235 }
236
237 #[inline]
239 pub fn get_depth(&self, id: NodeId) -> u16 {
240 self.depth[id.idx()]
241 }
242
243 #[inline]
245 pub fn get_feature_idx(&self, id: NodeId) -> u32 {
246 self.feature_idx[id.idx()]
247 }
248
249 #[inline]
251 pub fn get_threshold(&self, id: NodeId) -> f64 {
252 self.threshold[id.idx()]
253 }
254
255 #[inline]
257 pub fn get_left(&self, id: NodeId) -> NodeId {
258 self.left[id.idx()]
259 }
260
261 #[inline]
263 pub fn get_right(&self, id: NodeId) -> NodeId {
264 self.right[id.idx()]
265 }
266
267 #[inline]
269 pub fn get_sample_count(&self, id: NodeId) -> u64 {
270 self.sample_count[id.idx()]
271 }
272
273 #[inline]
275 pub fn increment_sample_count(&mut self, id: NodeId) {
276 self.sample_count[id.idx()] += 1;
277 }
278
279 #[inline]
281 pub fn n_nodes(&self) -> usize {
282 self.is_leaf.len()
283 }
284
285 pub fn n_leaves(&self) -> usize {
287 self.is_leaf.iter().filter(|&&b| b).count()
288 }
289
290 pub fn reset(&mut self) {
295 self.feature_idx.clear();
296 self.threshold.clear();
297 self.left.clear();
298 self.right.clear();
299 self.leaf_value.clear();
300 self.is_leaf.clear();
301 self.depth.clear();
302 self.sample_count.clear();
303 self.categorical_mask.clear();
304 }
305}
306
307impl Default for TreeArena {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
320 fn single_leaf() {
321 let mut arena = TreeArena::new();
322 let root = arena.add_leaf(0);
323
324 assert_eq!(root, NodeId(0));
325 assert!(arena.is_leaf(root));
326 assert_eq!(arena.predict(root), 0.0);
327 assert_eq!(arena.get_depth(root), 0);
328 assert_eq!(arena.get_sample_count(root), 0);
329 assert_eq!(arena.get_left(root), NodeId::NONE);
330 assert_eq!(arena.get_right(root), NodeId::NONE);
331 }
332
333 #[test]
336 fn split_leaf_basic() {
337 let mut arena = TreeArena::new();
338 let root = arena.add_leaf(0);
339
340 let (left, right) = arena.split_leaf(root, 3, 1.5, -0.25, 0.75);
341
342 assert!(!arena.is_leaf(root));
344 assert_eq!(arena.get_feature_idx(root), 3);
345 assert_eq!(arena.get_threshold(root), 1.5);
346 assert_eq!(arena.get_left(root), left);
347 assert_eq!(arena.get_right(root), right);
348
349 assert!(arena.is_leaf(left));
351 assert_eq!(arena.predict(left), -0.25);
352 assert_eq!(arena.get_depth(left), 1);
353
354 assert!(arena.is_leaf(right));
355 assert_eq!(arena.predict(right), 0.75);
356 assert_eq!(arena.get_depth(right), 1);
357 }
358
359 #[test]
361 fn split_child_three_levels() {
362 let mut arena = TreeArena::new();
363 let root = arena.add_leaf(0);
364
365 let (left, right) = arena.split_leaf(root, 0, 5.0, 0.0, 0.0);
367
368 let (ll, lr) = arena.split_leaf(left, 1, 2.0, -1.0, 1.0);
370
371 assert!(!arena.is_leaf(root));
373 assert_eq!(arena.get_depth(root), 0);
374
375 assert!(!arena.is_leaf(left));
377 assert_eq!(arena.get_depth(left), 1);
378 assert_eq!(arena.get_feature_idx(left), 1);
379 assert_eq!(arena.get_threshold(left), 2.0);
380 assert_eq!(arena.get_left(left), ll);
381 assert_eq!(arena.get_right(left), lr);
382
383 assert!(arena.is_leaf(right));
385 assert_eq!(arena.get_depth(right), 1);
386
387 assert!(arena.is_leaf(ll));
389 assert_eq!(arena.get_depth(ll), 2);
390 assert_eq!(arena.predict(ll), -1.0);
391
392 assert!(arena.is_leaf(lr));
393 assert_eq!(arena.get_depth(lr), 2);
394 assert_eq!(arena.predict(lr), 1.0);
395 }
396
397 #[test]
399 fn node_and_leaf_counting() {
400 let mut arena = TreeArena::new();
401
402 assert_eq!(arena.n_nodes(), 0);
404 assert_eq!(arena.n_leaves(), 0);
405
406 let root = arena.add_leaf(0);
408 assert_eq!(arena.n_nodes(), 1);
409 assert_eq!(arena.n_leaves(), 1);
410
411 let (_left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
413 assert_eq!(arena.n_nodes(), 3);
414 assert_eq!(arena.n_leaves(), 2);
415
416 let _ = arena.split_leaf(right, 1, 2.0, 0.0, 0.0);
418 assert_eq!(arena.n_nodes(), 5);
419 assert_eq!(arena.n_leaves(), 3);
420 }
421
422 #[test]
425 fn node_id_none_sentinel() {
426 let none = NodeId::NONE;
427 assert!(none.is_none());
428 assert_eq!(none.0, u32::MAX);
429
430 let valid = NodeId(0);
431 assert!(!valid.is_none());
432 assert_ne!(valid, NodeId::NONE);
433 }
434
435 #[test]
438 fn reset_clears_everything() {
439 let mut arena = TreeArena::new();
440 let root = arena.add_leaf(0);
441 let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
442
443 assert_eq!(arena.n_nodes(), 3);
444
445 arena.reset();
446
447 assert_eq!(arena.n_nodes(), 0);
448 assert_eq!(arena.n_leaves(), 0);
449
450 assert!(arena.feature_idx.capacity() >= 3);
452 assert!(arena.is_leaf.capacity() >= 3);
453
454 let new_root = arena.add_leaf(0);
456 assert_eq!(new_root, NodeId(0));
457 assert_eq!(arena.n_nodes(), 1);
458 assert_eq!(arena.n_leaves(), 1);
459 }
460
461 #[test]
463 fn sample_count_tracking() {
464 let mut arena = TreeArena::new();
465 let root = arena.add_leaf(0);
466
467 assert_eq!(arena.get_sample_count(root), 0);
468
469 arena.increment_sample_count(root);
470 assert_eq!(arena.get_sample_count(root), 1);
471
472 arena.increment_sample_count(root);
473 arena.increment_sample_count(root);
474 assert_eq!(arena.get_sample_count(root), 3);
475
476 let (left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
478 assert_eq!(arena.get_sample_count(left), 0);
479 assert_eq!(arena.get_sample_count(right), 0);
480
481 assert_eq!(arena.get_sample_count(root), 3);
483
484 arena.increment_sample_count(left);
485 assert_eq!(arena.get_sample_count(left), 1);
486 assert_eq!(arena.get_sample_count(right), 0);
487 }
488
489 #[test]
491 fn with_capacity_preallocates() {
492 let arena = TreeArena::with_capacity(64);
493
494 assert_eq!(arena.n_nodes(), 0);
496 assert_eq!(arena.n_leaves(), 0);
497
498 assert!(arena.feature_idx.capacity() >= 64);
500 assert!(arena.threshold.capacity() >= 64);
501 assert!(arena.left.capacity() >= 64);
502 assert!(arena.right.capacity() >= 64);
503 assert!(arena.leaf_value.capacity() >= 64);
504 assert!(arena.is_leaf.capacity() >= 64);
505 assert!(arena.depth.capacity() >= 64);
506 assert!(arena.sample_count.capacity() >= 64);
507 }
508
509 #[test]
511 fn set_leaf_value_updates() {
512 let mut arena = TreeArena::new();
513 let leaf = arena.add_leaf(0);
514
515 assert_eq!(arena.predict(leaf), 0.0);
516
517 arena.set_leaf_value(leaf, 42.5);
518 assert_eq!(arena.predict(leaf), 42.5);
519
520 arena.set_leaf_value(leaf, -3.25);
521 assert_eq!(arena.predict(leaf), -3.25);
522 }
523
524 #[test]
526 #[should_panic(expected = "predict called on internal node")]
527 fn predict_panics_on_internal_node() {
528 let mut arena = TreeArena::new();
529 let root = arena.add_leaf(0);
530 let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
531
532 let _ = arena.predict(root);
534 }
535
536 #[test]
538 #[should_panic(expected = "set_leaf_value called on internal node")]
539 fn set_leaf_value_panics_on_internal_node() {
540 let mut arena = TreeArena::new();
541 let root = arena.add_leaf(0);
542 let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
543
544 arena.set_leaf_value(root, 1.0);
546 }
547
548 #[test]
550 #[should_panic(expected = "split_leaf called on non-leaf node")]
551 fn split_leaf_panics_on_internal_node() {
552 let mut arena = TreeArena::new();
553 let root = arena.add_leaf(0);
554 let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
555
556 let _ = arena.split_leaf(root, 1, 2.0, 0.0, 0.0);
558 }
559
560 #[test]
562 fn default_matches_new() {
563 let a = TreeArena::new();
564 let b = TreeArena::default();
565
566 assert_eq!(a.n_nodes(), b.n_nodes());
567 assert_eq!(a.n_leaves(), b.n_leaves());
568 }
569}