ghostflow_ml/
tree.rs

1//! Decision Tree implementations - Real CART algorithm
2
3use ghostflow_core::Tensor;
4
5/// Split criterion for decision trees
6#[derive(Debug, Clone, Copy)]
7pub enum Criterion {
8    Gini,
9    Entropy,
10    MSE,
11    MAE,
12}
13
14/// A node in the decision tree
15#[derive(Debug, Clone)]
16pub struct TreeNode {
17    /// Feature index for split
18    pub feature_index: Option<usize>,
19    /// Threshold for split
20    pub threshold: Option<f32>,
21    /// Left child
22    pub left: Option<Box<TreeNode>>,
23    /// Right child
24    pub right: Option<Box<TreeNode>>,
25    /// Prediction value (for leaf nodes)
26    pub value: Option<f32>,
27    /// Class probabilities (for classification)
28    pub class_probs: Option<Vec<f32>>,
29    /// Number of samples at this node
30    pub n_samples: usize,
31    /// Impurity at this node
32    pub impurity: f32,
33}
34
35impl TreeNode {
36    fn leaf(value: f32, n_samples: usize, impurity: f32) -> Self {
37        TreeNode {
38            feature_index: None,
39            threshold: None,
40            left: None,
41            right: None,
42            value: Some(value),
43            class_probs: None,
44            n_samples,
45            impurity,
46        }
47    }
48
49    fn leaf_classification(class_probs: Vec<f32>, n_samples: usize, impurity: f32) -> Self {
50        let value = class_probs.iter()
51            .enumerate()
52            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
53            .map(|(i, _)| i as f32)
54            .unwrap_or(0.0);
55        
56        TreeNode {
57            feature_index: None,
58            threshold: None,
59            left: None,
60            right: None,
61            value: Some(value),
62            class_probs: Some(class_probs),
63            n_samples,
64            impurity,
65        }
66    }
67
68    fn is_leaf(&self) -> bool {
69        self.left.is_none() && self.right.is_none()
70    }
71}
72
73/// Decision Tree Classifier using CART algorithm
74pub struct DecisionTreeClassifier {
75    /// Maximum depth of tree
76    pub max_depth: Option<usize>,
77    /// Minimum samples to split
78    pub min_samples_split: usize,
79    /// Minimum samples in leaf
80    pub min_samples_leaf: usize,
81    /// Maximum features to consider
82    pub max_features: Option<usize>,
83    /// Split criterion
84    pub criterion: Criterion,
85    /// Number of classes
86    n_classes: usize,
87    /// Root node
88    root: Option<TreeNode>,
89}
90
91impl DecisionTreeClassifier {
92    pub fn new() -> Self {
93        DecisionTreeClassifier {
94            max_depth: None,
95            min_samples_split: 2,
96            min_samples_leaf: 1,
97            max_features: None,
98            criterion: Criterion::Gini,
99            n_classes: 0,
100            root: None,
101        }
102    }
103
104    pub fn max_depth(mut self, depth: usize) -> Self {
105        self.max_depth = Some(depth);
106        self
107    }
108
109    pub fn min_samples_split(mut self, n: usize) -> Self {
110        self.min_samples_split = n;
111        self
112    }
113
114    pub fn min_samples_leaf(mut self, n: usize) -> Self {
115        self.min_samples_leaf = n;
116        self
117    }
118
119    pub fn criterion(mut self, criterion: Criterion) -> Self {
120        self.criterion = criterion;
121        self
122    }
123
124    /// Fit the decision tree
125    pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
126        let x_data = x.data_f32();
127        let y_data = y.data_f32();
128        let n_samples = x.dims()[0];
129        let n_features = x.dims()[1];
130        
131        // Determine number of classes
132        self.n_classes = y_data.iter()
133            .map(|&v| v as usize)
134            .max()
135            .unwrap_or(0) + 1;
136        
137        let indices: Vec<usize> = (0..n_samples).collect();
138        
139        self.root = Some(self.build_tree(
140            &x_data, &y_data, &indices, n_features, 0
141        ));
142    }
143
144    fn build_tree(
145        &self,
146        x: &[f32],
147        y: &[f32],
148        indices: &[usize],
149        n_features: usize,
150        depth: usize,
151    ) -> TreeNode {
152        let n_samples = indices.len();
153        
154        // Calculate class distribution
155        let mut class_counts = vec![0usize; self.n_classes];
156        for &idx in indices {
157            let class = y[idx] as usize;
158            if class < self.n_classes {
159                class_counts[class] += 1;
160            }
161        }
162        
163        let class_probs: Vec<f32> = class_counts.iter()
164            .map(|&c| c as f32 / n_samples as f32)
165            .collect();
166        
167        let impurity = self.calculate_impurity(&class_probs);
168        
169        // Check stopping conditions
170        let should_stop = 
171            n_samples < self.min_samples_split ||
172            self.max_depth.is_some_and(|d| depth >= d) ||
173            impurity < 1e-7 ||
174            class_counts.iter().filter(|&&c| c > 0).count() <= 1;
175        
176        if should_stop {
177            return TreeNode::leaf_classification(class_probs, n_samples, impurity);
178        }
179        
180        // Find best split
181        let max_features = self.max_features.unwrap_or(n_features);
182        let features_to_try: Vec<usize> = if max_features < n_features {
183            use rand::seq::SliceRandom;
184            let mut rng = rand::thread_rng();
185            let mut all: Vec<usize> = (0..n_features).collect();
186            all.shuffle(&mut rng);
187            all.into_iter().take(max_features).collect()
188        } else {
189            (0..n_features).collect()
190        };
191        
192        let mut best_gain = 0.0f32;
193        let mut best_feature = 0;
194        let mut best_threshold = 0.0f32;
195        let mut best_left_indices = Vec::new();
196        let mut best_right_indices = Vec::new();
197        
198        for &feature in &features_to_try {
199            // Get unique values for this feature
200            let mut values: Vec<f32> = indices.iter()
201                .map(|&idx| x[idx * n_features + feature])
202                .collect();
203            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
204            values.dedup();
205            
206            // Try each threshold
207            for i in 0..values.len().saturating_sub(1) {
208                let threshold = (values[i] + values[i + 1]) / 2.0;
209                
210                let (left_indices, right_indices): (Vec<_>, Vec<_>) = indices.iter()
211                    .partition(|&&idx| x[idx * n_features + feature] <= threshold);
212                
213                if left_indices.len() < self.min_samples_leaf || 
214                   right_indices.len() < self.min_samples_leaf {
215                    continue;
216                }
217                
218                let gain = self.information_gain(
219                    y, indices, &left_indices, &right_indices, impurity
220                );
221                
222                if gain > best_gain {
223                    best_gain = gain;
224                    best_feature = feature;
225                    best_threshold = threshold;
226                    best_left_indices = left_indices;
227                    best_right_indices = right_indices;
228                }
229            }
230        }
231        
232        // If no good split found, make leaf
233        if best_gain <= 0.0 || best_left_indices.is_empty() || best_right_indices.is_empty() {
234            return TreeNode::leaf_classification(class_probs, n_samples, impurity);
235        }
236        
237        // Recursively build children
238        let left = self.build_tree(x, y, &best_left_indices, n_features, depth + 1);
239        let right = self.build_tree(x, y, &best_right_indices, n_features, depth + 1);
240        
241        TreeNode {
242            feature_index: Some(best_feature),
243            threshold: Some(best_threshold),
244            left: Some(Box::new(left)),
245            right: Some(Box::new(right)),
246            value: None,
247            class_probs: Some(class_probs),
248            n_samples,
249            impurity,
250        }
251    }
252
253    fn calculate_impurity(&self, probs: &[f32]) -> f32 {
254        match self.criterion {
255            Criterion::Gini => {
256                1.0 - probs.iter().map(|&p| p * p).sum::<f32>()
257            }
258            Criterion::Entropy => {
259                -probs.iter()
260                    .filter(|&&p| p > 0.0)
261                    .map(|&p| p * p.ln())
262                    .sum::<f32>()
263            }
264            _ => 0.0,
265        }
266    }
267
268    fn information_gain(
269        &self,
270        y: &[f32],
271        parent_indices: &[usize],
272        left_indices: &[usize],
273        right_indices: &[usize],
274        parent_impurity: f32,
275    ) -> f32 {
276        let n_parent = parent_indices.len() as f32;
277        let n_left = left_indices.len() as f32;
278        let n_right = right_indices.len() as f32;
279        
280        let left_probs = self.class_probs_from_indices(y, left_indices);
281        let right_probs = self.class_probs_from_indices(y, right_indices);
282        
283        let left_impurity = self.calculate_impurity(&left_probs);
284        let right_impurity = self.calculate_impurity(&right_probs);
285        
286        parent_impurity - (n_left / n_parent) * left_impurity - (n_right / n_parent) * right_impurity
287    }
288
289    fn class_probs_from_indices(&self, y: &[f32], indices: &[usize]) -> Vec<f32> {
290        let mut counts = vec![0usize; self.n_classes];
291        for &idx in indices {
292            let class = y[idx] as usize;
293            if class < self.n_classes {
294                counts[class] += 1;
295            }
296        }
297        let total = indices.len() as f32;
298        counts.iter().map(|&c| c as f32 / total).collect()
299    }
300
301    /// Predict class labels
302    pub fn predict(&self, x: &Tensor) -> Tensor {
303        let x_data = x.data_f32();
304        let n_samples = x.dims()[0];
305        let n_features = x.dims()[1];
306        
307        let predictions: Vec<f32> = (0..n_samples)
308            .map(|i| {
309                let sample = &x_data[i * n_features..(i + 1) * n_features];
310                self.predict_sample(sample)
311            })
312            .collect();
313        
314        Tensor::from_slice(&predictions, &[n_samples]).unwrap()
315    }
316
317    /// Predict class probabilities
318    pub fn predict_proba(&self, x: &Tensor) -> Tensor {
319        let x_data = x.data_f32();
320        let n_samples = x.dims()[0];
321        let n_features = x.dims()[1];
322        
323        let mut probs = Vec::with_capacity(n_samples * self.n_classes);
324        
325        for i in 0..n_samples {
326            let sample = &x_data[i * n_features..(i + 1) * n_features];
327            let sample_probs = self.predict_proba_sample(sample);
328            probs.extend(sample_probs);
329        }
330        
331        Tensor::from_slice(&probs, &[n_samples, self.n_classes]).unwrap()
332    }
333
334    fn predict_sample(&self, sample: &[f32]) -> f32 {
335        let mut node = self.root.as_ref().unwrap();
336        
337        while !node.is_leaf() {
338            let feature = node.feature_index.unwrap();
339            let threshold = node.threshold.unwrap();
340            
341            if sample[feature] <= threshold {
342                node = node.left.as_ref().unwrap();
343            } else {
344                node = node.right.as_ref().unwrap();
345            }
346        }
347        
348        node.value.unwrap()
349    }
350
351    fn predict_proba_sample(&self, sample: &[f32]) -> Vec<f32> {
352        let mut node = self.root.as_ref().unwrap();
353        
354        while !node.is_leaf() {
355            let feature = node.feature_index.unwrap();
356            let threshold = node.threshold.unwrap();
357            
358            if sample[feature] <= threshold {
359                node = node.left.as_ref().unwrap();
360            } else {
361                node = node.right.as_ref().unwrap();
362            }
363        }
364        
365        node.class_probs.clone().unwrap_or_else(|| vec![0.0; self.n_classes])
366    }
367}
368
369impl Default for DecisionTreeClassifier {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375/// Decision Tree Regressor using CART algorithm
376pub struct DecisionTreeRegressor {
377    pub max_depth: Option<usize>,
378    pub min_samples_split: usize,
379    pub min_samples_leaf: usize,
380    pub max_features: Option<usize>,
381    pub criterion: Criterion,
382    root: Option<TreeNode>,
383}
384
385impl DecisionTreeRegressor {
386    pub fn new() -> Self {
387        DecisionTreeRegressor {
388            max_depth: None,
389            min_samples_split: 2,
390            min_samples_leaf: 1,
391            max_features: None,
392            criterion: Criterion::MSE,
393            root: None,
394        }
395    }
396
397    pub fn max_depth(mut self, depth: usize) -> Self {
398        self.max_depth = Some(depth);
399        self
400    }
401
402    pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
403        let x_data = x.data_f32();
404        let y_data = y.data_f32();
405        let n_samples = x.dims()[0];
406        let n_features = x.dims()[1];
407        
408        let indices: Vec<usize> = (0..n_samples).collect();
409        
410        self.root = Some(self.build_tree(&x_data, &y_data, &indices, n_features, 0));
411    }
412
413    fn build_tree(
414        &self,
415        x: &[f32],
416        y: &[f32],
417        indices: &[usize],
418        n_features: usize,
419        depth: usize,
420    ) -> TreeNode {
421        let n_samples = indices.len();
422        
423        // Calculate mean and variance
424        let mean: f32 = indices.iter().map(|&i| y[i]).sum::<f32>() / n_samples as f32;
425        let variance: f32 = indices.iter()
426            .map(|&i| (y[i] - mean).powi(2))
427            .sum::<f32>() / n_samples as f32;
428        
429        // Check stopping conditions
430        let should_stop = 
431            n_samples < self.min_samples_split ||
432            self.max_depth.is_some_and(|d| depth >= d) ||
433            variance < 1e-7;
434        
435        if should_stop {
436            return TreeNode::leaf(mean, n_samples, variance);
437        }
438        
439        // Find best split
440        let mut best_mse = f32::INFINITY;
441        let mut best_feature = 0;
442        let mut best_threshold = 0.0f32;
443        let mut best_left_indices = Vec::new();
444        let mut best_right_indices = Vec::new();
445        
446        for feature in 0..n_features {
447            let mut values: Vec<f32> = indices.iter()
448                .map(|&idx| x[idx * n_features + feature])
449                .collect();
450            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
451            values.dedup();
452            
453            for i in 0..values.len().saturating_sub(1) {
454                let threshold = (values[i] + values[i + 1]) / 2.0;
455                
456                let (left_indices, right_indices): (Vec<_>, Vec<_>) = indices.iter()
457                    .partition(|&&idx| x[idx * n_features + feature] <= threshold);
458                
459                if left_indices.len() < self.min_samples_leaf || 
460                   right_indices.len() < self.min_samples_leaf {
461                    continue;
462                }
463                
464                let left_mean: f32 = left_indices.iter().map(|&i| y[i]).sum::<f32>() / left_indices.len() as f32;
465                let right_mean: f32 = right_indices.iter().map(|&i| y[i]).sum::<f32>() / right_indices.len() as f32;
466                
467                let left_mse: f32 = left_indices.iter().map(|&i| {
468                    let diff: f32 = y[i] - left_mean;
469                    diff.powi(2)
470                }).sum::<f32>();
471                let right_mse: f32 = right_indices.iter().map(|&i| {
472                    let diff: f32 = y[i] - right_mean;
473                    diff.powi(2)
474                }).sum::<f32>();
475                let total_mse = left_mse + right_mse;
476                
477                if total_mse < best_mse {
478                    best_mse = total_mse;
479                    best_feature = feature;
480                    best_threshold = threshold;
481                    best_left_indices = left_indices;
482                    best_right_indices = right_indices;
483                }
484            }
485        }
486        
487        if best_left_indices.is_empty() || best_right_indices.is_empty() {
488            return TreeNode::leaf(mean, n_samples, variance);
489        }
490        
491        let left = self.build_tree(x, y, &best_left_indices, n_features, depth + 1);
492        let right = self.build_tree(x, y, &best_right_indices, n_features, depth + 1);
493        
494        TreeNode {
495            feature_index: Some(best_feature),
496            threshold: Some(best_threshold),
497            left: Some(Box::new(left)),
498            right: Some(Box::new(right)),
499            value: Some(mean),
500            class_probs: None,
501            n_samples,
502            impurity: variance,
503        }
504    }
505
506    pub fn predict(&self, x: &Tensor) -> Tensor {
507        let x_data = x.data_f32();
508        let n_samples = x.dims()[0];
509        let n_features = x.dims()[1];
510        
511        let predictions: Vec<f32> = (0..n_samples)
512            .map(|i| {
513                let sample = &x_data[i * n_features..(i + 1) * n_features];
514                self.predict_sample(sample)
515            })
516            .collect();
517        
518        Tensor::from_slice(&predictions, &[n_samples]).unwrap()
519    }
520
521    fn predict_sample(&self, sample: &[f32]) -> f32 {
522        let mut node = self.root.as_ref().unwrap();
523        
524        while !node.is_leaf() {
525            let feature = node.feature_index.unwrap();
526            let threshold = node.threshold.unwrap();
527            
528            if sample[feature] <= threshold {
529                node = node.left.as_ref().unwrap();
530            } else {
531                node = node.right.as_ref().unwrap();
532            }
533        }
534        
535        node.value.unwrap()
536    }
537}
538
539impl Default for DecisionTreeRegressor {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_decision_tree_classifier() {
551        // Simple XOR-like problem
552        let x = Tensor::from_slice(&[0.0f32, 0.0,
553            0.0, 1.0,
554            1.0, 0.0,
555            1.0, 1.0,
556        ], &[4, 2]).unwrap();
557        
558        let y = Tensor::from_slice(&[0.0f32, 1.0, 1.0, 0.0], &[4]).unwrap();
559        
560        let mut tree = DecisionTreeClassifier::new().max_depth(3);
561        tree.fit(&x, &y);
562        
563        let predictions = tree.predict(&x);
564        let pred_data = predictions.storage().as_slice::<f32>().to_vec();
565        
566        // Should learn the XOR pattern
567        assert_eq!(pred_data.len(), 4);
568    }
569
570    #[test]
571    fn test_decision_tree_regressor() {
572        let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0,
573        ], &[5, 1]).unwrap();
574        
575        let y = Tensor::from_slice(&[2.0f32, 4.0, 6.0, 8.0, 10.0], &[5]).unwrap();
576        
577        let mut tree = DecisionTreeRegressor::new().max_depth(5);
578        tree.fit(&x, &y);
579        
580        let predictions = tree.predict(&x);
581        let pred_data = predictions.storage().as_slice::<f32>().to_vec();
582        
583        // Should approximate y = 2x
584        assert_eq!(pred_data.len(), 5);
585    }
586}
587
588