aprender 0.30.0

Next-generation ML framework in pure Rust — `cargo install aprender` for the `apr` CLI
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
<!-- PCU: ml-fundamentals-decision-trees | contract: contracts/apr-page-ml-fundamentals-decision-trees-v1.yaml -->
<!-- Example: cargo run -p aprender-core --example none -->
<!-- Status: enforced -->

# Decision Trees Theory

<!-- DOC_STATUS_START -->
**Chapter Status**: ✅ 100% Working (All examples verified)

| Status | Count | Examples |
|--------|-------|----------|
| ✅ Working | 30+ | CART algorithm (classification + regression) verified |
| ⏳ In Progress | 0 | - |
| ⬜ Not Implemented | 0 | - |

*Last tested: 2025-11-21*
*Aprender version: 0.4.1*
*Test file: src/tree/mod.rs tests*
<!-- DOC_STATUS_END -->

---

## Overview

Decision trees learn hierarchical decision rules by recursively partitioning the feature space. They're interpretable, handle non-linear relationships, and require no feature scaling.

**Key Concepts**:
- **CART Algorithm**: Classification And Regression Trees
- **Gini Impurity**: Measures node purity (classification)
- **MSE Criterion**: Measures variance (regression)
- **Recursive Splitting**: Build tree top-down, greedy
- **Max Depth**: Controls overfitting

**Why This Matters**:
Decision trees mirror human decision-making: "If feature X > threshold, then..." They're the foundation of powerful ensemble methods (Random Forests, Gradient Boosting). The same algorithm handles both classification (predicting categories) and regression (predicting continuous values).

---

## Mathematical Foundation

### The Decision Tree Structure

A decision tree is a binary tree where:
- **Internal nodes**: Test one feature against a threshold
- **Edges**: Represent test outcomes (≤ threshold, > threshold)
- **Leaves**: Contain class predictions

**Example Tree**:
```text
        [Petal Width ≤ 0.8]
       /                    \
   Class 0           [Petal Length ≤ 4.9]
                    /                    \
               Class 1                 Class 2
```

### Gini Impurity

**Definition**:
```text
Gini(S) = 1 - Σ p_i²

where:
S = set of samples in a node
p_i = proportion of class i in S
```

**Interpretation**:
- **Gini = 0.0**: Pure node (all samples same class)
- **Gini = 0.5**: Maximum impurity (binary, 50/50 split)
- **Gini < 0.5**: More pure than random

**Why squared?** Penalizes mixed distributions more than linear measure.

### Information Gain

When we split a node into left and right children:

```text
InfoGain = Gini(parent) - [w_L * Gini(left) + w_R * Gini(right)]

where:
w_L = n_left / n_total  (weight of left child)
w_R = n_right / n_total (weight of right child)
```

**Goal**: Maximize information gain → find best split

### CART Algorithm (Classification)

**Recursive Tree Building**:
```text
function BuildTree(X, y, depth, max_depth):
    if stopping_criterion_met:
        return Leaf(majority_class(y))

    best_split = find_best_split(X, y)  # Maximize InfoGain

    if no_valid_split or depth >= max_depth:
        return Leaf(majority_class(y))

    X_left, y_left, X_right, y_right = partition(X, y, best_split)

    return Node(
        feature = best_split.feature,
        threshold = best_split.threshold,
        left = BuildTree(X_left, y_left, depth+1, max_depth),
        right = BuildTree(X_right, y_right, depth+1, max_depth)
    )
```

**Stopping Criteria**:
1. All samples in node have same class (Gini = 0)
2. Reached max_depth
3. Node has too few samples (min_samples_split)
4. No split reduces impurity

### CART Algorithm (Regression)

Decision trees also handle **regression** tasks (predicting continuous values) using the same recursive splitting approach, but with different splitting criteria and leaf predictions.

**Key Differences from Classification**:
- **Splitting criterion**: Mean Squared Error (MSE) instead of Gini
- **Leaf prediction**: Mean of target values instead of majority class
- **Evaluation**: R² score instead of accuracy

#### Mean Squared Error (MSE)

**Definition**:
```text
MSE(S) = (1/n) Σ (y_i - ȳ)²

where:
S = set of samples in a node
y_i = target value of sample i
ȳ = mean target value in S
n = number of samples
```

**Equivalent Formulation**:
```text
MSE(S) = Variance(y) = (1/n) Σ (y_i - ȳ)²
```

**Interpretation**:
- **MSE = 0.0**: Pure node (all samples have same target value)
- **High MSE**: High variance in target values
- **Goal**: Minimize weighted MSE after split

#### Variance Reduction

When splitting a node into left and right children:

```text
VarReduction = MSE(parent) - [w_L * MSE(left) + w_R * MSE(right)]

where:
w_L = n_left / n_total  (weight of left child)
w_R = n_right / n_total (weight of right child)
```

