Skip to main content

ferrolearn_tree/
random_trees_embedding.rs

1//! Random trees embedding for unsupervised feature transformation.
2//!
3//! This module provides [`RandomTreesEmbedding`], which transforms input
4//! features into a high-dimensional sparse binary representation by encoding
5//! each sample as the concatenation of one-hot encoded leaf indices across
6//! an ensemble of randomly built trees.
7//!
8//! Trees are built with purely random splits (random feature, random threshold
9//! between min and max), ignoring any target variable. This makes the
10//! embedding entirely unsupervised.
11//!
12//! # Examples
13//!
14//! ```
15//! use ferrolearn_tree::RandomTreesEmbedding;
16//! use ferrolearn_core::{Fit, Transform};
17//! use ndarray::Array2;
18//!
19//! let x = Array2::from_shape_vec((6, 2), vec![
20//!     1.0, 2.0,  2.0, 3.0,  3.0, 3.0,
21//!     5.0, 6.0,  6.0, 7.0,  7.0, 8.0,
22//! ]).unwrap();
23//!
24//! let model = RandomTreesEmbedding::<f64>::new()
25//!     .with_n_estimators(5)
26//!     .with_max_depth(Some(3))
27//!     .with_random_state(42);
28//! let fitted = model.fit(&x, &()).unwrap();
29//! let embedded = fitted.transform(&x).unwrap();
30//! // Output has n_samples rows and (total_leaves_across_trees) columns.
31//! assert_eq!(embedded.nrows(), 6);
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
36use ferrolearn_core::traits::{Fit, Transform};
37use ndarray::{Array2, ArrayView1};
38use num_traits::Float;
39use rand::SeedableRng;
40use rand::rngs::StdRng;
41use serde::{Deserialize, Serialize};
42
43use crate::decision_tree::Node;
44
45// ---------------------------------------------------------------------------
46// RandomTreesEmbedding
47// ---------------------------------------------------------------------------
48
49/// Random trees embedding for unsupervised feature transformation.
50///
51/// Builds an ensemble of randomly split trees (ignoring targets) and
52/// represents each sample as a one-hot encoding of its leaf index in
53/// each tree, concatenated across all trees.
54///
55/// This is useful for creating a nonlinear feature representation that
56/// can be fed into linear models.
57///
58/// # Type Parameters
59///
60/// - `F`: The floating-point type (`f32` or `f64`).
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RandomTreesEmbedding<F> {
63    /// Number of random trees to build.
64    pub n_estimators: usize,
65    /// Maximum depth of each tree. `None` means unlimited.
66    pub max_depth: Option<usize>,
67    /// Minimum number of samples required to split a node.
68    pub min_samples_split: usize,
69    /// Random seed for reproducibility. `None` means non-deterministic.
70    pub random_state: Option<u64>,
71    _marker: std::marker::PhantomData<F>,
72}
73
74impl<F: Float> RandomTreesEmbedding<F> {
75    /// Create a new `RandomTreesEmbedding` with default settings.
76    ///
77    /// Defaults: `n_estimators = 10`, `max_depth = Some(5)`,
78    /// `min_samples_split = 2`, `random_state = None`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            n_estimators: 10,
83            max_depth: Some(5),
84            min_samples_split: 2,
85            random_state: None,
86            _marker: std::marker::PhantomData,
87        }
88    }
89
90    /// Set the number of random trees.
91    #[must_use]
92    pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
93        self.n_estimators = n_estimators;
94        self
95    }
96
97    /// Set the maximum tree depth.
98    #[must_use]
99    pub fn with_max_depth(mut self, max_depth: Option<usize>) -> Self {
100        self.max_depth = max_depth;
101        self
102    }
103
104    /// Set the minimum number of samples required to split a node.
105    #[must_use]
106    pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
107        self.min_samples_split = min_samples_split;
108        self
109    }
110
111    /// Set the random seed for reproducibility.
112    #[must_use]
113    pub fn with_random_state(mut self, seed: u64) -> Self {
114        self.random_state = Some(seed);
115        self
116    }
117}
118
119impl<F: Float> Default for RandomTreesEmbedding<F> {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125// ---------------------------------------------------------------------------
126// FittedRandomTreesEmbedding
127// ---------------------------------------------------------------------------
128
129/// A fitted random trees embedding.
130///
131/// Stores the ensemble of randomly built trees. Each tree's leaves are
132/// enumerated, and [`Transform`] produces a one-hot encoded matrix
133/// of shape `(n_samples, total_leaves)`.
134#[derive(Debug, Clone)]
135pub struct FittedRandomTreesEmbedding<F> {
136    /// Individual trees, each stored as a flat node vector.
137    trees: Vec<Vec<Node<F>>>,
138    /// Per-tree leaf count (number of leaves in each tree).
139    leaf_counts: Vec<usize>,
140    /// Per-tree mapping from leaf node index to enumerated leaf position.
141    /// `leaf_maps[t][node_idx]` gives the leaf's position index within tree `t`.
142    leaf_maps: Vec<Vec<Option<usize>>>,
143    /// Total number of leaves across all trees (output dimensionality).
144    total_leaves: usize,
145    /// Number of features the model was trained on.
146    n_features: usize,
147}
148
149impl<F: Float + Send + Sync + 'static> FittedRandomTreesEmbedding<F> {
150    /// Returns the number of trees in the ensemble.
151    #[must_use]
152    pub fn n_estimators(&self) -> usize {
153        self.trees.len()
154    }
155
156    /// Returns the number of features the model was trained on.
157    #[must_use]
158    pub fn n_features(&self) -> usize {
159        self.n_features
160    }
161
162    /// Returns the total number of output features (total leaves across all trees).
163    #[must_use]
164    pub fn n_output_features(&self) -> usize {
165        self.total_leaves
166    }
167}
168
169impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for RandomTreesEmbedding<F> {
170    type Fitted = FittedRandomTreesEmbedding<F>;
171    type Error = FerroError;
172
173    /// Fit the random trees embedding on the training data.
174    ///
175    /// Builds an ensemble of trees with purely random splits (random feature,
176    /// random threshold between feature min and max in the current node),
177    /// ignoring any target variable.
178    ///
179    /// # Errors
180    ///
181    /// Returns [`FerroError::InsufficientSamples`] if there are no samples.
182    /// Returns [`FerroError::InvalidParameter`] if hyperparameters are invalid.
183    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedRandomTreesEmbedding<F>, FerroError> {
184        let (n_samples, n_features) = x.dim();
185
186        if n_samples == 0 {
187            return Err(FerroError::InsufficientSamples {
188                required: 1,
189                actual: 0,
190                context: "RandomTreesEmbedding requires at least one sample".into(),
191            });
192        }
193        if self.n_estimators == 0 {
194            return Err(FerroError::InvalidParameter {
195                name: "n_estimators".into(),
196                reason: "must be at least 1".into(),
197            });
198        }
199        if self.min_samples_split < 2 {
200            return Err(FerroError::InvalidParameter {
201                name: "min_samples_split".into(),
202                reason: "must be at least 2".into(),
203            });
204        }
205
206        let mut rng = if let Some(seed) = self.random_state {
207            StdRng::seed_from_u64(seed)
208        } else {
209            StdRng::from_os_rng()
210        };
211
212        let indices: Vec<usize> = (0..n_samples).collect();
213
214        let mut trees = Vec::with_capacity(self.n_estimators);
215        let mut leaf_counts = Vec::with_capacity(self.n_estimators);
216        let mut leaf_maps = Vec::with_capacity(self.n_estimators);
217        let mut total_leaves = 0;
218
219        for _ in 0..self.n_estimators {
220            let mut nodes = Vec::new();
221            build_random_tree(
222                x,
223                &indices,
224                &mut nodes,
225                0,
226                self.max_depth,
227                self.min_samples_split,
228                n_features,
229                &mut rng,
230            );
231
232            // Enumerate leaf nodes.
233            let mut leaf_map = vec![None; nodes.len()];
234            let mut count = 0;
235            for (idx, node) in nodes.iter().enumerate() {
236                if matches!(node, Node::Leaf { .. }) {
237                    leaf_map[idx] = Some(count);
238                    count += 1;
239                }
240            }
241
242            trees.push(nodes);
243            leaf_counts.push(count);
244            leaf_maps.push(leaf_map);
245            total_leaves += count;
246        }
247
248        Ok(FittedRandomTreesEmbedding {
249            trees,
250            leaf_counts,
251            leaf_maps,
252            total_leaves,
253            n_features,
254        })
255    }
256}
257
258impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedRandomTreesEmbedding<F> {
259    type Output = Array2<F>;
260    type Error = FerroError;
261
262    /// Transform the input data into a one-hot encoded leaf-index representation.
263    ///
264    /// For each sample, traverse each tree to a leaf node, then one-hot encode
265    /// the leaf index within that tree. The encodings for all trees are
266    /// concatenated horizontally to produce the output.
267    ///
268    /// Output shape: `(n_samples, total_leaves_across_all_trees)`.
269    ///
270    /// # Errors
271    ///
272    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
273    /// not match the training data.
274    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
275        if x.ncols() != self.n_features {
276            return Err(FerroError::ShapeMismatch {
277                expected: vec![self.n_features],
278                actual: vec![x.ncols()],
279                context: "number of features must match fitted model".into(),
280            });
281        }
282
283        let n_samples = x.nrows();
284        let mut output = Array2::zeros((n_samples, self.total_leaves));
285
286        let mut col_offset = 0;
287        for (tree_idx, tree_nodes) in self.trees.iter().enumerate() {
288            let leaf_map = &self.leaf_maps[tree_idx];
289            let n_leaves = self.leaf_counts[tree_idx];
290
291            for i in 0..n_samples {
292                let row = x.row(i);
293                let leaf_node_idx = traverse_tree(tree_nodes, &row);
294                if let Some(leaf_pos) = leaf_map[leaf_node_idx] {
295                    output[[i, col_offset + leaf_pos]] = F::one();
296                }
297            }
298
299            col_offset += n_leaves;
300        }
301
302        Ok(output)
303    }
304}
305
306// Pipeline integration.
307impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for RandomTreesEmbedding<F> {
308    fn fit_pipeline(
309        &self,
310        x: &Array2<F>,
311        _y: &ndarray::Array1<F>,
312    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
313        let fitted = self.fit(x, &())?;
314        Ok(Box::new(fitted))
315    }
316}
317
318impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
319    for FittedRandomTreesEmbedding<F>
320{
321    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
322        self.transform(x)
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Internal: random tree building (unsupervised)
328// ---------------------------------------------------------------------------
329
330/// Traverse a tree from root to leaf for a single sample, returning the leaf node index.
331fn traverse_tree<F: Float>(nodes: &[Node<F>], sample: &ArrayView1<F>) -> usize {
332    let mut idx = 0;
333    loop {
334        match &nodes[idx] {
335            Node::Split {
336                feature,
337                threshold,
338                left,
339                right,
340                ..
341            } => {
342                if sample[*feature] <= *threshold {
343                    idx = *left;
344                } else {
345                    idx = *right;
346                }
347            }
348            Node::Leaf { .. } => return idx,
349        }
350    }
351}
352
353/// Generate a uniform random float in `[min_val, max_val]`.
354fn random_threshold<F: Float>(rng: &mut StdRng, min_val: F, max_val: F) -> F {
355    use rand::RngCore;
356    let u = (rng.next_u64() as f64) / (u64::MAX as f64);
357    let range = max_val - min_val;
358    min_val + F::from(u).unwrap() * range
359}
360
361/// Build a random tree recursively with purely random splits (unsupervised).
362///
363/// At each node, a random feature is chosen and a random threshold is drawn
364/// uniformly between the feature's min and max in the current node.
365#[allow(clippy::too_many_arguments)]
366fn build_random_tree<F: Float>(
367    x: &Array2<F>,
368    indices: &[usize],
369    nodes: &mut Vec<Node<F>>,
370    depth: usize,
371    max_depth: Option<usize>,
372    min_samples_split: usize,
373    n_features: usize,
374    rng: &mut StdRng,
375) -> usize {
376    let n = indices.len();
377
378    // Stop if: too few samples, or max depth reached.
379    let should_stop =
380        n < min_samples_split || max_depth.is_some_and(|d| depth >= d);
381
382    if should_stop {
383        let idx = nodes.len();
384        nodes.push(Node::Leaf {
385            value: F::zero(),
386            class_distribution: None,
387            n_samples: n,
388        });
389        return idx;
390    }
391
392    // Try random features until we find one that can split.
393    let max_attempts = n_features * 2;
394    for _ in 0..max_attempts {
395        use rand::RngCore;
396        let feature = (rng.next_u64() as usize) % n_features;
397
398        // Find min and max of this feature for the current indices.
399        let mut min_val = x[[indices[0], feature]];
400        let mut max_val = min_val;
401        for &i in &indices[1..] {
402            let v = x[[i, feature]];
403            if v < min_val {
404                min_val = v;
405            }
406            if v > max_val {
407                max_val = v;
408            }
409        }
410
411        // If all values are the same, try another feature.
412        if min_val >= max_val {
413            continue;
414        }
415
416        let threshold = random_threshold(rng, min_val, max_val);
417
418        // Partition indices.
419        let mut left_indices = Vec::new();
420        let mut right_indices = Vec::new();
421        for &i in indices {
422            if x[[i, feature]] <= threshold {
423                left_indices.push(i);
424            } else {
425                right_indices.push(i);
426            }
427        }
428
429        // If the split is degenerate, try again.
430        if left_indices.is_empty() || right_indices.is_empty() {
431            continue;
432        }
433
434        // Reserve a slot for this node.
435        let node_idx = nodes.len();
436        nodes.push(Node::Leaf {
437            value: F::zero(),
438            class_distribution: None,
439            n_samples: 0,
440        }); // placeholder
441
442        let left_child = build_random_tree(
443            x,
444            &left_indices,
445            nodes,
446            depth + 1,
447            max_depth,
448            min_samples_split,
449            n_features,
450            rng,
451        );
452        let right_child = build_random_tree(
453            x,
454            &right_indices,
455            nodes,
456            depth + 1,
457            max_depth,
458            min_samples_split,
459            n_features,
460            rng,
461        );
462
463        nodes[node_idx] = Node::Split {
464            feature,
465            threshold,
466            left: left_child,
467            right: right_child,
468            impurity_decrease: F::zero(),
469            n_samples: n,
470        };
471
472        return node_idx;
473    }
474
475    // Could not find a splittable feature — make this a leaf.
476    let idx = nodes.len();
477    nodes.push(Node::Leaf {
478        value: F::zero(),
479        class_distribution: None,
480        n_samples: n,
481    });
482    idx
483}
484
485// ---------------------------------------------------------------------------
486// Tests
487// ---------------------------------------------------------------------------
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use ndarray::Array2;
493
494    fn make_data() -> Array2<f64> {
495        Array2::from_shape_vec(
496            (8, 3),
497            vec![
498                1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
499                7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0,
500            ],
501        )
502        .unwrap()
503    }
504
505    #[test]
506    fn test_default() {
507        let model = RandomTreesEmbedding::<f64>::new();
508        assert_eq!(model.n_estimators, 10);
509        assert_eq!(model.max_depth, Some(5));
510        assert_eq!(model.min_samples_split, 2);
511        assert!(model.random_state.is_none());
512    }
513
514    #[test]
515    fn test_builder() {
516        let model = RandomTreesEmbedding::<f64>::new()
517            .with_n_estimators(20)
518            .with_max_depth(Some(3))
519            .with_min_samples_split(5)
520            .with_random_state(42);
521        assert_eq!(model.n_estimators, 20);
522        assert_eq!(model.max_depth, Some(3));
523        assert_eq!(model.min_samples_split, 5);
524        assert_eq!(model.random_state, Some(42));
525    }
526
527    #[test]
528    fn test_fit_transform_basic() {
529        let x = make_data();
530        let model = RandomTreesEmbedding::<f64>::new()
531            .with_n_estimators(5)
532            .with_max_depth(Some(3))
533            .with_random_state(42);
534        let fitted = model.fit(&x, &()).unwrap();
535        let embedded = fitted.transform(&x).unwrap();
536
537        assert_eq!(embedded.nrows(), 8);
538        // Each row should have exactly n_estimators ones (one per tree).
539        for i in 0..8 {
540            let row_sum: f64 = embedded.row(i).iter().copied().sum();
541            assert!(
542                (row_sum - 5.0).abs() < 1e-10,
543                "row {i} should have exactly 5 ones, got {row_sum}"
544            );
545        }
546    }
547
548    #[test]
549    fn test_output_is_binary() {
550        let x = make_data();
551        let model = RandomTreesEmbedding::<f64>::new()
552            .with_n_estimators(3)
553            .with_max_depth(Some(2))
554            .with_random_state(42);
555        let fitted = model.fit(&x, &()).unwrap();
556        let embedded = fitted.transform(&x).unwrap();
557
558        // All values should be 0.0 or 1.0
559        for &val in embedded.iter() {
560            assert!(
561                (val - 0.0).abs() < 1e-10 || (val - 1.0).abs() < 1e-10,
562                "values should be 0 or 1, got {val}"
563            );
564        }
565    }
566
567    #[test]
568    fn test_total_leaves_matches_output_cols() {
569        let x = make_data();
570        let model = RandomTreesEmbedding::<f64>::new()
571            .with_n_estimators(5)
572            .with_max_depth(Some(3))
573            .with_random_state(42);
574        let fitted = model.fit(&x, &()).unwrap();
575        let embedded = fitted.transform(&x).unwrap();
576
577        assert_eq!(embedded.ncols(), fitted.n_output_features());
578    }
579
580    #[test]
581    fn test_empty_input_error() {
582        let x = Array2::<f64>::zeros((0, 3));
583        let model = RandomTreesEmbedding::<f64>::new();
584        let result = model.fit(&x, &());
585        assert!(result.is_err());
586    }
587
588    #[test]
589    fn test_zero_estimators_error() {
590        let x = make_data();
591        let model = RandomTreesEmbedding::<f64>::new().with_n_estimators(0);
592        let result = model.fit(&x, &());
593        assert!(result.is_err());
594    }
595
596    #[test]
597    fn test_invalid_min_samples_split_error() {
598        let x = make_data();
599        let model = RandomTreesEmbedding::<f64>::new().with_min_samples_split(1);
600        let result = model.fit(&x, &());
601        assert!(result.is_err());
602    }
603
604    #[test]
605    fn test_shape_mismatch_error() {
606        let x_train = make_data();
607        let model = RandomTreesEmbedding::<f64>::new()
608            .with_n_estimators(3)
609            .with_random_state(42);
610        let fitted = model.fit(&x_train, &()).unwrap();
611
612        let x_test = Array2::<f64>::zeros((5, 10)); // wrong number of features
613        let result = fitted.transform(&x_test);
614        assert!(result.is_err());
615    }
616
617    #[test]
618    fn test_reproducibility() {
619        let x = make_data();
620        let model = RandomTreesEmbedding::<f64>::new()
621            .with_n_estimators(5)
622            .with_max_depth(Some(3))
623            .with_random_state(42);
624
625        let fitted1 = model.fit(&x, &()).unwrap();
626        let embedded1 = fitted1.transform(&x).unwrap();
627
628        let fitted2 = model.fit(&x, &()).unwrap();
629        let embedded2 = fitted2.transform(&x).unwrap();
630
631        assert_eq!(embedded1, embedded2);
632    }
633
634    #[test]
635    fn test_f32() {
636        let x = Array2::<f32>::from_shape_vec(
637            (6, 2),
638            vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0],
639        )
640        .unwrap();
641        let model = RandomTreesEmbedding::<f32>::new()
642            .with_n_estimators(3)
643            .with_max_depth(Some(2))
644            .with_random_state(42);
645        let fitted = model.fit(&x, &()).unwrap();
646        let embedded = fitted.transform(&x).unwrap();
647        assert_eq!(embedded.nrows(), 6);
648    }
649
650    #[test]
651    fn test_fitted_accessors() {
652        let x = make_data();
653        let model = RandomTreesEmbedding::<f64>::new()
654            .with_n_estimators(5)
655            .with_max_depth(Some(3))
656            .with_random_state(42);
657        let fitted = model.fit(&x, &()).unwrap();
658        assert_eq!(fitted.n_estimators(), 5);
659        assert_eq!(fitted.n_features(), 3);
660        assert!(fitted.n_output_features() > 0);
661    }
662
663    #[test]
664    fn test_deeper_trees_more_leaves() {
665        let x = make_data();
666
667        let shallow = RandomTreesEmbedding::<f64>::new()
668            .with_n_estimators(1)
669            .with_max_depth(Some(1))
670            .with_random_state(42);
671        let fitted_shallow = shallow.fit(&x, &()).unwrap();
672
673        let deep = RandomTreesEmbedding::<f64>::new()
674            .with_n_estimators(1)
675            .with_max_depth(Some(5))
676            .with_random_state(42);
677        let fitted_deep = deep.fit(&x, &()).unwrap();
678
679        assert!(
680            fitted_deep.n_output_features() >= fitted_shallow.n_output_features(),
681            "deeper trees should have at least as many leaves"
682        );
683    }
684
685    #[test]
686    fn test_single_sample() {
687        let x = Array2::<f64>::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
688        let model = RandomTreesEmbedding::<f64>::new()
689            .with_n_estimators(3)
690            .with_max_depth(Some(3))
691            .with_random_state(42);
692        let fitted = model.fit(&x, &()).unwrap();
693        let embedded = fitted.transform(&x).unwrap();
694        assert_eq!(embedded.nrows(), 1);
695        // Single sample can't be split, so each tree has exactly 1 leaf.
696        assert_eq!(embedded.ncols(), 3);
697    }
698
699    #[test]
700    fn test_unlimited_depth() {
701        let x = make_data();
702        let model = RandomTreesEmbedding::<f64>::new()
703            .with_n_estimators(3)
704            .with_max_depth(None)
705            .with_random_state(42);
706        let fitted = model.fit(&x, &()).unwrap();
707        let embedded = fitted.transform(&x).unwrap();
708        assert_eq!(embedded.nrows(), 8);
709        assert!(embedded.ncols() > 0);
710    }
711}