1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
36use ferrolearn_core::traits::{Fit, Transform};
37use ndarray::{Array2, ArrayView1};
38use num_traits::Float;
39use rand::SeedableRng;
40use rand::rngs::StdRng;
41use serde::{Deserialize, Serialize};
42
43use crate::decision_tree::Node;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RandomTreesEmbedding<F> {
63 pub n_estimators: usize,
65 pub max_depth: Option<usize>,
67 pub min_samples_split: usize,
69 pub random_state: Option<u64>,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> RandomTreesEmbedding<F> {
75 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 n_estimators: 10,
83 max_depth: Some(5),
84 min_samples_split: 2,
85 random_state: None,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
92 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
93 self.n_estimators = n_estimators;
94 self
95 }
96
97 #[must_use]
99 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
100 self.max_depth = max_depth;
101 self
102 }
103
104 #[must_use]
106 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
107 self.min_samples_split = min_samples_split;
108 self
109 }
110
111 #[must_use]
113 pub fn with_random_state(mut self, seed: u64) -> Self {
114 self.random_state = Some(seed);
115 self
116 }
117}
118
119impl<F: Float> Default for RandomTreesEmbedding<F> {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[derive(Debug, Clone)]
135pub struct FittedRandomTreesEmbedding<F> {
136 trees: Vec<Vec<Node<F>>>,
138 leaf_counts: Vec<usize>,
140 leaf_maps: Vec<Vec<Option<usize>>>,
143 total_leaves: usize,
145 n_features: usize,
147}
148
149impl<F: Float + Send + Sync + 'static> FittedRandomTreesEmbedding<F> {
150 #[must_use]
152 pub fn n_estimators(&self) -> usize {
153 self.trees.len()
154 }
155
156 #[must_use]
158 pub fn n_features(&self) -> usize {
159 self.n_features
160 }
161
162 #[must_use]
164 pub fn n_output_features(&self) -> usize {
165 self.total_leaves
166 }
167}
168
169impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RandomTreesEmbedding<F> {
170 type Fitted = FittedRandomTreesEmbedding<F>;
171 type Error = FerroError;
172
173 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRandomTreesEmbedding<F>, FerroError> {
184 let (n_samples, n_features) = x.dim();
185
186 if n_samples == 0 {
187 return Err(FerroError::InsufficientSamples {
188 required: 1,
189 actual: 0,
190 context: "RandomTreesEmbedding requires at least one sample".into(),
191 });
192 }
193 if self.n_estimators == 0 {
194 return Err(FerroError::InvalidParameter {
195 name: "n_estimators".into(),
196 reason: "must be at least 1".into(),
197 });
198 }
199 if self.min_samples_split < 2 {
200 return Err(FerroError::InvalidParameter {
201 name: "min_samples_split".into(),
202 reason: "must be at least 2".into(),
203 });
204 }
205
206 let mut rng = if let Some(seed) = self.random_state {
207 StdRng::seed_from_u64(seed)
208 } else {
209 StdRng::from_os_rng()
210 };
211
212 let indices: Vec<usize> = (0..n_samples).collect();
213
214 let mut trees = Vec::with_capacity(self.n_estimators);
215 let mut leaf_counts = Vec::with_capacity(self.n_estimators);
216 let mut leaf_maps = Vec::with_capacity(self.n_estimators);
217 let mut total_leaves = 0;
218
219 for _ in 0..self.n_estimators {
220 let mut nodes = Vec::new();
221 build_random_tree(
222 x,
223 &indices,
224 &mut nodes,
225 0,
226 self.max_depth,
227 self.min_samples_split,
228 n_features,
229 &mut rng,
230 );
231
232 let mut leaf_map = vec![None; nodes.len()];
234 let mut count = 0;
235 for (idx, node) in nodes.iter().enumerate() {
236 if matches!(node, Node::Leaf { .. }) {
237 leaf_map[idx] = Some(count);
238 count += 1;
239 }
240 }
241
242 trees.push(nodes);
243 leaf_counts.push(count);
244 leaf_maps.push(leaf_map);
245 total_leaves += count;
246 }
247
248 Ok(FittedRandomTreesEmbedding {
249 trees,
250 leaf_counts,
251 leaf_maps,
252 total_leaves,
253 n_features,
254 })
255 }
256}
257
258impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRandomTreesEmbedding<F> {
259 type Output = Array2<F>;
260 type Error = FerroError;
261
262 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
275 if x.ncols() != self.n_features {
276 return Err(FerroError::ShapeMismatch {
277 expected: vec![self.n_features],
278 actual: vec![x.ncols()],
279 context: "number of features must match fitted model".into(),
280 });
281 }
282
283 let n_samples = x.nrows();
284 let mut output = Array2::zeros((n_samples, self.total_leaves));
285
286 let mut col_offset = 0;
287 for (tree_idx, tree_nodes) in self.trees.iter().enumerate() {
288 let leaf_map = &self.leaf_maps[tree_idx];
289 let n_leaves = self.leaf_counts[tree_idx];
290
291 for i in 0..n_samples {
292 let row = x.row(i);
293 let leaf_node_idx = traverse_tree(tree_nodes, &row);
294 if let Some(leaf_pos) = leaf_map[leaf_node_idx] {
295 output[[i, col_offset + leaf_pos]] = F::one();
296 }
297 }
298
299 col_offset += n_leaves;
300 }
301
302 Ok(output)
303 }
304}
305
306impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RandomTreesEmbedding<F> {
308 fn fit_pipeline(
309 &self,
310 x: &Array2<F>,
311 _y: &ndarray::Array1<F>,
312 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
313 let fitted = self.fit(x, &())?;
314 Ok(Box::new(fitted))
315 }
316}
317
318impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
319 for FittedRandomTreesEmbedding<F>
320{
321 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322 self.transform(x)
323 }
324}
325
326fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ArrayView1<F>) -> usize {
332 let mut idx = 0;
333 loop {
334 match &nodes[idx] {
335 Node::Split {
336 feature,
337 threshold,
338 left,
339 right,
340 ..
341 } => {
342 if sample[*feature] <= *threshold {
343 idx = *left;
344 } else {
345 idx = *right;
346 }
347 }
348 Node::Leaf { .. } => return idx,
349 }
350 }
351}
352
353fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
355 use rand::RngCore;
356 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
357 let range = max_val - min_val;
358 min_val + F::from(u).unwrap() * range
359}
360
361#[allow(clippy::too_many_arguments)]
366fn build_random_tree<F: Float>(
367 x: &Array2<F>,
368 indices: &[usize],
369 nodes: &mut Vec<Node<F>>,
370 depth: usize,
371 max_depth: Option<usize>,
372 min_samples_split: usize,
373 n_features: usize,
374 rng: &mut StdRng,
375) -> usize {
376 let n = indices.len();
377
378 let should_stop =
380 n < min_samples_split || max_depth.is_some_and(|d| depth >= d);
381
382 if should_stop {
383 let idx = nodes.len();
384 nodes.push(Node::Leaf {
385 value: F::zero(),
386 class_distribution: None,
387 n_samples: n,
388 });
389 return idx;
390 }
391
392 let max_attempts = n_features * 2;
394 for _ in 0..max_attempts {
395 use rand::RngCore;
396 let feature = (rng.next_u64() as usize) % n_features;
397
398 let mut min_val = x[[indices[0], feature]];
400 let mut max_val = min_val;
401 for &i in &indices[1..] {
402 let v = x[[i, feature]];
403 if v < min_val {
404 min_val = v;
405 }
406 if v > max_val {
407 max_val = v;
408 }
409 }
410
411 if min_val >= max_val {
413 continue;
414 }
415
416 let threshold = random_threshold(rng, min_val, max_val);
417
418 let mut left_indices = Vec::new();
420 let mut right_indices = Vec::new();
421 for &i in indices {
422 if x[[i, feature]] <= threshold {
423 left_indices.push(i);
424 } else {
425 right_indices.push(i);
426 }
427 }
428
429 if left_indices.is_empty() || right_indices.is_empty() {
431 continue;
432 }
433
434 let node_idx = nodes.len();
436 nodes.push(Node::Leaf {
437 value: F::zero(),
438 class_distribution: None,
439 n_samples: 0,
440 }); let left_child = build_random_tree(
443 x,
444 &left_indices,
445 nodes,
446 depth + 1,
447 max_depth,
448 min_samples_split,
449 n_features,
450 rng,
451 );
452 let right_child = build_random_tree(
453 x,
454 &right_indices,
455 nodes,
456 depth + 1,
457 max_depth,
458 min_samples_split,
459 n_features,
460 rng,
461 );
462
463 nodes[node_idx] = Node::Split {
464 feature,
465 threshold,
466 left: left_child,
467 right: right_child,
468 impurity_decrease: F::zero(),
469 n_samples: n,
470 };
471
472 return node_idx;
473 }
474
475 let idx = nodes.len();
477 nodes.push(Node::Leaf {
478 value: F::zero(),
479 class_distribution: None,
480 n_samples: n,
481 });
482 idx
483}
484
485#[cfg(test)]
490mod tests {
491 use super::*;
492 use ndarray::Array2;
493
494 fn make_data() -> Array2<f64> {
495 Array2::from_shape_vec(
496 (8, 3),
497 vec![
498 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
499 7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0,
500 ],
501 )
502 .unwrap()
503 }
504
505 #[test]
506 fn test_default() {
507 let model = RandomTreesEmbedding::<f64>::new();
508 assert_eq!(model.n_estimators, 10);
509 assert_eq!(model.max_depth, Some(5));
510 assert_eq!(model.min_samples_split, 2);
511 assert!(model.random_state.is_none());
512 }
513
514 #[test]
515 fn test_builder() {
516 let model = RandomTreesEmbedding::<f64>::new()
517 .with_n_estimators(20)
518 .with_max_depth(Some(3))
519 .with_min_samples_split(5)
520 .with_random_state(42);
521 assert_eq!(model.n_estimators, 20);
522 assert_eq!(model.max_depth, Some(3));
523 assert_eq!(model.min_samples_split, 5);
524 assert_eq!(model.random_state, Some(42));
525 }
526
527 #[test]
528 fn test_fit_transform_basic() {
529 let x = make_data();
530 let model = RandomTreesEmbedding::<f64>::new()
531 .with_n_estimators(5)
532 .with_max_depth(Some(3))
533 .with_random_state(42);
534 let fitted = model.fit(&x, &()).unwrap();
535 let embedded = fitted.transform(&x).unwrap();
536
537 assert_eq!(embedded.nrows(), 8);
538 for i in 0..8 {
540 let row_sum: f64 = embedded.row(i).iter().copied().sum();
541 assert!(
542 (row_sum - 5.0).abs() < 1e-10,
543 "row {i} should have exactly 5 ones, got {row_sum}"
544 );
545 }
546 }
547
548 #[test]
549 fn test_output_is_binary() {
550 let x = make_data();
551 let model = RandomTreesEmbedding::<f64>::new()
552 .with_n_estimators(3)
553 .with_max_depth(Some(2))
554 .with_random_state(42);
555 let fitted = model.fit(&x, &()).unwrap();
556 let embedded = fitted.transform(&x).unwrap();
557
558 for &val in embedded.iter() {
560 assert!(
561 (val - 0.0).abs() < 1e-10 || (val - 1.0).abs() < 1e-10,
562 "values should be 0 or 1, got {val}"
563 );
564 }
565 }
566
567 #[test]
568 fn test_total_leaves_matches_output_cols() {
569 let x = make_data();
570 let model = RandomTreesEmbedding::<f64>::new()
571 .with_n_estimators(5)
572 .with_max_depth(Some(3))
573 .with_random_state(42);
574 let fitted = model.fit(&x, &()).unwrap();
575 let embedded = fitted.transform(&x).unwrap();
576
577 assert_eq!(embedded.ncols(), fitted.n_output_features());
578 }
579
580 #[test]
581 fn test_empty_input_error() {
582 let x = Array2::<f64>::zeros((0, 3));
583 let model = RandomTreesEmbedding::<f64>::new();
584 let result = model.fit(&x, &());
585 assert!(result.is_err());
586 }
587
588 #[test]
589 fn test_zero_estimators_error() {
590 let x = make_data();
591 let model = RandomTreesEmbedding::<f64>::new().with_n_estimators(0);
592 let result = model.fit(&x, &());
593 assert!(result.is_err());
594 }
595
596 #[test]
597 fn test_invalid_min_samples_split_error() {
598 let x = make_data();
599 let model = RandomTreesEmbedding::<f64>::new().with_min_samples_split(1);
600 let result = model.fit(&x, &());
601 assert!(result.is_err());
602 }
603
604 #[test]
605 fn test_shape_mismatch_error() {
606 let x_train = make_data();
607 let model = RandomTreesEmbedding::<f64>::new()
608 .with_n_estimators(3)
609 .with_random_state(42);
610 let fitted = model.fit(&x_train, &()).unwrap();
611
612 let x_test = Array2::<f64>::zeros((5, 10)); let result = fitted.transform(&x_test);
614 assert!(result.is_err());
615 }
616
617 #[test]
618 fn test_reproducibility() {
619 let x = make_data();
620 let model = RandomTreesEmbedding::<f64>::new()
621 .with_n_estimators(5)
622 .with_max_depth(Some(3))
623 .with_random_state(42);
624
625 let fitted1 = model.fit(&x, &()).unwrap();
626 let embedded1 = fitted1.transform(&x).unwrap();
627
628 let fitted2 = model.fit(&x, &()).unwrap();
629 let embedded2 = fitted2.transform(&x).unwrap();
630
631 assert_eq!(embedded1, embedded2);
632 }
633
634 #[test]
635 fn test_f32() {
636 let x = Array2::<f32>::from_shape_vec(
637 (6, 2),
638 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
639 )
640 .unwrap();
641 let model = RandomTreesEmbedding::<f32>::new()
642 .with_n_estimators(3)
643 .with_max_depth(Some(2))
644 .with_random_state(42);
645 let fitted = model.fit(&x, &()).unwrap();
646 let embedded = fitted.transform(&x).unwrap();
647 assert_eq!(embedded.nrows(), 6);
648 }
649
650 #[test]
651 fn test_fitted_accessors() {
652 let x = make_data();
653 let model = RandomTreesEmbedding::<f64>::new()
654 .with_n_estimators(5)
655 .with_max_depth(Some(3))
656 .with_random_state(42);
657 let fitted = model.fit(&x, &()).unwrap();
658 assert_eq!(fitted.n_estimators(), 5);
659 assert_eq!(fitted.n_features(), 3);
660 assert!(fitted.n_output_features() > 0);
661 }
662
663 #[test]
664 fn test_deeper_trees_more_leaves() {
665 let x = make_data();
666
667 let shallow = RandomTreesEmbedding::<f64>::new()
668 .with_n_estimators(1)
669 .with_max_depth(Some(1))
670 .with_random_state(42);
671 let fitted_shallow = shallow.fit(&x, &()).unwrap();
672
673 let deep = RandomTreesEmbedding::<f64>::new()
674 .with_n_estimators(1)
675 .with_max_depth(Some(5))
676 .with_random_state(42);
677 let fitted_deep = deep.fit(&x, &()).unwrap();
678
679 assert!(
680 fitted_deep.n_output_features() >= fitted_shallow.n_output_features(),
681 "deeper trees should have at least as many leaves"
682 );
683 }
684
685 #[test]
686 fn test_single_sample() {
687 let x = Array2::<f64>::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
688 let model = RandomTreesEmbedding::<f64>::new()
689 .with_n_estimators(3)
690 .with_max_depth(Some(3))
691 .with_random_state(42);
692 let fitted = model.fit(&x, &()).unwrap();
693 let embedded = fitted.transform(&x).unwrap();
694 assert_eq!(embedded.nrows(), 1);
695 assert_eq!(embedded.ncols(), 3);
697 }
698
699 #[test]
700 fn test_unlimited_depth() {
701 let x = make_data();
702 let model = RandomTreesEmbedding::<f64>::new()
703 .with_n_estimators(3)
704 .with_max_depth(None)
705 .with_random_state(42);
706 let fitted = model.fit(&x, &()).unwrap();
707 let embedded = fitted.transform(&x).unwrap();
708 assert_eq!(embedded.nrows(), 8);
709 assert!(embedded.ncols() > 0);
710 }
711}