**Goal**: Maximize variance reduction → find best split

**Analogy to Classification**:
- MSE for regression ≈ Gini impurity for classification
- Variance reduction ≈ Information gain
- Both measure "purity" of nodes

#### Regression Tree Building

**Recursive Algorithm**:
```text
function BuildRegressionTree(X, y, depth, max_depth):
    if stopping_criterion_met:
        return Leaf(mean(y))

    best_split = find_best_split(X, y)  # Maximize VarReduction

    if no_valid_split or depth >= max_depth:
        return Leaf(mean(y))

    X_left, y_left, X_right, y_right = partition(X, y, best_split)

    return Node(
        feature = best_split.feature,
        threshold = best_split.threshold,
        left = BuildRegressionTree(X_left, y_left, depth+1, max_depth),
        right = BuildRegressionTree(X_right, y_right, depth+1, max_depth)
    )
```

**Stopping Criteria**:
1. All samples have same target value (variance = 0)
2. Reached max_depth
3. Node has too few samples (min_samples_split)
4. No split reduces variance

#### MSE vs Gini Criterion Comparison

| Aspect | MSE (Regression) | Gini (Classification) |
|--------|------------------|----------------------|
| **Task** | Continuous prediction | Class prediction |
| **Range** | [0, ∞) | [0, 1] |
| **Pure node** | MSE = 0 (constant target) | Gini = 0 (single class) |
| **Impure node** | High variance | Gini ≈ 0.5 |
| **Split goal** | Minimize MSE | Minimize Gini |
| **Leaf prediction** | Mean of y | Majority class |
| **Evaluation** | R² score | Accuracy |

---

## Implementation in Aprender

### Example 1: Simple Binary Classification

```rust,ignore
use aprender::tree::DecisionTreeClassifier;
use aprender::primitives::Matrix;

// XOR-like problem (not linearly separable)
let x = Matrix::from_vec(4, 2, vec![
    0.0, 0.0,  // Class 0
    0.0, 1.0,  // Class 1
    1.0, 0.0,  // Class 1
    1.0, 1.0,  // Class 0
]).unwrap();
let y = vec![0, 1, 1, 0];

// Train decision tree with max depth 3
let mut tree = DecisionTreeClassifier::new()
    .with_max_depth(3);

tree.fit(&x, &y).unwrap();

// Predict on training data (should be perfect)
let predictions = tree.predict(&x);
println!("Predictions: {:?}", predictions); // [0, 1, 1, 0]

let accuracy = tree.score(&x, &y);
println!("Accuracy: {:.3}", accuracy); // 1.000
```

**Test Reference**: `src/tree/mod.rs::tests::test_build_tree_simple_split`

### Example 2: Multi-Class Classification (Iris)

```rust,ignore
// Iris dataset (3 classes, 4 features)
// Simplified example - see case study for full implementation

let mut tree = DecisionTreeClassifier::new()
    .with_max_depth(5);

tree.fit(&x_train, &y_train).unwrap();

// Test set evaluation
let y_pred = tree.predict(&x_test);
let accuracy = tree.score(&x_test, &y_test);
println!("Test Accuracy: {:.3}", accuracy); // e.g., 0.967
```

**Case Study**: See [Decision Tree - Iris Classification](../examples/decision-tree-iris.md)

### Example 3: Regression (Housing Prices)

```rust,ignore
use aprender::tree::DecisionTreeRegressor;
use aprender::primitives::{Matrix, Vector};

// Housing data: [sqft, bedrooms, age]
let x = Matrix::from_vec(8, 3, vec![
    1500.0, 3.0, 10.0,  // $280k
    2000.0, 4.0, 5.0,   // $350k
    1200.0, 2.0, 30.0,  // $180k
    1800.0, 3.0, 15.0,  // $300k
    2500.0, 5.0, 2.0,   // $450k
    1000.0, 2.0, 50.0,  // $150k
    2200.0, 4.0, 8.0,   // $380k
    1600.0, 3.0, 20.0,  // $260k
]).unwrap();

let y = Vector::from_slice(&[
    280.0, 350.0, 180.0, 300.0, 450.0, 150.0, 380.0, 260.0
]);

// Train regression tree
let mut tree = DecisionTreeRegressor::new()
    .with_max_depth(4)
    .with_min_samples_split(2);

tree.fit(&x, &y).unwrap();

// Predict on new house: 1900 sqft, 4 bed, 12 years
let x_new = Matrix::from_vec(1, 3, vec![1900.0, 4.0, 12.0]).unwrap();
let predicted_price = tree.predict(&x_new);
println!("Predicted: ${:.0}k", predicted_price.as_slice()[0]);

// Evaluate with R² score
let r2 = tree.score(&x, &y);
println!("R² Score: {:.3}", r2); // e.g., 0.95+
```

