1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
36use ferrolearn_core::traits::{Fit, Transform};
37use ndarray::{Array2, ArrayView1};
38use num_traits::Float;
39use rand::SeedableRng;
40use rand::rngs::StdRng;
41use serde::{Deserialize, Serialize};
42
43use crate::decision_tree::Node;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RandomTreesEmbedding<F> {
63 pub n_estimators: usize,
65 pub max_depth: Option<usize>,
67 pub min_samples_split: usize,
69 pub random_state: Option<u64>,
71 _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> RandomTreesEmbedding<F> {
75 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 n_estimators: 10,
83 max_depth: Some(5),
84 min_samples_split: 2,
85 random_state: None,
86 _marker: std::marker::PhantomData,
87 }
88 }
89
90 #[must_use]
92 pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
93 self.n_estimators = n_estimators;
94 self
95 }
96
97 #[must_use]
99 pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
100 self.max_depth = max_depth;
101 self
102 }
103
104 #[must_use]
106 pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
107 self.min_samples_split = min_samples_split;
108 self
109 }
110
111 #[must_use]
113 pub fn with_random_state(mut self, seed: u64) -> Self {
114 self.random_state = Some(seed);
115 self
116 }
117}
118
119impl<F: Float> Default for RandomTreesEmbedding<F> {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125#[derive(Debug, Clone)]
135pub struct FittedRandomTreesEmbedding<F> {
136 trees: Vec<Vec<Node<F>>>,
138 leaf_counts: Vec<usize>,
140 leaf_maps: Vec<Vec<Option<usize>>>,
143 total_leaves: usize,
145 n_features: usize,
147}
148
149impl<F: Float + Send + Sync + 'static> FittedRandomTreesEmbedding<F> {
150 #[must_use]
152 pub fn n_estimators(&self) -> usize {
153 self.trees.len()
154 }
155
156 #[must_use]
158 pub fn n_features(&self) -> usize {
159 self.n_features
160 }
161
162 #[must_use]
164 pub fn n_output_features(&self) -> usize {
165 self.total_leaves
166 }
167}
168
169impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RandomTreesEmbedding<F> {
170 type Fitted = FittedRandomTreesEmbedding<F>;
171 type Error = FerroError;
172
173 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRandomTreesEmbedding<F>, FerroError> {
184 let (n_samples, n_features) = x.dim();
185
186 if n_samples == 0 {
187 return Err(FerroError::InsufficientSamples {
188 required: 1,
189 actual: 0,
190 context: "RandomTreesEmbedding requires at least one sample".into(),
191 });
192 }
193 if self.n_estimators == 0 {
194 return Err(FerroError::InvalidParameter {
195 name: "n_estimators".into(),
196 reason: "must be at least 1".into(),
197 });
198 }
199 if self.min_samples_split < 2 {
200 return Err(FerroError::InvalidParameter {
201 name: "min_samples_split".into(),
202 reason: "must be at least 2".into(),
203 });
204 }
205
206 let mut rng = if let Some(seed) = self.random_state {
207 StdRng::seed_from_u64(seed)
208 } else {
209 StdRng::from_os_rng()
210 };
211
212 let indices: Vec<usize> = (0..n_samples).collect();
213
214 let mut trees = Vec::with_capacity(self.n_estimators);
215 let mut leaf_counts = Vec::with_capacity(self.n_estimators);
216 let mut leaf_maps = Vec::with_capacity(self.n_estimators);
217 let mut total_leaves = 0;
218
219 for _ in 0..self.n_estimators {
220 let mut nodes = Vec::new();
221 build_random_tree(
222 x,
223 &indices,
224 &mut nodes,
225 0,
226 self.max_depth,
227 self.min_samples_split,
228 n_features,
229 &mut rng,
230 );
231
232 let mut leaf_map = vec![None; nodes.len()];
234 let mut count = 0;
235 for (idx, node) in nodes.iter().enumerate() {
236 if matches!(node, Node::Leaf { .. }) {
237 leaf_map[idx] = Some(count);
238 count += 1;
239 }
240 }
241
242 trees.push(nodes);
243 leaf_counts.push(count);
244 leaf_maps.push(leaf_map);
245 total_leaves += count;
246 }
247
248 Ok(FittedRandomTreesEmbedding {
249 trees,
250 leaf_counts,
251 leaf_maps,
252 total_leaves,
253 n_features,
254 })
255 }
256}
257
258impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRandomTreesEmbedding<F> {
259 type Output = Array2<F>;
260 type Error = FerroError;
261
262 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
275 if x.ncols() != self.n_features {
276 return Err(FerroError::ShapeMismatch {
277 expected: vec![self.n_features],
278 actual: vec![x.ncols()],
279 context: "number of features must match fitted model".into(),
280 });
281 }
282
283 let n_samples = x.nrows();
284 let mut output = Array2::zeros((n_samples, self.total_leaves));
285
286 let mut col_offset = 0;
287 for (tree_idx, tree_nodes) in self.trees.iter().enumerate() {
288 let leaf_map = &self.leaf_maps[tree_idx];
289 let n_leaves = self.leaf_counts[tree_idx];
290
291 for i in 0..n_samples {
292 let row = x.row(i);
293 let leaf_node_idx = traverse_tree(tree_nodes, &row);
294 if let Some(leaf_pos) = leaf_map[leaf_node_idx] {
295 output[[i, col_offset + leaf_pos]] = F::one();
296 }
297 }
298
299 col_offset += n_leaves;
300 }
301
302 Ok(output)
303 }
304}
305
306impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RandomTreesEmbedding<F> {
308 fn fit_pipeline(
309 &self,
310 x: &Array2<F>,
311 _y: &ndarray::Array1<F>,
312 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
313 let fitted = self.fit(x, &())?;
314 Ok(Box::new(fitted))
315 }
316}
317
318impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
319 for FittedRandomTreesEmbedding<F>
320{
321 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322 self.transform(x)
323 }
324}
325
326fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ArrayView1<F>) -> usize {
332 let mut idx = 0;
333 loop {
334 match &nodes[idx] {
335 Node::Split {
336 feature,
337 threshold,
338 left,
339 right,
340 ..
341 } => {
342 if sample[*feature] <= *threshold {
343 idx = *left;
344 } else {
345 idx = *right;
346 }
347 }
348 Node::Leaf { .. } => return idx,
349 }
350 }
351}
352
353fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
355 use rand::RngCore;
356 let u = (rng.next_u64() as f64) / (u64::MAX as f64);
357 let range = max_val - min_val;
358 min_val + F::from(u).unwrap() * range
359}
360
361#[allow(clippy::too_many_arguments)]
366fn build_random_tree<F: Float>(
367 x: &Array2<F>,
368 indices: &[usize],
369 nodes: &mut Vec<Node<F>>,
370 depth: usize,
371 max_depth: Option<usize>,
372 min_samples_split: usize,
373 n_features: usize,
374 rng: &mut StdRng,
375) -> usize {
376 let n = indices.len();
377
378 let should_stop = n < min_samples_split || max_depth.is_some_and(|d| depth >= d);
380
381 if should_stop {
382 let idx = nodes.len();
383 nodes.push(Node::Leaf {
384 value: F::zero(),
385 class_distribution: None,
386 n_samples: n,
387 });
388 return idx;
389 }
390
391 let max_attempts = n_features * 2;
393 for _ in 0..max_attempts {
394 use rand::RngCore;
395 let feature = (rng.next_u64() as usize) % n_features;
396
397 let mut min_val = x[[indices[0], feature]];
399 let mut max_val = min_val;
400 for &i in &indices[1..] {
401 let v = x[[i, feature]];
402 if v < min_val {
403 min_val = v;
404 }
405 if v > max_val {
406 max_val = v;
407 }
408 }
409
410 if min_val >= max_val {
412 continue;
413 }
414
415 let threshold = random_threshold(rng, min_val, max_val);
416
417 let mut left_indices = Vec::new();
419 let mut right_indices = Vec::new();
420 for &i in indices {
421 if x[[i, feature]] <= threshold {
422 left_indices.push(i);
423 } else {
424 right_indices.push(i);
425 }
426 }
427
428 if left_indices.is_empty() || right_indices.is_empty() {
430 continue;
431 }
432
433 let node_idx = nodes.len();
435 nodes.push(Node::Leaf {
436 value: F::zero(),
437 class_distribution: None,
438 n_samples: 0,
439 }); let left_child = build_random_tree(
442 x,
443 &left_indices,
444 nodes,
445 depth + 1,
446 max_depth,
447 min_samples_split,
448 n_features,
449 rng,
450 );
451 let right_child = build_random_tree(
452 x,
453 &right_indices,
454 nodes,
455 depth + 1,
456 max_depth,
457 min_samples_split,
458 n_features,
459 rng,
460 );
461
462 nodes[node_idx] = Node::Split {
463 feature,
464 threshold,
465 left: left_child,
466 right: right_child,
467 impurity_decrease: F::zero(),
468 n_samples: n,
469 };
470
471 return node_idx;
472 }
473
474 let idx = nodes.len();
476 nodes.push(Node::Leaf {
477 value: F::zero(),
478 class_distribution: None,
479 n_samples: n,
480 });
481 idx
482}
483
484#[cfg(test)]
489mod tests {
490 use super::*;
491 use ndarray::Array2;
492
493 fn make_data() -> Array2<f64> {
494 Array2::from_shape_vec(
495 (8, 3),
496 vec![
497 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
498 7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0,
499 ],
500 )
501 .unwrap()
502 }
503
504 #[test]
505 fn test_default() {
506 let model = RandomTreesEmbedding::<f64>::new();
507 assert_eq!(model.n_estimators, 10);
508 assert_eq!(model.max_depth, Some(5));
509 assert_eq!(model.min_samples_split, 2);
510 assert!(model.random_state.is_none());
511 }
512
513 #[test]
514 fn test_builder() {
515 let model = RandomTreesEmbedding::<f64>::new()
516 .with_n_estimators(20)
517 .with_max_depth(Some(3))
518 .with_min_samples_split(5)
519 .with_random_state(42);
520 assert_eq!(model.n_estimators, 20);
521 assert_eq!(model.max_depth, Some(3));
522 assert_eq!(model.min_samples_split, 5);
523 assert_eq!(model.random_state, Some(42));
524 }
525
526 #[test]
527 fn test_fit_transform_basic() {
528 let x = make_data();
529 let model = RandomTreesEmbedding::<f64>::new()
530 .with_n_estimators(5)
531 .with_max_depth(Some(3))
532 .with_random_state(42);
533 let fitted = model.fit(&x, &()).unwrap();
534 let embedded = fitted.transform(&x).unwrap();
535
536 assert_eq!(embedded.nrows(), 8);
537 for i in 0..8 {
539 let row_sum: f64 = embedded.row(i).iter().copied().sum();
540 assert!(
541 (row_sum - 5.0).abs() < 1e-10,
542 "row {i} should have exactly 5 ones, got {row_sum}"
543 );
544 }
545 }
546
547 #[test]
548 fn test_output_is_binary() {
549 let x = make_data();
550 let model = RandomTreesEmbedding::<f64>::new()
551 .with_n_estimators(3)
552 .with_max_depth(Some(2))
553 .with_random_state(42);
554 let fitted = model.fit(&x, &()).unwrap();
555 let embedded = fitted.transform(&x).unwrap();
556
557 for &val in &embedded {
559 assert!(
560 (val - 0.0).abs() < 1e-10 || (val - 1.0).abs() < 1e-10,
561 "values should be 0 or 1, got {val}"
562 );
563 }
564 }
565
566 #[test]
567 fn test_total_leaves_matches_output_cols() {
568 let x = make_data();
569 let model = RandomTreesEmbedding::<f64>::new()
570 .with_n_estimators(5)
571 .with_max_depth(Some(3))
572 .with_random_state(42);
573 let fitted = model.fit(&x, &()).unwrap();
574 let embedded = fitted.transform(&x).unwrap();
575
576 assert_eq!(embedded.ncols(), fitted.n_output_features());
577 }
578
579 #[test]
580 fn test_empty_input_error() {
581 let x = Array2::<f64>::zeros((0, 3));
582 let model = RandomTreesEmbedding::<f64>::new();
583 let result = model.fit(&x, &());
584 assert!(result.is_err());
585 }
586
587 #[test]
588 fn test_zero_estimators_error() {
589 let x = make_data();
590 let model = RandomTreesEmbedding::<f64>::new().with_n_estimators(0);
591 let result = model.fit(&x, &());
592 assert!(result.is_err());
593 }
594
595 #[test]
596 fn test_invalid_min_samples_split_error() {
597 let x = make_data();
598 let model = RandomTreesEmbedding::<f64>::new().with_min_samples_split(1);
599 let result = model.fit(&x, &());
600 assert!(result.is_err());
601 }
602
603 #[test]
604 fn test_shape_mismatch_error() {
605 let x_train = make_data();
606 let model = RandomTreesEmbedding::<f64>::new()
607 .with_n_estimators(3)
608 .with_random_state(42);
609 let fitted = model.fit(&x_train, &()).unwrap();
610
611 let x_test = Array2::<f64>::zeros((5, 10)); let result = fitted.transform(&x_test);
613 assert!(result.is_err());
614 }
615
616 #[test]
617 fn test_reproducibility() {
618 let x = make_data();
619 let model = RandomTreesEmbedding::<f64>::new()
620 .with_n_estimators(5)
621 .with_max_depth(Some(3))
622 .with_random_state(42);
623
624 let fitted1 = model.fit(&x, &()).unwrap();
625 let embedded1 = fitted1.transform(&x).unwrap();
626
627 let fitted2 = model.fit(&x, &()).unwrap();
628 let embedded2 = fitted2.transform(&x).unwrap();
629
630 assert_eq!(embedded1, embedded2);
631 }
632
633 #[test]
634 fn test_f32() {
635 let x = Array2::<f32>::from_shape_vec(
636 (6, 2),
637 vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
638 )
639 .unwrap();
640 let model = RandomTreesEmbedding::<f32>::new()
641 .with_n_estimators(3)
642 .with_max_depth(Some(2))
643 .with_random_state(42);
644 let fitted = model.fit(&x, &()).unwrap();
645 let embedded = fitted.transform(&x).unwrap();
646 assert_eq!(embedded.nrows(), 6);
647 }
648
649 #[test]
650 fn test_fitted_accessors() {
651 let x = make_data();
652 let model = RandomTreesEmbedding::<f64>::new()
653 .with_n_estimators(5)
654 .with_max_depth(Some(3))
655 .with_random_state(42);
656 let fitted = model.fit(&x, &()).unwrap();
657 assert_eq!(fitted.n_estimators(), 5);
658 assert_eq!(fitted.n_features(), 3);
659 assert!(fitted.n_output_features() > 0);
660 }
661
662 #[test]
663 fn test_deeper_trees_more_leaves() {
664 let x = make_data();
665
666 let shallow = RandomTreesEmbedding::<f64>::new()
667 .with_n_estimators(1)
668 .with_max_depth(Some(1))
669 .with_random_state(42);
670 let fitted_shallow = shallow.fit(&x, &()).unwrap();
671
672 let deep = RandomTreesEmbedding::<f64>::new()
673 .with_n_estimators(1)
674 .with_max_depth(Some(5))
675 .with_random_state(42);
676 let fitted_deep = deep.fit(&x, &()).unwrap();
677
678 assert!(
679 fitted_deep.n_output_features() >= fitted_shallow.n_output_features(),
680 "deeper trees should have at least as many leaves"
681 );
682 }
683
684 #[test]
685 fn test_single_sample() {
686 let x = Array2::<f64>::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
687 let model = RandomTreesEmbedding::<f64>::new()
688 .with_n_estimators(3)
689 .with_max_depth(Some(3))
690 .with_random_state(42);
691 let fitted = model.fit(&x, &()).unwrap();
692 let embedded = fitted.transform(&x).unwrap();
693 assert_eq!(embedded.nrows(), 1);
694 assert_eq!(embedded.ncols(), 3);
696 }
697
698 #[test]
699 fn test_unlimited_depth() {
700 let x = make_data();
701 let model = RandomTreesEmbedding::<f64>::new()
702 .with_n_estimators(3)
703 .with_max_depth(None)
704 .with_random_state(42);
705 let fitted = model.fit(&x, &()).unwrap();
706 let embedded = fitted.transform(&x).unwrap();
707 assert_eq!(embedded.nrows(), 8);
708 assert!(embedded.ncols() > 0);
709 }
710}