Skip to main content

anofox_ml_trees/
split.rs

1use std::collections::HashMap;
2
3use anofox_ml_core::Float;
4use ndarray::{Array1, Array2};
5
6/// Convert a Float value to a u64 key suitable for HashMap use.
7/// Uses f64 bit representation for exact equality matching.
8#[inline]
9fn float_key<F: Float>(v: F) -> u64 {
10    v.to_f64().unwrap().to_bits()
11}
12
13/// Criterion for evaluating splits.
14#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
15pub enum SplitCriterion {
16    /// Gini impurity (for classification).
17    Gini,
18    /// Entropy / information gain (for classification).
19    Entropy,
20    /// Mean squared error (for regression).
21    Mse,
22}
23
24/// Strategy for selecting split thresholds.
25#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
26pub enum SplitStrategy {
27    /// Find the best split threshold (standard CART).
28    Best,
29    /// Pick a random threshold between min and max of each feature (ExtraTrees).
30    Random,
31}
32
33/// Controls the number of features considered at each split.
34#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
35pub enum MaxFeatures {
36    /// Use `floor(sqrt(n_features))` features.
37    Sqrt,
38    /// Use `floor(log2(n_features))` features (at least 1).
39    Log2,
40    /// Use exactly `k` features.
41    Fixed(usize),
42    /// Use `floor(fraction * n_features)` features (at least 1).
43    Fraction(f64),
44}
45
46impl MaxFeatures {
47    /// Resolve to a concrete number of features given `n_features` total.
48    pub fn resolve(&self, n_features: usize) -> usize {
49        match self {
50            MaxFeatures::Sqrt => (n_features as f64).sqrt().floor().max(1.0) as usize,
51            MaxFeatures::Log2 => (n_features as f64).log2().floor().max(1.0) as usize,
52            MaxFeatures::Fixed(k) => (*k).min(n_features).max(1),
53            MaxFeatures::Fraction(f) => (*f * n_features as f64).floor().max(1.0) as usize,
54        }
55    }
56}
57
58/// Select a random subset of feature indices using xorshift64.
59///
60/// Returns `k` distinct indices from `0..n_features` using a deterministic
61/// pseudo-random shuffle seeded by `seed`. Falls back to all features if
62/// `k >= n_features`.
63pub fn select_feature_subset(n_features: usize, k: usize, seed: u64) -> Vec<usize> {
64    if k >= n_features {
65        return (0..n_features).collect();
66    }
67    let mut indices: Vec<usize> = (0..n_features).collect();
68    let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
69    for i in 0..k {
70        // xorshift64
71        state ^= state << 13;
72        state ^= state >> 7;
73        state ^= state << 17;
74        let j = i + (state as usize) % (n_features - i);
75        indices.swap(i, j);
76    }
77    indices.truncate(k);
78    indices.sort_unstable();
79    indices
80}
81
82/// Result of finding the best split at a node.
83#[derive(Debug, Clone)]
84pub struct BestSplit<F: Float> {
85    pub feature_index: usize,
86    pub threshold: F,
87    pub left_indices: Vec<usize>,
88    pub right_indices: Vec<usize>,
89    pub improvement: F,
90}
91
92/// Find the best split over all features and thresholds.
93///
94/// Uses an incremental class-count / running-sum approach so that each
95/// candidate threshold is evaluated in O(k) (classification, k = n_classes)
96/// or O(1) (regression) instead of O(n).
97pub fn find_best_split<F: Float>(
98    x: &Array2<F>,
99    y: &Array1<F>,
100    indices: &[usize],
101    criterion: SplitCriterion,
102    min_samples_leaf: usize,
103) -> Option<BestSplit<F>> {
104    let all_features: Vec<usize> = (0..x.ncols()).collect();
105    find_best_split_with_features(x, y, indices, criterion, min_samples_leaf, &all_features)
106}
107
108/// Like [`find_best_split`] but only considers the given feature indices.
109pub fn find_best_split_with_features<F: Float>(
110    x: &Array2<F>,
111    y: &Array1<F>,
112    indices: &[usize],
113    criterion: SplitCriterion,
114    min_samples_leaf: usize,
115    feature_indices: &[usize],
116) -> Option<BestSplit<F>> {
117    let n = indices.len();
118    if n < 2 * min_samples_leaf {
119        return None;
120    }
121
122    let parent_impurity = compute_impurity(y, indices, criterion);
123
124    match criterion {
125        SplitCriterion::Gini | SplitCriterion::Entropy => find_best_split_classification(
126            x,
127            y,
128            indices,
129            criterion,
130            min_samples_leaf,
131            feature_indices,
132            n,
133            parent_impurity,
134        ),
135        SplitCriterion::Mse => find_best_split_regression(
136            x,
137            y,
138            indices,
139            min_samples_leaf,
140            feature_indices,
141            n,
142            parent_impurity,
143        ),
144    }
145}
146
147/// Sort indices by feature value, filling the provided buffer.
148///
149/// Clears `sorted_pairs` and fills it with `(feature_value, original_index)`
150/// pairs sorted by feature value.
151#[inline]
152fn sort_feature_pairs<F: Float>(
153    x: &Array2<F>,
154    indices: &[usize],
155    feature: usize,
156    sorted_pairs: &mut Vec<(F, usize)>,
157) {
158    sorted_pairs.clear();
159    sorted_pairs.extend(indices.iter().map(|&i| (x[[i, feature]], i)));
160    sorted_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
161}
162
163/// Candidate best split (without index vectors) used during the search.
164struct CandidateSplit<F: Float> {
165    feature: usize,
166    threshold: F,
167    improvement: F,
168}
169
170/// If `improvement` beats `best_improvement`, record the candidate.
171///
172/// Index vectors are NOT allocated here — they are reconstructed once
173/// after the search completes, avoiding repeated allocations.
174#[inline]
175fn try_update_best_split<F: Float>(
176    improvement: F,
177    best_improvement: &mut F,
178    best: &mut Option<CandidateSplit<F>>,
179    feature: usize,
180    threshold: F,
181) {
182    if improvement > *best_improvement {
183        *best_improvement = improvement;
184        *best = Some(CandidateSplit {
185            feature,
186            threshold,
187            improvement,
188        });
189    }
190}
191
192/// Accumulator that tracks split impurity incrementally as samples move
193/// from the "right" partition to the "left" partition.
194trait SplitAccumulator<F: Float> {
195    /// Reset the accumulator with all samples in "right".
196    fn reset(&mut self, y: &Array1<F>, indices: &[usize]);
197    /// Move sample `idx` from right to left.
198    fn move_to_left(&mut self, y: &Array1<F>, idx: usize);
199    /// Compute weighted impurity: (n_left/n)*left + (n_right/n)*right.
200    fn weighted_impurity(&self, n: usize) -> F;
201    /// Number of samples currently in the left partition.
202    fn n_left(&self) -> usize;
203    /// Number of samples currently in the right partition.
204    fn n_right(&self) -> usize;
205}
206
207/// Classification accumulator using incremental class counts.
208struct ClassificationAccumulator<F: Float> {
209    left_counts: Vec<usize>,
210    right_counts: Vec<usize>,
211    n_left: usize,
212    n_right: usize,
213    criterion: SplitCriterion,
214    class_map: HashMap<u64, usize>,
215    _marker: std::marker::PhantomData<F>,
216}
217
218impl<F: Float> ClassificationAccumulator<F> {
219    /// Create a new accumulator, building the class_map once.
220    fn new(y: &Array1<F>, indices: &[usize]) -> Self {
221        let class_map = build_class_map(y, indices);
222        let n_classes = class_map.len();
223
224        let mut total_counts = vec![0usize; n_classes];
225        for &i in indices {
226            let cls = class_map[&float_key(y[i])];
227            total_counts[cls] += 1;
228        }
229
230        Self {
231            left_counts: vec![0usize; n_classes],
232            right_counts: total_counts,
233            n_left: 0,
234            n_right: indices.len(),
235            criterion: SplitCriterion::Gini,
236            class_map,
237            _marker: std::marker::PhantomData,
238        }
239    }
240}
241
242impl<F: Float> SplitAccumulator<F> for ClassificationAccumulator<F> {
243    fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
244        // Reuse existing Vecs — just zero and refill counts.
245        self.left_counts.fill(0);
246        self.right_counts.fill(0);
247        for &i in indices {
248            let cls = self.class_map[&float_key(y[i])];
249            self.right_counts[cls] += 1;
250        }
251        self.n_left = 0;
252        self.n_right = indices.len();
253    }
254
255    fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
256        let cls = self.class_map[&float_key(y[idx])];
257        self.left_counts[cls] += 1;
258        self.right_counts[cls] -= 1;
259        self.n_left += 1;
260        self.n_right -= 1;
261    }
262
263    fn weighted_impurity(&self, n: usize) -> F {
264        let n_f = F::from_usize(n).unwrap();
265        let nl = F::from_usize(self.n_left).unwrap();
266        let nr = F::from_usize(self.n_right).unwrap();
267        let left_imp = impurity_from_counts(&self.left_counts, self.n_left, self.criterion);
268        let right_imp = impurity_from_counts(&self.right_counts, self.n_right, self.criterion);
269        (nl / n_f) * left_imp + (nr / n_f) * right_imp
270    }
271
272    fn n_left(&self) -> usize {
273        self.n_left
274    }
275
276    fn n_right(&self) -> usize {
277        self.n_right
278    }
279}
280
281impl<F: Float> ClassificationAccumulator<F> {
282    fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
283        self.criterion = criterion;
284        self
285    }
286}
287
288/// Regression accumulator using running sum and sum-of-squares.
289struct RegressionAccumulator<F: Float> {
290    left_sum: F,
291    left_sum_sq: F,
292    right_sum: F,
293    right_sum_sq: F,
294    n_left: usize,
295    n_right: usize,
296}
297
298impl<F: Float> RegressionAccumulator<F> {
299    fn new(y: &Array1<F>, indices: &[usize]) -> Self {
300        let mut total_sum = F::zero();
301        let mut total_sum_sq = F::zero();
302        for &i in indices {
303            let v = y[i];
304            total_sum += v;
305            total_sum_sq += v * v;
306        }
307
308        Self {
309            left_sum: F::zero(),
310            left_sum_sq: F::zero(),
311            right_sum: total_sum,
312            right_sum_sq: total_sum_sq,
313            n_left: 0,
314            n_right: indices.len(),
315        }
316    }
317}
318
319impl<F: Float> SplitAccumulator<F> for RegressionAccumulator<F> {
320    fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
321        self.left_sum = F::zero();
322        self.left_sum_sq = F::zero();
323        self.right_sum = F::zero();
324        self.right_sum_sq = F::zero();
325        for &i in indices {
326            let v = y[i];
327            self.right_sum += v;
328            self.right_sum_sq += v * v;
329        }
330        self.n_left = 0;
331        self.n_right = indices.len();
332    }
333
334    fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
335        let v = y[idx];
336        self.left_sum += v;
337        self.left_sum_sq += v * v;
338        self.right_sum -= v;
339        self.right_sum_sq -= v * v;
340        self.n_left += 1;
341        self.n_right -= 1;
342    }
343
344    fn weighted_impurity(&self, n: usize) -> F {
345        let n_f = F::from_usize(n).unwrap();
346        let nl = F::from_usize(self.n_left).unwrap();
347        let nr = F::from_usize(self.n_right).unwrap();
348        // MSE = sum_sq/n - (sum/n)^2
349        let left_mse = self.left_sum_sq / nl - (self.left_sum / nl) * (self.left_sum / nl);
350        let right_mse = self.right_sum_sq / nr - (self.right_sum / nr) * (self.right_sum / nr);
351        (nl / n_f) * left_mse + (nr / n_f) * right_mse
352    }
353
354    fn n_left(&self) -> usize {
355        self.n_left
356    }
357
358    fn n_right(&self) -> usize {
359        self.n_right
360    }
361}
362
363/// Evaluate a candidate split point between `cur_val` and `next_val`.
364///
365/// Returns `Some((threshold, improvement))` when the split passes the
366/// distinct-value and min-samples-leaf gates, or `None` otherwise.
367#[inline]
368fn evaluate_candidate_split<F: Float, A: SplitAccumulator<F>>(
369    acc: &A,
370    n: usize,
371    min_samples_leaf: usize,
372    cur_val: F,
373    next_val: F,
374    parent_impurity: F,
375) -> Option<(F, F)> {
376    // Only consider a split between distinct values.
377    if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
378        return None;
379    }
380
381    // Check min_samples_leaf constraint.
382    if acc.n_left() < min_samples_leaf || acc.n_right() < min_samples_leaf {
383        return None;
384    }
385
386    let threshold = (cur_val + next_val) / (F::one() + F::one());
387    let improvement = parent_impurity - acc.weighted_impurity(n);
388    Some((threshold, improvement))
389}
390
391/// Unified split-finding loop parameterised by accumulator type.
392///
393/// Scans each feature's sorted values, moving samples from right to left
394/// and evaluating candidate splits via the accumulator's impurity method.
395/// The accumulator is created once and reset per feature to avoid repeated
396/// allocations (e.g. HashMap/Vec for class counts).
397#[allow(clippy::too_many_arguments)]
398fn find_best_split_inner<F, A>(
399    x: &Array2<F>,
400    y: &Array1<F>,
401    indices: &[usize],
402    min_samples_leaf: usize,
403    feature_indices: &[usize],
404    n: usize,
405    parent_impurity: F,
406    mut acc: A,
407) -> Option<BestSplit<F>>
408where
409    F: Float,
410    A: SplitAccumulator<F>,
411{
412    let mut best: Option<CandidateSplit<F>> = None;
413    let mut best_improvement = F::neg_infinity();
414
415    let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
416
417    for &feature in feature_indices {
418        sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
419
420        acc.reset(y, indices);
421
422        for pos in 0..n - 1 {
423            let (cur_val, cur_idx) = sorted_pairs[pos];
424            acc.move_to_left(y, cur_idx);
425
426            let next_val = sorted_pairs[pos + 1].0;
427            if let Some((threshold, improvement)) = evaluate_candidate_split(
428                &acc,
429                n,
430                min_samples_leaf,
431                cur_val,
432                next_val,
433                parent_impurity,
434            ) {
435                try_update_best_split(
436                    improvement,
437                    &mut best_improvement,
438                    &mut best,
439                    feature,
440                    threshold,
441                );
442            }
443        }
444    }
445
446    // Reconstruct index vectors only for the winning split via O(n) partition.
447    best.map(|candidate| {
448        let mut left_indices = Vec::with_capacity(n);
449        let mut right_indices = Vec::with_capacity(n);
450        for &i in indices {
451            if x[[i, candidate.feature]] <= candidate.threshold {
452                left_indices.push(i);
453            } else {
454                right_indices.push(i);
455            }
456        }
457        BestSplit {
458            feature_index: candidate.feature,
459            threshold: candidate.threshold,
460            left_indices,
461            right_indices,
462            improvement: candidate.improvement,
463        }
464    })
465}
466
467/// Classification split finding with incremental class counts.
468#[allow(clippy::too_many_arguments)]
469fn find_best_split_classification<F: Float>(
470    x: &Array2<F>,
471    y: &Array1<F>,
472    indices: &[usize],
473    criterion: SplitCriterion,
474    min_samples_leaf: usize,
475    feature_indices: &[usize],
476    n: usize,
477    parent_impurity: F,
478) -> Option<BestSplit<F>> {
479    let acc = ClassificationAccumulator::<F>::new(y, indices).with_criterion(criterion);
480    find_best_split_inner(
481        x,
482        y,
483        indices,
484        min_samples_leaf,
485        feature_indices,
486        n,
487        parent_impurity,
488        acc,
489    )
490}
491
492/// Regression split finding with running sum and sum-of-squares.
493fn find_best_split_regression<F: Float>(
494    x: &Array2<F>,
495    y: &Array1<F>,
496    indices: &[usize],
497    min_samples_leaf: usize,
498    feature_indices: &[usize],
499    n: usize,
500    parent_impurity: F,
501) -> Option<BestSplit<F>> {
502    let acc = RegressionAccumulator::<F>::new(y, indices);
503    find_best_split_inner(
504        x,
505        y,
506        indices,
507        min_samples_leaf,
508        feature_indices,
509        n,
510        parent_impurity,
511        acc,
512    )
513}
514
515/// Build a mapping from class label (as f64 bits) to contiguous index.
516fn build_class_map<F: Float>(y: &Array1<F>, indices: &[usize]) -> HashMap<u64, usize> {
517    let mut map = HashMap::new();
518    let mut next_idx = 0;
519    for &i in indices {
520        let bits = float_key(y[i]);
521        if let std::collections::hash_map::Entry::Vacant(e) = map.entry(bits) {
522            e.insert(next_idx);
523            next_idx += 1;
524        }
525    }
526    map
527}
528
529/// Compute Gini or Entropy impurity from class counts (O(k) where k = n_classes).
530#[inline]
531fn impurity_from_counts<F: Float>(counts: &[usize], total: usize, criterion: SplitCriterion) -> F {
532    let n = F::from_usize(total).unwrap();
533    match criterion {
534        SplitCriterion::Gini => {
535            let sum_sq: F = counts
536                .iter()
537                .filter(|&&c| c > 0)
538                .map(|&c| {
539                    let p = F::from_usize(c).unwrap() / n;
540                    p * p
541                })
542                .fold(F::zero(), |a, b| a + b);
543            F::one() - sum_sq
544        }
545        SplitCriterion::Entropy => {
546            let sum: F = counts
547                .iter()
548                .filter(|&&c| c > 0)
549                .map(|&c| {
550                    let p = F::from_usize(c).unwrap() / n;
551                    p * p.ln()
552                })
553                .fold(F::zero(), |a, b| a + b);
554            -sum
555        }
556        SplitCriterion::Mse => unreachable!("MSE does not use class counts"),
557    }
558}
559
560/// Compute impurity for a subset of samples.
561#[inline]
562pub fn compute_impurity<F: Float>(
563    y: &Array1<F>,
564    indices: &[usize],
565    criterion: SplitCriterion,
566) -> F {
567    match criterion {
568        SplitCriterion::Gini => gini(y, indices),
569        SplitCriterion::Entropy => entropy(y, indices),
570        SplitCriterion::Mse => mse_impurity(y, indices),
571    }
572}
573
574#[inline]
575fn gini<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
576    let n = F::from_usize(indices.len()).unwrap();
577    let class_counts = count_classes(y, indices);
578
579    let sum_sq: F = class_counts
580        .iter()
581        .map(|&(_, count)| {
582            let p = F::from_usize(count).unwrap() / n;
583            p * p
584        })
585        .fold(F::zero(), |a, b| a + b);
586
587    F::one() - sum_sq
588}
589
590#[inline]
591fn entropy<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
592    let n = F::from_usize(indices.len()).unwrap();
593    let class_counts = count_classes(y, indices);
594
595    let sum: F = class_counts
596        .iter()
597        .map(|&(_, count)| {
598            let p = F::from_usize(count).unwrap() / n;
599            if p > F::zero() {
600                p * p.ln()
601            } else {
602                F::zero()
603            }
604        })
605        .fold(F::zero(), |a, b| a + b);
606
607    -sum
608}
609
610#[inline]
611fn mse_impurity<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
612    let n = F::from_usize(indices.len()).unwrap();
613    let mean: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n;
614
615    indices
616        .iter()
617        .map(|&i| (y[i] - mean) * (y[i] - mean))
618        .fold(F::zero(), |a, b| a + b)
619        / n
620}
621
622/// Count occurrences of each class in a subset.
623pub fn count_classes<F: Float>(y: &Array1<F>, indices: &[usize]) -> Vec<(F, usize)> {
624    let mut map: HashMap<u64, (F, usize)> = HashMap::new();
625    for &i in indices {
626        let val = y[i];
627        let bits = float_key(val);
628        map.entry(bits).and_modify(|e| e.1 += 1).or_insert((val, 1));
629    }
630    let mut counts: Vec<(F, usize)> = map.into_values().collect();
631    counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
632    counts
633}
634
635/// Compute the majority class (for classification) or mean (for regression).
636#[inline]
637pub fn leaf_value<F: Float>(y: &Array1<F>, indices: &[usize], criterion: SplitCriterion) -> F {
638    match criterion {
639        SplitCriterion::Mse => {
640            let n = F::from_usize(indices.len()).unwrap();
641            indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n
642        }
643        SplitCriterion::Gini | SplitCriterion::Entropy => {
644            let counts = count_classes(y, indices);
645            counts
646                .into_iter()
647                .max_by_key(|&(_, count)| count)
648                .unwrap()
649                .0
650        }
651    }
652}
653
654/// Find a split using random thresholds (ExtraTrees strategy).
655///
656/// For each feature, picks a random threshold between the min and max value,
657/// then evaluates the resulting split. Returns the best across all features.
658pub fn find_random_split<F: Float>(
659    x: &Array2<F>,
660    y: &Array1<F>,
661    indices: &[usize],
662    criterion: SplitCriterion,
663    min_samples_leaf: usize,
664    seed: u64,
665) -> Option<BestSplit<F>> {
666    let n_features = x.ncols();
667    let n = indices.len();
668    if n < 2 * min_samples_leaf {
669        return None;
670    }
671
672    let parent_impurity = compute_impurity(y, indices, criterion);
673
674    let mut best: Option<CandidateSplit<F>> = None;
675    let mut best_improvement = F::neg_infinity();
676
677    // Simple deterministic RNG (xorshift) to avoid depending on rand in this crate
678    let mut rng_state = seed.wrapping_add(0x9E3779B97F4A7C15);
679
680    for feature in 0..n_features {
681        // Find min and max of this feature across the indices
682        let mut min_val = x[[indices[0], feature]];
683        let mut max_val = min_val;
684        for &i in &indices[1..] {
685            let v = x[[i, feature]];
686            if v < min_val {
687                min_val = v;
688            }
689            if v > max_val {
690                max_val = v;
691            }
692        }
693
694        if (max_val - min_val).abs() < F::from_f64(1e-15).unwrap() {
695            continue;
696        }
697
698        // Generate random threshold between min and max using xorshift64
699        rng_state ^= rng_state << 13;
700        rng_state ^= rng_state >> 7;
701        rng_state ^= rng_state << 17;
702        let t = F::from_f64((rng_state as f64) / (u64::MAX as f64)).unwrap();
703        let threshold = min_val + t * (max_val - min_val);
704
705        // Partition and compute impurity
706        let mut n_left = 0usize;
707        let mut n_right = 0usize;
708        for &i in indices {
709            if x[[i, feature]] <= threshold {
710                n_left += 1;
711            } else {
712                n_right += 1;
713            }
714        }
715
716        if n_left < min_samples_leaf || n_right < min_samples_leaf {
717            continue;
718        }
719
720        // Compute weighted impurity for this split
721        let left_indices: Vec<usize> = indices
722            .iter()
723            .copied()
724            .filter(|&i| x[[i, feature]] <= threshold)
725            .collect();
726        let right_indices: Vec<usize> = indices
727            .iter()
728            .copied()
729            .filter(|&i| x[[i, feature]] > threshold)
730            .collect();
731
732        let left_imp = compute_impurity(y, &left_indices, criterion);
733        let right_imp = compute_impurity(y, &right_indices, criterion);
734
735        let n_f = F::from_usize(n).unwrap();
736        let nl_f = F::from_usize(n_left).unwrap();
737        let nr_f = F::from_usize(n_right).unwrap();
738        let weighted = (nl_f / n_f) * left_imp + (nr_f / n_f) * right_imp;
739        let improvement = parent_impurity - weighted;
740
741        try_update_best_split(
742            improvement,
743            &mut best_improvement,
744            &mut best,
745            feature,
746            threshold,
747        );
748    }
749
750    // Reconstruct index vectors for winning split
751    best.map(|candidate| {
752        let mut left_indices = Vec::with_capacity(n);
753        let mut right_indices = Vec::with_capacity(n);
754        for &i in indices {
755            if x[[i, candidate.feature]] <= candidate.threshold {
756                left_indices.push(i);
757            } else {
758                right_indices.push(i);
759            }
760        }
761        BestSplit {
762            feature_index: candidate.feature,
763            threshold: candidate.threshold,
764            left_indices,
765            right_indices,
766            improvement: candidate.improvement,
767        }
768    })
769}
770
771/// Class weighting strategy for classifiers.
772#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
773pub enum ClassWeight {
774    /// Weights inversely proportional to class frequencies: `n_samples / (n_classes * count_k)`.
775    Balanced,
776    /// Manual per-class weights. Keys are class labels (as f64), values are weights.
777    Manual(Vec<(f64, f64)>),
778}
779
780/// Compute per-sample weights from a class weight strategy.
781pub fn compute_sample_weights_from_class_weight<F: Float>(
782    y: &Array1<F>,
783    class_weight: &ClassWeight,
784) -> Array1<F> {
785    let n_samples = y.len();
786    match class_weight {
787        ClassWeight::Balanced => {
788            let counts = count_classes(y, &(0..n_samples).collect::<Vec<_>>());
789            let n_classes = counts.len();
790            let n_f = F::from_usize(n_samples).unwrap();
791            let nc_f = F::from_usize(n_classes).unwrap();
792            let mut weights = Array1::<F>::ones(n_samples);
793            for i in 0..n_samples {
794                for &(class_val, count) in &counts {
795                    if (y[i] - class_val).abs() < F::from_f64(1e-9).unwrap() {
796                        weights[i] = n_f / (nc_f * F::from_usize(count).unwrap());
797                        break;
798                    }
799                }
800            }
801            weights
802        }
803        ClassWeight::Manual(mapping) => {
804            let mut weights = Array1::<F>::ones(n_samples);
805            for i in 0..n_samples {
806                let yi = y[i].to_f64().unwrap();
807                for &(class_val, w) in mapping {
808                    if (yi - class_val).abs() < 1e-9 {
809                        weights[i] = F::from_f64(w).unwrap();
810                        break;
811                    }
812                }
813            }
814            weights
815        }
816    }
817}
818
819/// Compute weighted impurity for a subset of samples.
820pub fn compute_weighted_impurity<F: Float>(
821    y: &Array1<F>,
822    indices: &[usize],
823    weights: &Array1<F>,
824    criterion: SplitCriterion,
825) -> F {
826    let total_weight: F = indices
827        .iter()
828        .map(|&i| weights[i])
829        .fold(F::zero(), |a, b| a + b);
830    if total_weight <= F::zero() {
831        return F::zero();
832    }
833
834    match criterion {
835        SplitCriterion::Gini => {
836            // Weighted Gini: 1 - sum(p_k^2) where p_k = weighted_count_k / total_weight
837            let mut class_weights: HashMap<u64, F> = HashMap::new();
838            for &i in indices {
839                let key = float_key(y[i]);
840                *class_weights.entry(key).or_insert(F::zero()) += weights[i];
841            }
842            let sum_sq: F = class_weights
843                .values()
844                .map(|&w| {
845                    let p = w / total_weight;
846                    p * p
847                })
848                .fold(F::zero(), |a, b| a + b);
849            F::one() - sum_sq
850        }
851        SplitCriterion::Entropy => {
852            let mut class_weights: HashMap<u64, F> = HashMap::new();
853            for &i in indices {
854                let key = float_key(y[i]);
855                *class_weights.entry(key).or_insert(F::zero()) += weights[i];
856            }
857            let sum: F = class_weights
858                .values()
859                .filter(|&&w| w > F::zero())
860                .map(|&w| {
861                    let p = w / total_weight;
862                    p * p.ln()
863                })
864                .fold(F::zero(), |a, b| a + b);
865            -sum
866        }
867        SplitCriterion::Mse => {
868            // Weighted MSE: sum(w_i * (y_i - weighted_mean)^2) / total_weight
869            let w_mean: F = indices
870                .iter()
871                .map(|&i| weights[i] * y[i])
872                .fold(F::zero(), |a, b| a + b)
873                / total_weight;
874            indices
875                .iter()
876                .map(|&i| weights[i] * (y[i] - w_mean) * (y[i] - w_mean))
877                .fold(F::zero(), |a, b| a + b)
878                / total_weight
879        }
880    }
881}
882
883/// Compute weighted leaf value.
884pub fn weighted_leaf_value<F: Float>(
885    y: &Array1<F>,
886    indices: &[usize],
887    weights: &Array1<F>,
888    criterion: SplitCriterion,
889) -> F {
890    match criterion {
891        SplitCriterion::Mse => {
892            let total_weight: F = indices
893                .iter()
894                .map(|&i| weights[i])
895                .fold(F::zero(), |a, b| a + b);
896            if total_weight <= F::zero() {
897                return F::zero();
898            }
899            indices
900                .iter()
901                .map(|&i| weights[i] * y[i])
902                .fold(F::zero(), |a, b| a + b)
903                / total_weight
904        }
905        SplitCriterion::Gini | SplitCriterion::Entropy => {
906            // Weighted majority class
907            let mut class_weights: HashMap<u64, (F, F)> = HashMap::new();
908            for &i in indices {
909                let key = float_key(y[i]);
910                class_weights
911                    .entry(key)
912                    .and_modify(|e| e.1 += weights[i])
913                    .or_insert((y[i], weights[i]));
914            }
915            class_weights
916                .into_values()
917                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
918                .unwrap()
919                .0
920        }
921    }
922}
923
924/// Count weighted occurrences of each class.
925pub fn weighted_count_classes<F: Float>(
926    y: &Array1<F>,
927    indices: &[usize],
928    weights: &Array1<F>,
929) -> Vec<(F, F)> {
930    let mut map: HashMap<u64, (F, F)> = HashMap::new();
931    for &i in indices {
932        let val = y[i];
933        let bits = float_key(val);
934        map.entry(bits)
935            .and_modify(|e| e.1 += weights[i])
936            .or_insert((val, weights[i]));
937    }
938    let mut counts: Vec<(F, F)> = map.into_values().collect();
939    counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
940    counts
941}
942
943/// Find the best split considering sample weights.
944///
945/// Like [`find_best_split`] but weights each sample's contribution to impurity.
946pub fn find_best_split_weighted<F: Float>(
947    x: &Array2<F>,
948    y: &Array1<F>,
949    indices: &[usize],
950    weights: &Array1<F>,
951    criterion: SplitCriterion,
952    min_samples_leaf: usize,
953    feature_indices: &[usize],
954) -> Option<BestSplit<F>> {
955    let n = indices.len();
956    if n < 2 * min_samples_leaf {
957        return None;
958    }
959
960    let parent_impurity = compute_weighted_impurity(y, indices, weights, criterion);
961
962    let mut best: Option<CandidateSplit<F>> = None;
963    let mut best_improvement = F::neg_infinity();
964    let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
965
966    let total_weight: F = indices
967        .iter()
968        .map(|&i| weights[i])
969        .fold(F::zero(), |a, b| a + b);
970
971    for &feature in feature_indices {
972        sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
973
974        // Track left/right weight sums and weighted impurity incrementally
975        let mut left_weight = F::zero();
976        let mut right_weight = total_weight;
977        let mut left_class_weights: HashMap<u64, F> = HashMap::new();
978        let mut right_class_weights: HashMap<u64, F> = HashMap::new();
979
980        // Initialize right with all samples
981        for &i in indices {
982            let key = float_key(y[i]);
983            *right_class_weights.entry(key).or_insert(F::zero()) += weights[i];
984        }
985
986        for pos in 0..n - 1 {
987            let (cur_val, cur_idx) = sorted_pairs[pos];
988            let w = weights[cur_idx];
989            let key = float_key(y[cur_idx]);
990
991            // Move from right to left
992            left_weight += w;
993            right_weight -= w;
994            *left_class_weights.entry(key).or_insert(F::zero()) += w;
995            *right_class_weights.entry(key).or_insert(F::zero()) -= w;
996
997            let next_val = sorted_pairs[pos + 1].0;
998            if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
999                continue;
1000            }
1001
1002            // Check min_samples_leaf (count, not weight)
1003            let n_left = pos + 1;
1004            let n_right = n - n_left;
1005            if n_left < min_samples_leaf || n_right < min_samples_leaf {
1006                continue;
1007            }
1008
1009            // Compute weighted impurity
1010            let left_imp = match criterion {
1011                SplitCriterion::Gini => {
1012                    let sum_sq: F = left_class_weights
1013                        .values()
1014                        .filter(|&&w| w > F::zero())
1015                        .map(|&w| {
1016                            let p = w / left_weight;
1017                            p * p
1018                        })
1019                        .fold(F::zero(), |a, b| a + b);
1020                    F::one() - sum_sq
1021                }
1022                SplitCriterion::Entropy => {
1023                    let sum: F = left_class_weights
1024                        .values()
1025                        .filter(|&&w| w > F::zero())
1026                        .map(|&w| {
1027                            let p = w / left_weight;
1028                            p * p.ln()
1029                        })
1030                        .fold(F::zero(), |a, b| a + b);
1031                    -sum
1032                }
1033                SplitCriterion::Mse => {
1034                    // For MSE, we need running sums — use a simpler approach
1035                    let left_indices: Vec<usize> =
1036                        sorted_pairs[..=pos].iter().map(|&(_, i)| i).collect();
1037                    compute_weighted_impurity(y, &left_indices, weights, criterion)
1038                }
1039            };
1040
1041            let right_imp = match criterion {
1042                SplitCriterion::Gini => {
1043                    let sum_sq: F = right_class_weights
1044                        .values()
1045                        .filter(|&&w| w > F::zero())
1046                        .map(|&w| {
1047                            let p = w / right_weight;
1048                            p * p
1049                        })
1050                        .fold(F::zero(), |a, b| a + b);
1051                    F::one() - sum_sq
1052                }
1053                SplitCriterion::Entropy => {
1054                    let sum: F = right_class_weights
1055                        .values()
1056                        .filter(|&&w| w > F::zero())
1057                        .map(|&w| {
1058                            let p = w / right_weight;
1059                            p * p.ln()
1060                        })
1061                        .fold(F::zero(), |a, b| a + b);
1062                    -sum
1063                }
1064                SplitCriterion::Mse => {
1065                    let right_indices: Vec<usize> =
1066                        sorted_pairs[pos + 1..].iter().map(|&(_, i)| i).collect();
1067                    compute_weighted_impurity(y, &right_indices, weights, criterion)
1068                }
1069            };
1070
1071            let weighted_imp =
1072                (left_weight / total_weight) * left_imp + (right_weight / total_weight) * right_imp;
1073            let improvement = parent_impurity - weighted_imp;
1074            let threshold = (cur_val + next_val) / (F::one() + F::one());
1075
1076            try_update_best_split(
1077                improvement,
1078                &mut best_improvement,
1079                &mut best,
1080                feature,
1081                threshold,
1082            );
1083        }
1084    }
1085
1086    best.map(|candidate| {
1087        let mut left_indices = Vec::with_capacity(n);
1088        let mut right_indices = Vec::with_capacity(n);
1089        for &i in indices {
1090            if x[[i, candidate.feature]] <= candidate.threshold {
1091                left_indices.push(i);
1092            } else {
1093                right_indices.push(i);
1094            }
1095        }
1096        BestSplit {
1097            feature_index: candidate.feature,
1098            threshold: candidate.threshold,
1099            left_indices,
1100            right_indices,
1101            improvement: candidate.improvement,
1102        }
1103    })
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108    use super::*;
1109    use approx::assert_abs_diff_eq;
1110    use ndarray::array;
1111
1112    #[test]
1113    fn test_gini_pure() {
1114        let y = array![1.0, 1.0, 1.0];
1115        let indices = vec![0, 1, 2];
1116        assert_abs_diff_eq!(gini(&y, &indices), 0.0, epsilon = 1e-10);
1117    }
1118
1119    #[test]
1120    fn test_gini_balanced() {
1121        let y = array![0.0, 1.0];
1122        let indices = vec![0, 1];
1123        assert_abs_diff_eq!(gini(&y, &indices), 0.5, epsilon = 1e-10);
1124    }
1125
1126    #[test]
1127    fn test_mse_pure() {
1128        let y = array![5.0, 5.0, 5.0];
1129        let indices = vec![0, 1, 2];
1130        assert_abs_diff_eq!(mse_impurity(&y, &indices), 0.0, epsilon = 1e-10);
1131    }
1132
1133    #[test]
1134    fn test_find_best_split() {
1135        let x = array![[1.0], [2.0], [3.0], [4.0]];
1136        let y = array![0.0, 0.0, 1.0, 1.0];
1137        let indices = vec![0, 1, 2, 3];
1138
1139        let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 1).unwrap();
1140        // Should split between 2.0 and 3.0
1141        assert!(split.threshold > 2.0 && split.threshold < 3.0);
1142    }
1143
1144    #[test]
1145    fn test_find_best_split_regression() {
1146        let x = array![[1.0], [2.0], [3.0], [4.0]];
1147        let y = array![1.0, 1.5, 10.0, 10.5];
1148        let indices = vec![0, 1, 2, 3];
1149
1150        let split = find_best_split(&x, &y, &indices, SplitCriterion::Mse, 1).unwrap();
1151        // Should split between 2.0 and 3.0
1152        assert!(split.threshold > 2.0 && split.threshold < 3.0);
1153        assert_eq!(split.left_indices.len(), 2);
1154        assert_eq!(split.right_indices.len(), 2);
1155    }
1156
1157    #[test]
1158    fn test_count_classes_uses_exact_bits() {
1159        let y = array![0.0, 1.0, 0.0, 2.0, 1.0];
1160        let indices = vec![0, 1, 2, 3, 4];
1161        let counts = count_classes(&y, &indices);
1162        assert_eq!(counts.len(), 3);
1163        // Sorted by value: (0.0, 2), (1.0, 2), (2.0, 1)
1164        assert_eq!(counts[0].1, 2); // class 0.0
1165        assert_eq!(counts[1].1, 2); // class 1.0
1166        assert_eq!(counts[2].1, 1); // class 2.0
1167    }
1168
1169    #[test]
1170    fn test_find_best_split_entropy() {
1171        let x = array![[1.0], [2.0], [3.0], [4.0]];
1172        let y = array![0.0, 0.0, 1.0, 1.0];
1173        let indices = vec![0, 1, 2, 3];
1174
1175        let split = find_best_split(&x, &y, &indices, SplitCriterion::Entropy, 1).unwrap();
1176        assert!(split.threshold > 2.0 && split.threshold < 3.0);
1177    }
1178
1179    #[test]
1180    fn test_min_samples_leaf_respected() {
1181        let x = array![[1.0], [2.0], [3.0], [4.0]];
1182        let y = array![0.0, 0.0, 1.0, 1.0];
1183        let indices = vec![0, 1, 2, 3];
1184
1185        // min_samples_leaf=3 means no valid split with 4 samples
1186        let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 3);
1187        assert!(split.is_none());
1188    }
1189}