**Key Differences from Classification**:
- Uses `Vector<f32>` for continuous targets (not `Vec<usize>` classes)
- Predictions are continuous values (not class labels)
- Score returns R² instead of accuracy
- MSE criterion splits on variance reduction

**Test Reference**: `src/tree/mod.rs::tests::test_regression_tree_*`

**Case Study**: See [Decision Tree Regression](../examples/decision-tree-regression.md)

### Example 4: Model Serialization

```rust,ignore
// Train and save tree
let mut tree = DecisionTreeClassifier::new()
    .with_max_depth(4);
tree.fit(&x_train, &y_train).unwrap();

tree.save("tree_model.bin").unwrap();

// Load in production
let loaded_tree = DecisionTreeClassifier::load("tree_model.bin").unwrap();
let predictions = loaded_tree.predict(&x_test);
```

**Test Reference**: `src/tree/mod.rs::tests` (save/load tests)

---

## Understanding Gini Impurity

### Example Calculation

**Scenario**: Node with 6 samples: [A, A, A, B, B, C]

```text
Class A: 3/6 = 0.5
Class B: 2/6 = 0.33
Class C: 1/6 = 0.17

Gini = 1 - (0.5² + 0.33² + 0.17²)
     = 1 - (0.25 + 0.11 + 0.03)
     = 1 - 0.39
     = 0.61
```

**Interpretation**: 0.61 impurity (moderately mixed)

### Pure vs Impure Nodes

| Node | Distribution | Gini | Interpretation |
|------|-------------|------|----------------|
| [A, A, A, A] | 100% A | 0.0 | Pure (stop splitting) |
| [A, A, B, B] | 50% A, 50% B | 0.5 | Maximum impurity (binary) |
| [A, A, A, B] | 75% A, 25% B | 0.375 | Moderately pure |

**Test Reference**: `src/tree/mod.rs::tests::test_gini_impurity_*`

---

## Choosing Max Depth

### The Depth Trade-off

**Too shallow (max_depth = 1)**:
- Underfitting
- High bias, low variance
- Poor train and test accuracy

**Too deep (max_depth = ∞)**:
- Overfitting
- Low bias, high variance
- Perfect train accuracy, poor test accuracy

**Just right (max_depth = 3-7)**:
- Balanced bias-variance
- Good generalization

### Finding Optimal Depth

Use cross-validation:

```rust,ignore
// Pseudocode
for depth in 1..=10 {
    model = DecisionTreeClassifier::new().with_max_depth(depth);
    cv_score = cross_validate(model, x, y, k=5);
    // Select depth with best cv_score
}
```

**Rule of Thumb**:
- Simple problems: max_depth = 3-5
- Complex problems: max_depth = 5-10
- If using ensemble (Random Forest): deeper trees OK (15-30)

---

## Advantages and Limitations

### Advantages ✅

1. **Interpretable**: Can visualize and explain decisions
2. **No feature scaling**: Works on raw features
3. **Handles non-linear**: Learns complex boundaries
4. **Mixed data types**: Numeric and categorical features
5. **Fast prediction**: O(log n) traversal

### Limitations ❌

1. **Overfitting**: Single trees overfit easily
2. **Instability**: Small data changes → different tree
3. **Bias toward dominant classes**: In imbalanced data
4. **Greedy algorithm**: May miss global optimum
5. **Axis-aligned splits**: Can't learn diagonal boundaries easily

**Solution to overfitting**: Use ensemble methods (Random Forests, Gradient Boosting)

---

## Decision Trees vs Other Methods

### Comparison Table

| Method | Interpretability | Feature Scaling | Non-linear | Overfitting Risk | Speed |
|--------|------------------|-----------------|------------|------------------|-------|
| **Decision Tree** | High | Not needed | Yes | High (single tree) | Fast |
| **Logistic Regression** | Medium | Required | No (unless polynomial) | Low | Fast |
| **SVM** | Low | Required | Yes (kernels) | Medium | Slow |
| **Random Forest** | Medium | Not needed | Yes | Low | Medium |

### When to Use Decision Trees

**Good for**:
- Interpretability required (medical, legal domains)
- Mixed feature types
- Quick baseline
- Building block for ensembles
- **Regression**: Non-linear relationships without feature engineering
- **Classification**: Multi-class problems with complex boundaries

**Not good for**:
- Need best single-model accuracy (use ensemble instead)
- Linear relationships (logistic/linear regression simpler)
- Large feature space (curse of dimensionality)
- **Regression**: Smooth predictions or extrapolation beyond training range

---

## Practical Considerations

### Feature Importance

Decision trees naturally rank feature importance:
- **Most important**: Features near the root (used early)
- **Less important**: Features deeper in tree or unused

**Interpretation**: Features used for early splits have highest information gain.

