Skip to main content

ferrolearn_tree/
random_trees_embedding.rs

1//! Random trees embedding for unsupervised feature transformation.
2//!
3//! This module provides [`RandomTreesEmbedding`], which transforms input
4//! features into a high-dimensional sparse binary representation by encoding
5//! each sample as the concatenation of one-hot encoded leaf indices across
6//! an ensemble of randomly built trees.
7//!
8//! Trees are built with purely random splits (random feature, random threshold
9//! between min and max), ignoring any target variable. This makes the
10//! embedding entirely unsupervised.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::RandomTreesEmbedding;
16//! use ferrolearn_core::{Fit, Transform};
17//! use ndarray::Array2;
18//!
19//! let x = Array2::from_shape_vec((6, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
22//! ]).unwrap();
23//!
24//! let model = RandomTreesEmbedding::<f64>::new()
25//!     .with_n_estimators(5)
26//!     .with_max_depth(Some(3))
27//!     .with_random_state(42);
28//! let fitted = model.fit(&x, &()).unwrap();
29//! let embedded = fitted.transform(&x).unwrap();
30//! // Output has n_samples rows and (total_leaves_across_trees) columns.
31//! assert_eq!(embedded.nrows(), 6);
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
36use ferrolearn_core::traits::{Fit, Transform};
37use ndarray::{Array2, ArrayView1};
38use num_traits::Float;
39use rand::SeedableRng;
40use rand::rngs::StdRng;
41use serde::{Deserialize, Serialize};
42
43use crate::decision_tree::Node;
44
45// ---------------------------------------------------------------------------
46// RandomTreesEmbedding
47// ---------------------------------------------------------------------------
48
49/// Random trees embedding for unsupervised feature transformation.
50///
51/// Builds an ensemble of randomly split trees (ignoring targets) and
52/// represents each sample as a one-hot encoding of its leaf index in
53/// each tree, concatenated across all trees.
54///
55/// This is useful for creating a nonlinear feature representation that
56/// can be fed into linear models.
57///
58/// # Type Parameters
59///
60/// - `F`: The floating-point type (`f32` or `f64`).
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RandomTreesEmbedding<F> {
63    /// Number of random trees to build.
64    pub n_estimators: usize,
65    /// Maximum depth of each tree. `None` means unlimited.
66    pub max_depth: Option<usize>,
67    /// Minimum number of samples required to split a node.
68    pub min_samples_split: usize,
69    /// Random seed for reproducibility. `None` means non-deterministic.
70    pub random_state: Option<u64>,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> RandomTreesEmbedding<F> {
75    /// Create a new `RandomTreesEmbedding` with default settings.
76    ///
77    /// Defaults: `n_estimators = 10`, `max_depth = Some(5)`,
78    /// `min_samples_split = 2`, `random_state = None`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            n_estimators: 10,
83            max_depth: Some(5),
84            min_samples_split: 2,
85            random_state: None,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Set the number of random trees.
91    #[must_use]
92    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
93        self.n_estimators = n_estimators;
94        self
95    }
96
97    /// Set the maximum tree depth.
98    #[must_use]
99    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
100        self.max_depth = max_depth;
101        self
102    }
103
104    /// Set the minimum number of samples required to split a node.
105    #[must_use]
106    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
107        self.min_samples_split = min_samples_split;
108        self
109    }
110
111    /// Set the random seed for reproducibility.
112    #[must_use]
113    pub fn with_random_state(mut self, seed: u64) -> Self {
114        self.random_state = Some(seed);
115        self
116    }
117}
118
119impl<F: Float> Default for RandomTreesEmbedding<F> {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125// ---------------------------------------------------------------------------
126// FittedRandomTreesEmbedding
127// ---------------------------------------------------------------------------
128
129/// A fitted random trees embedding.
130///
131/// Stores the ensemble of randomly built trees. Each tree's leaves are
132/// enumerated, and [`Transform`] produces a one-hot encoded matrix
133/// of shape `(n_samples, total_leaves)`.
134#[derive(Debug, Clone)]
135pub struct FittedRandomTreesEmbedding<F> {
136    /// Individual trees, each stored as a flat node vector.
137    trees: Vec<Vec<Node<F>>>,
138    /// Per-tree leaf count (number of leaves in each tree).
139    leaf_counts: Vec<usize>,
140    /// Per-tree mapping from leaf node index to enumerated leaf position.
141    /// `leaf_maps[t][node_idx]` gives the leaf's position index within tree `t`.
142    leaf_maps: Vec<Vec<Option<usize>>>,
143    /// Total number of leaves across all trees (output dimensionality).
144    total_leaves: usize,
145    /// Number of features the model was trained on.
146    n_features: usize,
147}
148
149impl<F: Float + Send + Sync + 'static> FittedRandomTreesEmbedding<F> {
150    /// Returns the number of trees in the ensemble.
151    #[must_use]
152    pub fn n_estimators(&self) -> usize {
153        self.trees.len()
154    }
155
156    /// Returns the number of features the model was trained on.
157    #[must_use]
158    pub fn n_features(&self) -> usize {
159        self.n_features
160    }
161
162    /// Returns the total number of output features (total leaves across all trees).
163    #[must_use]
164    pub fn n_output_features(&self) -> usize {
165        self.total_leaves
166    }
167}
168
169impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RandomTreesEmbedding<F> {
170    type Fitted = FittedRandomTreesEmbedding<F>;
171    type Error = FerroError;
172
173    /// Fit the random trees embedding on the training data.
174    ///
175    /// Builds an ensemble of trees with purely random splits (random feature,
176    /// random threshold between feature min and max in the current node),
177    /// ignoring any target variable.
178    ///
179    /// # Errors
180    ///
181    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
182    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
183    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRandomTreesEmbedding<F>, FerroError> {
184        let (n_samples, n_features) = x.dim();
185
186        if n_samples == 0 {
187            return Err(FerroError::InsufficientSamples {
188                required: 1,
189                actual: 0,
190                context: "RandomTreesEmbedding requires at least one sample".into(),
191            });
192        }
193        if self.n_estimators == 0 {
194            return Err(FerroError::InvalidParameter {
195                name: "n_estimators".into(),
196                reason: "must be at least 1".into(),
197            });
198        }
199        if self.min_samples_split < 2 {
200            return Err(FerroError::InvalidParameter {
201                name: "min_samples_split".into(),
202                reason: "must be at least 2".into(),
203            });
204        }
205
206        let mut rng = if let Some(seed) = self.random_state {
207            StdRng::seed_from_u64(seed)
208        } else {
209            StdRng::from_os_rng()
210        };
211
212        let indices: Vec<usize> = (0..n_samples).collect();
213
214        let mut trees = Vec::with_capacity(self.n_estimators);
215        let mut leaf_counts = Vec::with_capacity(self.n_estimators);
216        let mut leaf_maps = Vec::with_capacity(self.n_estimators);
217        let mut total_leaves = 0;
218
219        for _ in 0..self.n_estimators {
220            let mut nodes = Vec::new();
221            build_random_tree(
222                x,
223                &indices,
224                &mut nodes,
225                0,
226                self.max_depth,
227                self.min_samples_split,
228                n_features,
229                &mut rng,
230            );
231
232            // Enumerate leaf nodes.
233            let mut leaf_map = vec![None; nodes.len()];
234            let mut count = 0;
235            for (idx, node) in nodes.iter().enumerate() {
236                if matches!(node, Node::Leaf { .. }) {
237                    leaf_map[idx] = Some(count);
238                    count += 1;
239                }
240            }
241
242            trees.push(nodes);
243            leaf_counts.push(count);
244            leaf_maps.push(leaf_map);
245            total_leaves += count;
246        }
247
248        Ok(FittedRandomTreesEmbedding {
249            trees,
250            leaf_counts,
251            leaf_maps,
252            total_leaves,
253            n_features,
254        })
255    }
256}
257
258impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRandomTreesEmbedding<F> {
259    type Output = Array2<F>;
260    type Error = FerroError;
261
262    /// Transform the input data into a one-hot encoded leaf-index representation.
263    ///
264    /// For each sample, traverse each tree to a leaf node, then one-hot encode
265    /// the leaf index within that tree. The encodings for all trees are
266    /// concatenated horizontally to produce the output.
267    ///
268    /// Output shape: `(n_samples, total_leaves_across_all_trees)`.
269    ///
270    /// # Errors
271    ///
272    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
273    /// not match the training data.
274    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
275        if x.ncols() != self.n_features {
276            return Err(FerroError::ShapeMismatch {
277                expected: vec![self.n_features],
278                actual: vec![x.ncols()],
279                context: "number of features must match fitted model".into(),
280            });
281        }
282
283        let n_samples = x.nrows();
284        let mut output = Array2::zeros((n_samples, self.total_leaves));
285
286        let mut col_offset = 0;
287        for (tree_idx, tree_nodes) in self.trees.iter().enumerate() {
288            let leaf_map = &self.leaf_maps[tree_idx];
289            let n_leaves = self.leaf_counts[tree_idx];
290
291            for i in 0..n_samples {
292                let row = x.row(i);
293                let leaf_node_idx = traverse_tree(tree_nodes, &row);
294                if let Some(leaf_pos) = leaf_map[leaf_node_idx] {
295                    output[[i, col_offset + leaf_pos]] = F::one();
296                }
297            }
298
299            col_offset += n_leaves;
300        }
301
302        Ok(output)
303    }
304}
305
306// Pipeline integration.
307impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RandomTreesEmbedding<F> {
308    fn fit_pipeline(
309        &self,
310        x: &Array2<F>,
311        _y: &ndarray::Array1<F>,
312    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
313        let fitted = self.fit(x, &())?;
314        Ok(Box::new(fitted))
315    }
316}
317
318impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
319    for FittedRandomTreesEmbedding<F>
320{
321    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322        self.transform(x)
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Internal: random tree building (unsupervised)
328// ---------------------------------------------------------------------------
329
330/// Traverse a tree from root to leaf for a single sample, returning the leaf node index.
331fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ArrayView1<F>) -> usize {
332    let mut idx = 0;
333    loop {
334        match &nodes[idx] {
335            Node::Split {
336                feature,
337                threshold,
338                left,
339                right,
340                ..
341            } => {
342                if sample[*feature] <= *threshold {
343                    idx = *left;
344                } else {
345                    idx = *right;
346                }
347            }
348            Node::Leaf { .. } => return idx,
349        }
350    }
351}
352
353/// Generate a uniform random float in `[min_val, max_val]`.
354fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
355    use rand::RngCore;
356    let u = (rng.next_u64() as f64) / (u64::MAX as f64);
357    let range = max_val - min_val;
358    min_val + F::from(u).unwrap() * range
359}
360
361/// Build a random tree recursively with purely random splits (unsupervised).
362///
363/// At each node, a random feature is chosen and a random threshold is drawn
364/// uniformly between the feature's min and max in the current node.
365#[allow(clippy::too_many_arguments)]
366fn build_random_tree<F: Float>(
367    x: &Array2<F>,
368    indices: &[usize],
369    nodes: &mut Vec<Node<F>>,
370    depth: usize,
371    max_depth: Option<usize>,
372    min_samples_split: usize,
373    n_features: usize,
374    rng: &mut StdRng,
375) -> usize {
376    let n = indices.len();
377
378    // Stop if: too few samples, or max depth reached.
379    let should_stop = n < min_samples_split || max_depth.is_some_and(|d| depth >= d);
380
381    if should_stop {
382        let idx = nodes.len();
383        nodes.push(Node::Leaf {
384            value: F::zero(),
385            class_distribution: None,
386            n_samples: n,
387        });
388        return idx;
389    }
390
391    // Try random features until we find one that can split.
392    let max_attempts = n_features * 2;
393    for _ in 0..max_attempts {
394        use rand::RngCore;
395        let feature = (rng.next_u64() as usize) % n_features;
396
397        // Find min and max of this feature for the current indices.
398        let mut min_val = x[[indices[0], feature]];
399        let mut max_val = min_val;
400        for &i in &indices[1..] {
401            let v = x[[i, feature]];
402            if v < min_val {
403                min_val = v;
404            }
405            if v > max_val {
406                max_val = v;
407            }
408        }
409
410        // If all values are the same, try another feature.
411        if min_val >= max_val {
412            continue;
413        }
414
415        let threshold = random_threshold(rng, min_val, max_val);
416
417        // Partition indices.
418        let mut left_indices = Vec::new();
419        let mut right_indices = Vec::new();
420        for &i in indices {
421            if x[[i, feature]] <= threshold {
422                left_indices.push(i);
423            } else {
424                right_indices.push(i);
425            }
426        }
427
428        // If the split is degenerate, try again.
429        if left_indices.is_empty() || right_indices.is_empty() {
430            continue;
431        }
432
433        // Reserve a slot for this node.
434        let node_idx = nodes.len();
435        nodes.push(Node::Leaf {
436            value: F::zero(),
437            class_distribution: None,
438            n_samples: 0,
439        }); // placeholder
440
441        let left_child = build_random_tree(
442            x,
443            &left_indices,
444            nodes,
445            depth + 1,
446            max_depth,
447            min_samples_split,
448            n_features,
449            rng,
450        );
451        let right_child = build_random_tree(
452            x,
453            &right_indices,
454            nodes,
455            depth + 1,
456            max_depth,
457            min_samples_split,
458            n_features,
459            rng,
460        );
461
462        nodes[node_idx] = Node::Split {
463            feature,
464            threshold,
465            left: left_child,
466            right: right_child,
467            impurity_decrease: F::zero(),
468            n_samples: n,
469        };
470
471        return node_idx;
472    }
473
474    // Could not find a splittable feature — make this a leaf.
475    let idx = nodes.len();
476    nodes.push(Node::Leaf {
477        value: F::zero(),
478        class_distribution: None,
479        n_samples: n,
480    });
481    idx
482}
483
484// ---------------------------------------------------------------------------
485// Tests
486// ---------------------------------------------------------------------------
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use ndarray::Array2;
492
493    fn make_data() -> Array2<f64> {
494        Array2::from_shape_vec(
495            (8, 3),
496            vec![
497                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
498                7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0,
499            ],
500        )
501        .unwrap()
502    }
503
504    #[test]
505    fn test_default() {
506        let model = RandomTreesEmbedding::<f64>::new();
507        assert_eq!(model.n_estimators, 10);
508        assert_eq!(model.max_depth, Some(5));
509        assert_eq!(model.min_samples_split, 2);
510        assert!(model.random_state.is_none());
511    }
512
513    #[test]
514    fn test_builder() {
515        let model = RandomTreesEmbedding::<f64>::new()
516            .with_n_estimators(20)
517            .with_max_depth(Some(3))
518            .with_min_samples_split(5)
519            .with_random_state(42);
520        assert_eq!(model.n_estimators, 20);
521        assert_eq!(model.max_depth, Some(3));
522        assert_eq!(model.min_samples_split, 5);
523        assert_eq!(model.random_state, Some(42));
524    }
525
526    #[test]
527    fn test_fit_transform_basic() {
528        let x = make_data();
529        let model = RandomTreesEmbedding::<f64>::new()
530            .with_n_estimators(5)
531            .with_max_depth(Some(3))
532            .with_random_state(42);
533        let fitted = model.fit(&x, &()).unwrap();
534        let embedded = fitted.transform(&x).unwrap();
535
536        assert_eq!(embedded.nrows(), 8);
537        // Each row should have exactly n_estimators ones (one per tree).
538        for i in 0..8 {
539            let row_sum: f64 = embedded.row(i).iter().copied().sum();
540            assert!(
541                (row_sum - 5.0).abs() < 1e-10,
542                "row {i} should have exactly 5 ones, got {row_sum}"
543            );
544        }
545    }
546
547    #[test]
548    fn test_output_is_binary() {
549        let x = make_data();
550        let model = RandomTreesEmbedding::<f64>::new()
551            .with_n_estimators(3)
552            .with_max_depth(Some(2))
553            .with_random_state(42);
554        let fitted = model.fit(&x, &()).unwrap();
555        let embedded = fitted.transform(&x).unwrap();
556
557        // All values should be 0.0 or 1.0
558        for &val in &embedded {
559            assert!(
560                (val - 0.0).abs() < 1e-10 || (val - 1.0).abs() < 1e-10,
561                "values should be 0 or 1, got {val}"
562            );
563        }
564    }
565
566    #[test]
567    fn test_total_leaves_matches_output_cols() {
568        let x = make_data();
569        let model = RandomTreesEmbedding::<f64>::new()
570            .with_n_estimators(5)
571            .with_max_depth(Some(3))
572            .with_random_state(42);
573        let fitted = model.fit(&x, &()).unwrap();
574        let embedded = fitted.transform(&x).unwrap();
575
576        assert_eq!(embedded.ncols(), fitted.n_output_features());
577    }
578
579    #[test]
580    fn test_empty_input_error() {
581        let x = Array2::<f64>::zeros((0, 3));
582        let model = RandomTreesEmbedding::<f64>::new();
583        let result = model.fit(&x, &());
584        assert!(result.is_err());
585    }
586
587    #[test]
588    fn test_zero_estimators_error() {
589        let x = make_data();
590        let model = RandomTreesEmbedding::<f64>::new().with_n_estimators(0);
591        let result = model.fit(&x, &());
592        assert!(result.is_err());
593    }
594
595    #[test]
596    fn test_invalid_min_samples_split_error() {
597        let x = make_data();
598        let model = RandomTreesEmbedding::<f64>::new().with_min_samples_split(1);
599        let result = model.fit(&x, &());
600        assert!(result.is_err());
601    }
602
603    #[test]
604    fn test_shape_mismatch_error() {
605        let x_train = make_data();
606        let model = RandomTreesEmbedding::<f64>::new()
607            .with_n_estimators(3)
608            .with_random_state(42);
609        let fitted = model.fit(&x_train, &()).unwrap();
610
611        let x_test = Array2::<f64>::zeros((5, 10)); // wrong number of features
612        let result = fitted.transform(&x_test);
613        assert!(result.is_err());
614    }
615
616    #[test]
617    fn test_reproducibility() {
618        let x = make_data();
619        let model = RandomTreesEmbedding::<f64>::new()
620            .with_n_estimators(5)
621            .with_max_depth(Some(3))
622            .with_random_state(42);
623
624        let fitted1 = model.fit(&x, &()).unwrap();
625        let embedded1 = fitted1.transform(&x).unwrap();
626
627        let fitted2 = model.fit(&x, &()).unwrap();
628        let embedded2 = fitted2.transform(&x).unwrap();
629
630        assert_eq!(embedded1, embedded2);
631    }
632
633    #[test]
634    fn test_f32() {
635        let x = Array2::<f32>::from_shape_vec(
636            (6, 2),
637            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
638        )
639        .unwrap();
640        let model = RandomTreesEmbedding::<f32>::new()
641            .with_n_estimators(3)
642            .with_max_depth(Some(2))
643            .with_random_state(42);
644        let fitted = model.fit(&x, &()).unwrap();
645        let embedded = fitted.transform(&x).unwrap();
646        assert_eq!(embedded.nrows(), 6);
647    }
648
649    #[test]
650    fn test_fitted_accessors() {
651        let x = make_data();
652        let model = RandomTreesEmbedding::<f64>::new()
653            .with_n_estimators(5)
654            .with_max_depth(Some(3))
655            .with_random_state(42);
656        let fitted = model.fit(&x, &()).unwrap();
657        assert_eq!(fitted.n_estimators(), 5);
658        assert_eq!(fitted.n_features(), 3);
659        assert!(fitted.n_output_features() > 0);
660    }
661
662    #[test]
663    fn test_deeper_trees_more_leaves() {
664        let x = make_data();
665
666        let shallow = RandomTreesEmbedding::<f64>::new()
667            .with_n_estimators(1)
668            .with_max_depth(Some(1))
669            .with_random_state(42);
670        let fitted_shallow = shallow.fit(&x, &()).unwrap();
671
672        let deep = RandomTreesEmbedding::<f64>::new()
673            .with_n_estimators(1)
674            .with_max_depth(Some(5))
675            .with_random_state(42);
676        let fitted_deep = deep.fit(&x, &()).unwrap();
677
678        assert!(
679            fitted_deep.n_output_features() >= fitted_shallow.n_output_features(),
680            "deeper trees should have at least as many leaves"
681        );
682    }
683
684    #[test]
685    fn test_single_sample() {
686        let x = Array2::<f64>::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
687        let model = RandomTreesEmbedding::<f64>::new()
688            .with_n_estimators(3)
689            .with_max_depth(Some(3))
690            .with_random_state(42);
691        let fitted = model.fit(&x, &()).unwrap();
692        let embedded = fitted.transform(&x).unwrap();
693        assert_eq!(embedded.nrows(), 1);
694        // Single sample can't be split, so each tree has exactly 1 leaf.
695        assert_eq!(embedded.ncols(), 3);
696    }
697
698    #[test]
699    fn test_unlimited_depth() {
700        let x = make_data();
701        let model = RandomTreesEmbedding::<f64>::new()
702            .with_n_estimators(3)
703            .with_max_depth(None)
704            .with_random_state(42);
705        let fitted = model.fit(&x, &()).unwrap();
706        let embedded = fitted.transform(&x).unwrap();
707        assert_eq!(embedded.nrows(), 8);
708        assert!(embedded.ncols() > 0);
709    }
710}