1use ghostflow_core::Tensor;
4
5#[derive(Debug, Clone, Copy)]
7pub enum Criterion {
8 Gini,
9 Entropy,
10 MSE,
11 MAE,
12}
13
14#[derive(Debug, Clone)]
16pub struct TreeNode {
17 pub feature_index: Option<usize>,
19 pub threshold: Option<f32>,
21 pub left: Option<Box<TreeNode>>,
23 pub right: Option<Box<TreeNode>>,
25 pub value: Option<f32>,
27 pub class_probs: Option<Vec<f32>>,
29 pub n_samples: usize,
31 pub impurity: f32,
33}
34
35impl TreeNode {
36 fn leaf(value: f32, n_samples: usize, impurity: f32) -> Self {
37 TreeNode {
38 feature_index: None,
39 threshold: None,
40 left: None,
41 right: None,
42 value: Some(value),
43 class_probs: None,
44 n_samples,
45 impurity,
46 }
47 }
48
49 fn leaf_classification(class_probs: Vec<f32>, n_samples: usize, impurity: f32) -> Self {
50 let value = class_probs.iter()
51 .enumerate()
52 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
53 .map(|(i, _)| i as f32)
54 .unwrap_or(0.0);
55
56 TreeNode {
57 feature_index: None,
58 threshold: None,
59 left: None,
60 right: None,
61 value: Some(value),
62 class_probs: Some(class_probs),
63 n_samples,
64 impurity,
65 }
66 }
67
68 fn is_leaf(&self) -> bool {
69 self.left.is_none() && self.right.is_none()
70 }
71}
72
73pub struct DecisionTreeClassifier {
75 pub max_depth: Option<usize>,
77 pub min_samples_split: usize,
79 pub min_samples_leaf: usize,
81 pub max_features: Option<usize>,
83 pub criterion: Criterion,
85 n_classes: usize,
87 root: Option<TreeNode>,
89}
90
91impl DecisionTreeClassifier {
92 pub fn new() -> Self {
93 DecisionTreeClassifier {
94 max_depth: None,
95 min_samples_split: 2,
96 min_samples_leaf: 1,
97 max_features: None,
98 criterion: Criterion::Gini,
99 n_classes: 0,
100 root: None,
101 }
102 }
103
104 pub fn max_depth(mut self, depth: usize) -> Self {
105 self.max_depth = Some(depth);
106 self
107 }
108
109 pub fn min_samples_split(mut self, n: usize) -> Self {
110 self.min_samples_split = n;
111 self
112 }
113
114 pub fn min_samples_leaf(mut self, n: usize) -> Self {
115 self.min_samples_leaf = n;
116 self
117 }
118
119 pub fn criterion(mut self, criterion: Criterion) -> Self {
120 self.criterion = criterion;
121 self
122 }
123
124 pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
126 let x_data = x.data_f32();
127 let y_data = y.data_f32();
128 let n_samples = x.dims()[0];
129 let n_features = x.dims()[1];
130
131 self.n_classes = y_data.iter()
133 .map(|&v| v as usize)
134 .max()
135 .unwrap_or(0) + 1;
136
137 let indices: Vec<usize> = (0..n_samples).collect();
138
139 self.root = Some(self.build_tree(
140 &x_data, &y_data, &indices, n_features, 0
141 ));
142 }
143
144 fn build_tree(
145 &self,
146 x: &[f32],
147 y: &[f32],
148 indices: &[usize],
149 n_features: usize,
150 depth: usize,
151 ) -> TreeNode {
152 let n_samples = indices.len();
153
154 let mut class_counts = vec![0usize; self.n_classes];
156 for &idx in indices {
157 let class = y[idx] as usize;
158 if class < self.n_classes {
159 class_counts[class] += 1;
160 }
161 }
162
163 let class_probs: Vec<f32> = class_counts.iter()
164 .map(|&c| c as f32 / n_samples as f32)
165 .collect();
166
167 let impurity = self.calculate_impurity(&class_probs);
168
169 let should_stop =
171 n_samples < self.min_samples_split ||
172 self.max_depth.map_or(false, |d| depth >= d) ||
173 impurity < 1e-7 ||
174 class_counts.iter().filter(|&&c| c > 0).count() <= 1;
175
176 if should_stop {
177 return TreeNode::leaf_classification(class_probs, n_samples, impurity);
178 }
179
180 let max_features = self.max_features.unwrap_or(n_features);
182 let features_to_try: Vec<usize> = if max_features < n_features {
183 use rand::seq::SliceRandom;
184 let mut rng = rand::thread_rng();
185 let mut all: Vec<usize> = (0..n_features).collect();
186 all.shuffle(&mut rng);
187 all.into_iter().take(max_features).collect()
188 } else {
189 (0..n_features).collect()
190 };
191
192 let mut best_gain = 0.0f32;
193 let mut best_feature = 0;
194 let mut best_threshold = 0.0f32;
195 let mut best_left_indices = Vec::new();
196 let mut best_right_indices = Vec::new();
197
198 for &feature in &features_to_try {
199 let mut values: Vec<f32> = indices.iter()
201 .map(|&idx| x[idx * n_features + feature])
202 .collect();
203 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
204 values.dedup();
205
206 for i in 0..values.len().saturating_sub(1) {
208 let threshold = (values[i] + values[i + 1]) / 2.0;
209
210 let (left_indices, right_indices): (Vec<_>, Vec<_>) = indices.iter()
211 .partition(|&&idx| x[idx * n_features + feature] <= threshold);
212
213 if left_indices.len() < self.min_samples_leaf ||
214 right_indices.len() < self.min_samples_leaf {
215 continue;
216 }
217
218 let gain = self.information_gain(
219 y, indices, &left_indices, &right_indices, impurity
220 );
221
222 if gain > best_gain {
223 best_gain = gain;
224 best_feature = feature;
225 best_threshold = threshold;
226 best_left_indices = left_indices;
227 best_right_indices = right_indices;
228 }
229 }
230 }
231
232 if best_gain <= 0.0 || best_left_indices.is_empty() || best_right_indices.is_empty() {
234 return TreeNode::leaf_classification(class_probs, n_samples, impurity);
235 }
236
237 let left = self.build_tree(x, y, &best_left_indices, n_features, depth + 1);
239 let right = self.build_tree(x, y, &best_right_indices, n_features, depth + 1);
240
241 TreeNode {
242 feature_index: Some(best_feature),
243 threshold: Some(best_threshold),
244 left: Some(Box::new(left)),
245 right: Some(Box::new(right)),
246 value: None,
247 class_probs: Some(class_probs),
248 n_samples,
249 impurity,
250 }
251 }
252
253 fn calculate_impurity(&self, probs: &[f32]) -> f32 {
254 match self.criterion {
255 Criterion::Gini => {
256 1.0 - probs.iter().map(|&p| p * p).sum::<f32>()
257 }
258 Criterion::Entropy => {
259 -probs.iter()
260 .filter(|&&p| p > 0.0)
261 .map(|&p| p * p.ln())
262 .sum::<f32>()
263 }
264 _ => 0.0,
265 }
266 }
267
268 fn information_gain(
269 &self,
270 y: &[f32],
271 parent_indices: &[usize],
272 left_indices: &[usize],
273 right_indices: &[usize],
274 parent_impurity: f32,
275 ) -> f32 {
276 let n_parent = parent_indices.len() as f32;
277 let n_left = left_indices.len() as f32;
278 let n_right = right_indices.len() as f32;
279
280 let left_probs = self.class_probs_from_indices(y, left_indices);
281 let right_probs = self.class_probs_from_indices(y, right_indices);
282
283 let left_impurity = self.calculate_impurity(&left_probs);
284 let right_impurity = self.calculate_impurity(&right_probs);
285
286 parent_impurity - (n_left / n_parent) * left_impurity - (n_right / n_parent) * right_impurity
287 }
288
289 fn class_probs_from_indices(&self, y: &[f32], indices: &[usize]) -> Vec<f32> {
290 let mut counts = vec![0usize; self.n_classes];
291 for &idx in indices {
292 let class = y[idx] as usize;
293 if class < self.n_classes {
294 counts[class] += 1;
295 }
296 }
297 let total = indices.len() as f32;
298 counts.iter().map(|&c| c as f32 / total).collect()
299 }
300
301 pub fn predict(&self, x: &Tensor) -> Tensor {
303 let x_data = x.data_f32();
304 let n_samples = x.dims()[0];
305 let n_features = x.dims()[1];
306
307 let predictions: Vec<f32> = (0..n_samples)
308 .map(|i| {
309 let sample = &x_data[i * n_features..(i + 1) * n_features];
310 self.predict_sample(sample)
311 })
312 .collect();
313
314 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
315 }
316
317 pub fn predict_proba(&self, x: &Tensor) -> Tensor {
319 let x_data = x.data_f32();
320 let n_samples = x.dims()[0];
321 let n_features = x.dims()[1];
322
323 let mut probs = Vec::with_capacity(n_samples * self.n_classes);
324
325 for i in 0..n_samples {
326 let sample = &x_data[i * n_features..(i + 1) * n_features];
327 let sample_probs = self.predict_proba_sample(sample);
328 probs.extend(sample_probs);
329 }
330
331 Tensor::from_slice(&probs, &[n_samples, self.n_classes]).unwrap()
332 }
333
334 fn predict_sample(&self, sample: &[f32]) -> f32 {
335 let mut node = self.root.as_ref().unwrap();
336
337 while !node.is_leaf() {
338 let feature = node.feature_index.unwrap();
339 let threshold = node.threshold.unwrap();
340
341 if sample[feature] <= threshold {
342 node = node.left.as_ref().unwrap();
343 } else {
344 node = node.right.as_ref().unwrap();
345 }
346 }
347
348 node.value.unwrap()
349 }
350
351 fn predict_proba_sample(&self, sample: &[f32]) -> Vec<f32> {
352 let mut node = self.root.as_ref().unwrap();
353
354 while !node.is_leaf() {
355 let feature = node.feature_index.unwrap();
356 let threshold = node.threshold.unwrap();
357
358 if sample[feature] <= threshold {
359 node = node.left.as_ref().unwrap();
360 } else {
361 node = node.right.as_ref().unwrap();
362 }
363 }
364
365 node.class_probs.clone().unwrap_or_else(|| vec![0.0; self.n_classes])
366 }
367}
368
369impl Default for DecisionTreeClassifier {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375pub struct DecisionTreeRegressor {
377 pub max_depth: Option<usize>,
378 pub min_samples_split: usize,
379 pub min_samples_leaf: usize,
380 pub max_features: Option<usize>,
381 pub criterion: Criterion,
382 root: Option<TreeNode>,
383}
384
385impl DecisionTreeRegressor {
386 pub fn new() -> Self {
387 DecisionTreeRegressor {
388 max_depth: None,
389 min_samples_split: 2,
390 min_samples_leaf: 1,
391 max_features: None,
392 criterion: Criterion::MSE,
393 root: None,
394 }
395 }
396
397 pub fn max_depth(mut self, depth: usize) -> Self {
398 self.max_depth = Some(depth);
399 self
400 }
401
402 pub fn fit(&mut self, x: &Tensor, y: &Tensor) {
403 let x_data = x.data_f32();
404 let y_data = y.data_f32();
405 let n_samples = x.dims()[0];
406 let n_features = x.dims()[1];
407
408 let indices: Vec<usize> = (0..n_samples).collect();
409
410 self.root = Some(self.build_tree(&x_data, &y_data, &indices, n_features, 0));
411 }
412
413 fn build_tree(
414 &self,
415 x: &[f32],
416 y: &[f32],
417 indices: &[usize],
418 n_features: usize,
419 depth: usize,
420 ) -> TreeNode {
421 let n_samples = indices.len();
422
423 let mean: f32 = indices.iter().map(|&i| y[i]).sum::<f32>() / n_samples as f32;
425 let variance: f32 = indices.iter()
426 .map(|&i| (y[i] - mean).powi(2))
427 .sum::<f32>() / n_samples as f32;
428
429 let should_stop =
431 n_samples < self.min_samples_split ||
432 self.max_depth.map_or(false, |d| depth >= d) ||
433 variance < 1e-7;
434
435 if should_stop {
436 return TreeNode::leaf(mean, n_samples, variance);
437 }
438
439 let mut best_mse = f32::INFINITY;
441 let mut best_feature = 0;
442 let mut best_threshold = 0.0f32;
443 let mut best_left_indices = Vec::new();
444 let mut best_right_indices = Vec::new();
445
446 for feature in 0..n_features {
447 let mut values: Vec<f32> = indices.iter()
448 .map(|&idx| x[idx * n_features + feature])
449 .collect();
450 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
451 values.dedup();
452
453 for i in 0..values.len().saturating_sub(1) {
454 let threshold = (values[i] + values[i + 1]) / 2.0;
455
456 let (left_indices, right_indices): (Vec<_>, Vec<_>) = indices.iter()
457 .partition(|&&idx| x[idx * n_features + feature] <= threshold);
458
459 if left_indices.len() < self.min_samples_leaf ||
460 right_indices.len() < self.min_samples_leaf {
461 continue;
462 }
463
464 let left_mean: f32 = left_indices.iter().map(|&i| y[i]).sum::<f32>() / left_indices.len() as f32;
465 let right_mean: f32 = right_indices.iter().map(|&i| y[i]).sum::<f32>() / right_indices.len() as f32;
466
467 let left_mse: f32 = left_indices.iter().map(|&i| {
468 let diff: f32 = y[i] - left_mean;
469 diff.powi(2)
470 }).sum::<f32>();
471 let right_mse: f32 = right_indices.iter().map(|&i| {
472 let diff: f32 = y[i] - right_mean;
473 diff.powi(2)
474 }).sum::<f32>();
475 let total_mse = left_mse + right_mse;
476
477 if total_mse < best_mse {
478 best_mse = total_mse;
479 best_feature = feature;
480 best_threshold = threshold;
481 best_left_indices = left_indices;
482 best_right_indices = right_indices;
483 }
484 }
485 }
486
487 if best_left_indices.is_empty() || best_right_indices.is_empty() {
488 return TreeNode::leaf(mean, n_samples, variance);
489 }
490
491 let left = self.build_tree(x, y, &best_left_indices, n_features, depth + 1);
492 let right = self.build_tree(x, y, &best_right_indices, n_features, depth + 1);
493
494 TreeNode {
495 feature_index: Some(best_feature),
496 threshold: Some(best_threshold),
497 left: Some(Box::new(left)),
498 right: Some(Box::new(right)),
499 value: Some(mean),
500 class_probs: None,
501 n_samples,
502 impurity: variance,
503 }
504 }
505
506 pub fn predict(&self, x: &Tensor) -> Tensor {
507 let x_data = x.data_f32();
508 let n_samples = x.dims()[0];
509 let n_features = x.dims()[1];
510
511 let predictions: Vec<f32> = (0..n_samples)
512 .map(|i| {
513 let sample = &x_data[i * n_features..(i + 1) * n_features];
514 self.predict_sample(sample)
515 })
516 .collect();
517
518 Tensor::from_slice(&predictions, &[n_samples]).unwrap()
519 }
520
521 fn predict_sample(&self, sample: &[f32]) -> f32 {
522 let mut node = self.root.as_ref().unwrap();
523
524 while !node.is_leaf() {
525 let feature = node.feature_index.unwrap();
526 let threshold = node.threshold.unwrap();
527
528 if sample[feature] <= threshold {
529 node = node.left.as_ref().unwrap();
530 } else {
531 node = node.right.as_ref().unwrap();
532 }
533 }
534
535 node.value.unwrap()
536 }
537}
538
539impl Default for DecisionTreeRegressor {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_decision_tree_classifier() {
551 let x = Tensor::from_slice(&[
553 0.0, 0.0,
554 0.0, 1.0,
555 1.0, 0.0,
556 1.0, 1.0,
557 ], &[4, 2]).unwrap();
558
559 let y = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], &[4]).unwrap();
560
561 let mut tree = DecisionTreeClassifier::new().max_depth(3);
562 tree.fit(&x, &y);
563
564 let predictions = tree.predict(&x);
565 let pred_data = predictions.data_f32();
566
567 assert_eq!(pred_data.len(), 4);
569 }
570
571 #[test]
572 fn test_decision_tree_regressor() {
573 let x = Tensor::from_slice(&[
574 1.0, 2.0, 3.0, 4.0, 5.0,
575 ], &[5, 1]).unwrap();
576
577 let y = Tensor::from_slice(&[2.0, 4.0, 6.0, 8.0, 10.0], &[5]).unwrap();
578
579 let mut tree = DecisionTreeRegressor::new().max_depth(5);
580 tree.fit(&x, &y);
581
582 let predictions = tree.predict(&x);
583 let pred_data = predictions.data_f32();
584
585 assert_eq!(pred_data.len(), 5);
587 }
588}