### Handling Imbalanced Classes

**Problem**: Tree biased toward majority class

**Solutions**:
1. **Class weights**: Penalize majority class errors more
2. **Sampling**: SMOTE, undersampling majority
3. **Threshold tuning**: Adjust prediction threshold

### Pruning (Post-Processing)

**Idea**: Build full tree, then remove nodes with low information gain

**Benefit**: Reduces overfitting without limiting depth during training

**Status in Aprender**: Not yet implemented (use max_depth instead)

---

## Verification Through Tests

Decision tree tests verify mathematical properties:

**Gini Impurity Tests**:
- Pure node → Gini = 0.0
- 50/50 binary split → Gini = 0.5
- Gini always in [0, 1]

**Tree Building Tests**:
- Pure leaf stops splitting
- Max depth enforced
- Predictions match majority class

**Property Tests** (via integration tests):
- Tree depth ≤ max_depth
- All leaves are pure or at max_depth
- Information gain non-negative

**Test Reference**: `src/tree/mod.rs` (15+ tests)

---

## Real-World Application

### Medical Diagnosis Example

**Problem**: Diagnose disease from symptoms (temperature, blood pressure, age)

**Decision Tree**:
```text
          [Temperature > 38°C]
         /                    \
   [BP > 140]               Healthy
   /        \
Disease A   Disease B
```

**Why Decision Tree?**
- Interpretable (doctors can verify logic)
- No feature scaling (raw measurements)
- Handles mixed units (°C, mmHg, years)

### Credit Scoring Example

**Features**: Income, debt, employment length, credit history

**Decision Tree learns**:
- If income < $30k and debt > $20k → High risk
- If income > $80k → Low risk
- Else, check employment length...

**Advantage**: Transparent lending decisions (regulatory compliance)

---

## Further Reading

### Peer-Reviewed Papers

**Breiman et al. (1984)** - *Classification and Regression Trees*
- **Relevance**: Original CART algorithm (Gini impurity, recursive splitting)
- **Link**: Chapman and Hall/CRC (book, library access)
- **Key Contribution**: Unified framework for classification and regression trees
- **Applied in**: `src/tree/mod.rs` CART implementation

**Quinlan (1986)** - *Induction of Decision Trees*
- **Relevance**: Alternative algorithm using entropy (ID3)
- **Link**: [SpringerLink]https://link.springer.com/article/10.1007/BF00116251
- **Key Contribution**: Information gain via entropy (alternative to Gini)

### Related Chapters

- [Ensemble Methods Theory]./ensemble-methods.md - Random Forests (next chapter)
- [Classification Metrics Theory]./classification-metrics.md - Evaluating trees
- [Cross-Validation Theory]./cross-validation.md - Finding optimal max_depth

---

## Summary

**What You Learned**:
- ✅ Decision trees: hierarchical if-then rules for classification AND regression
-**Classification**: Gini impurity (Gini = 1 - Σ p_i²), predict majority class
-**Regression**: MSE criterion (variance), predict mean value
- ✅ CART algorithm: greedy, top-down, recursive (same for both tasks)
- ✅ Information gain: Maximize reduction in impurity (Gini or MSE)
- ✅ Max depth: Controls overfitting (tune with CV)
- ✅ Advantages: Interpretable, no scaling, non-linear
- ✅ Limitations: Overfitting, instability (use ensembles)

**Verification Guarantee**: Decision tree implementation extensively tested (30+ tests) in `src/tree/mod.rs`. Tests verify Gini calculations, MSE splitting, tree building, and prediction logic for both classification and regression.

**Quick Reference**:

**Classification**:
- **Pure node**: Gini = 0 (stop splitting)
- **Max impurity**: Gini = 0.5 (binary 50/50)
- **Best split**: Maximize information gain
- **Leaf prediction**: Majority class

**Regression**:
- **Pure node**: MSE = 0 (constant target, stop splitting)
- **High impurity**: High variance in target values
- **Best split**: Maximize variance reduction
- **Leaf prediction**: Mean of target values

**Both Tasks**:
- **Prevent overfit**: Set max_depth (3-7 typical)
- **Additional pruning**: min_samples_split, min_samples_leaf
- **Evaluation**: R² for regression, accuracy for classification

**Key Equations**:
```text
Classification:
  Gini(S) = 1 - Σ p_i²
  InfoGain = Gini(parent) - Weighted_Avg(Gini(children))

Regression:
  MSE(S) = (1/n) Σ (y_i - ȳ)²
  VarReduction = MSE(parent) - Weighted_Avg(MSE(children))

Both:
  Split: feature ≤ threshold → left, else → right
```

---

**Next Chapter**: [Ensemble Methods Theory](./ensemble-methods.md)

**Previous Chapter**: [Classification Metrics Theory](./classification-metrics.md)