1use std::collections::HashMap;
2
3use anofox_ml_core::Float;
4use ndarray::{Array1, Array2};
5
6#[inline]
9fn float_key<F: Float>(v: F) -> u64 {
10 v.to_f64().unwrap().to_bits()
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
15pub enum SplitCriterion {
16 Gini,
18 Entropy,
20 Mse,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
26pub enum SplitStrategy {
27 Best,
29 Random,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
35pub enum MaxFeatures {
36 Sqrt,
38 Log2,
40 Fixed(usize),
42 Fraction(f64),
44}
45
46impl MaxFeatures {
47 pub fn resolve(&self, n_features: usize) -> usize {
49 match self {
50 MaxFeatures::Sqrt => (n_features as f64).sqrt().floor().max(1.0) as usize,
51 MaxFeatures::Log2 => (n_features as f64).log2().floor().max(1.0) as usize,
52 MaxFeatures::Fixed(k) => (*k).min(n_features).max(1),
53 MaxFeatures::Fraction(f) => (*f * n_features as f64).floor().max(1.0) as usize,
54 }
55 }
56}
57
58pub fn select_feature_subset(n_features: usize, k: usize, seed: u64) -> Vec<usize> {
64 if k >= n_features {
65 return (0..n_features).collect();
66 }
67 let mut indices: Vec<usize> = (0..n_features).collect();
68 let mut state = seed.wrapping_add(0x9E3779B97F4A7C15);
69 for i in 0..k {
70 state ^= state << 13;
72 state ^= state >> 7;
73 state ^= state << 17;
74 let j = i + (state as usize) % (n_features - i);
75 indices.swap(i, j);
76 }
77 indices.truncate(k);
78 indices.sort_unstable();
79 indices
80}
81
82#[derive(Debug, Clone)]
84pub struct BestSplit<F: Float> {
85 pub feature_index: usize,
86 pub threshold: F,
87 pub left_indices: Vec<usize>,
88 pub right_indices: Vec<usize>,
89 pub improvement: F,
90}
91
92pub fn find_best_split<F: Float>(
98 x: &Array2<F>,
99 y: &Array1<F>,
100 indices: &[usize],
101 criterion: SplitCriterion,
102 min_samples_leaf: usize,
103) -> Option<BestSplit<F>> {
104 let all_features: Vec<usize> = (0..x.ncols()).collect();
105 find_best_split_with_features(x, y, indices, criterion, min_samples_leaf, &all_features)
106}
107
108pub fn find_best_split_with_features<F: Float>(
110 x: &Array2<F>,
111 y: &Array1<F>,
112 indices: &[usize],
113 criterion: SplitCriterion,
114 min_samples_leaf: usize,
115 feature_indices: &[usize],
116) -> Option<BestSplit<F>> {
117 let n = indices.len();
118 if n < 2 * min_samples_leaf {
119 return None;
120 }
121
122 let parent_impurity = compute_impurity(y, indices, criterion);
123
124 match criterion {
125 SplitCriterion::Gini | SplitCriterion::Entropy => find_best_split_classification(
126 x,
127 y,
128 indices,
129 criterion,
130 min_samples_leaf,
131 feature_indices,
132 n,
133 parent_impurity,
134 ),
135 SplitCriterion::Mse => find_best_split_regression(
136 x,
137 y,
138 indices,
139 min_samples_leaf,
140 feature_indices,
141 n,
142 parent_impurity,
143 ),
144 }
145}
146
147#[inline]
152fn sort_feature_pairs<F: Float>(
153 x: &Array2<F>,
154 indices: &[usize],
155 feature: usize,
156 sorted_pairs: &mut Vec<(F, usize)>,
157) {
158 sorted_pairs.clear();
159 sorted_pairs.extend(indices.iter().map(|&i| (x[[i, feature]], i)));
160 sorted_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
161}
162
163struct CandidateSplit<F: Float> {
165 feature: usize,
166 threshold: F,
167 improvement: F,
168}
169
170#[inline]
175fn try_update_best_split<F: Float>(
176 improvement: F,
177 best_improvement: &mut F,
178 best: &mut Option<CandidateSplit<F>>,
179 feature: usize,
180 threshold: F,
181) {
182 if improvement > *best_improvement {
183 *best_improvement = improvement;
184 *best = Some(CandidateSplit {
185 feature,
186 threshold,
187 improvement,
188 });
189 }
190}
191
192trait SplitAccumulator<F: Float> {
195 fn reset(&mut self, y: &Array1<F>, indices: &[usize]);
197 fn move_to_left(&mut self, y: &Array1<F>, idx: usize);
199 fn weighted_impurity(&self, n: usize) -> F;
201 fn n_left(&self) -> usize;
203 fn n_right(&self) -> usize;
205}
206
207struct ClassificationAccumulator<F: Float> {
209 left_counts: Vec<usize>,
210 right_counts: Vec<usize>,
211 n_left: usize,
212 n_right: usize,
213 criterion: SplitCriterion,
214 class_map: HashMap<u64, usize>,
215 _marker: std::marker::PhantomData<F>,
216}
217
218impl<F: Float> ClassificationAccumulator<F> {
219 fn new(y: &Array1<F>, indices: &[usize]) -> Self {
221 let class_map = build_class_map(y, indices);
222 let n_classes = class_map.len();
223
224 let mut total_counts = vec![0usize; n_classes];
225 for &i in indices {
226 let cls = class_map[&float_key(y[i])];
227 total_counts[cls] += 1;
228 }
229
230 Self {
231 left_counts: vec![0usize; n_classes],
232 right_counts: total_counts,
233 n_left: 0,
234 n_right: indices.len(),
235 criterion: SplitCriterion::Gini,
236 class_map,
237 _marker: std::marker::PhantomData,
238 }
239 }
240}
241
242impl<F: Float> SplitAccumulator<F> for ClassificationAccumulator<F> {
243 fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
244 self.left_counts.fill(0);
246 self.right_counts.fill(0);
247 for &i in indices {
248 let cls = self.class_map[&float_key(y[i])];
249 self.right_counts[cls] += 1;
250 }
251 self.n_left = 0;
252 self.n_right = indices.len();
253 }
254
255 fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
256 let cls = self.class_map[&float_key(y[idx])];
257 self.left_counts[cls] += 1;
258 self.right_counts[cls] -= 1;
259 self.n_left += 1;
260 self.n_right -= 1;
261 }
262
263 fn weighted_impurity(&self, n: usize) -> F {
264 let n_f = F::from_usize(n).unwrap();
265 let nl = F::from_usize(self.n_left).unwrap();
266 let nr = F::from_usize(self.n_right).unwrap();
267 let left_imp = impurity_from_counts(&self.left_counts, self.n_left, self.criterion);
268 let right_imp = impurity_from_counts(&self.right_counts, self.n_right, self.criterion);
269 (nl / n_f) * left_imp + (nr / n_f) * right_imp
270 }
271
272 fn n_left(&self) -> usize {
273 self.n_left
274 }
275
276 fn n_right(&self) -> usize {
277 self.n_right
278 }
279}
280
281impl<F: Float> ClassificationAccumulator<F> {
282 fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
283 self.criterion = criterion;
284 self
285 }
286}
287
288struct RegressionAccumulator<F: Float> {
290 left_sum: F,
291 left_sum_sq: F,
292 right_sum: F,
293 right_sum_sq: F,
294 n_left: usize,
295 n_right: usize,
296}
297
298impl<F: Float> RegressionAccumulator<F> {
299 fn new(y: &Array1<F>, indices: &[usize]) -> Self {
300 let mut total_sum = F::zero();
301 let mut total_sum_sq = F::zero();
302 for &i in indices {
303 let v = y[i];
304 total_sum += v;
305 total_sum_sq += v * v;
306 }
307
308 Self {
309 left_sum: F::zero(),
310 left_sum_sq: F::zero(),
311 right_sum: total_sum,
312 right_sum_sq: total_sum_sq,
313 n_left: 0,
314 n_right: indices.len(),
315 }
316 }
317}
318
319impl<F: Float> SplitAccumulator<F> for RegressionAccumulator<F> {
320 fn reset(&mut self, y: &Array1<F>, indices: &[usize]) {
321 self.left_sum = F::zero();
322 self.left_sum_sq = F::zero();
323 self.right_sum = F::zero();
324 self.right_sum_sq = F::zero();
325 for &i in indices {
326 let v = y[i];
327 self.right_sum += v;
328 self.right_sum_sq += v * v;
329 }
330 self.n_left = 0;
331 self.n_right = indices.len();
332 }
333
334 fn move_to_left(&mut self, y: &Array1<F>, idx: usize) {
335 let v = y[idx];
336 self.left_sum += v;
337 self.left_sum_sq += v * v;
338 self.right_sum -= v;
339 self.right_sum_sq -= v * v;
340 self.n_left += 1;
341 self.n_right -= 1;
342 }
343
344 fn weighted_impurity(&self, n: usize) -> F {
345 let n_f = F::from_usize(n).unwrap();
346 let nl = F::from_usize(self.n_left).unwrap();
347 let nr = F::from_usize(self.n_right).unwrap();
348 let left_mse = self.left_sum_sq / nl - (self.left_sum / nl) * (self.left_sum / nl);
350 let right_mse = self.right_sum_sq / nr - (self.right_sum / nr) * (self.right_sum / nr);
351 (nl / n_f) * left_mse + (nr / n_f) * right_mse
352 }
353
354 fn n_left(&self) -> usize {
355 self.n_left
356 }
357
358 fn n_right(&self) -> usize {
359 self.n_right
360 }
361}
362
363#[inline]
368fn evaluate_candidate_split<F: Float, A: SplitAccumulator<F>>(
369 acc: &A,
370 n: usize,
371 min_samples_leaf: usize,
372 cur_val: F,
373 next_val: F,
374 parent_impurity: F,
375) -> Option<(F, F)> {
376 if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
378 return None;
379 }
380
381 if acc.n_left() < min_samples_leaf || acc.n_right() < min_samples_leaf {
383 return None;
384 }
385
386 let threshold = (cur_val + next_val) / (F::one() + F::one());
387 let improvement = parent_impurity - acc.weighted_impurity(n);
388 Some((threshold, improvement))
389}
390
391#[allow(clippy::too_many_arguments)]
398fn find_best_split_inner<F, A>(
399 x: &Array2<F>,
400 y: &Array1<F>,
401 indices: &[usize],
402 min_samples_leaf: usize,
403 feature_indices: &[usize],
404 n: usize,
405 parent_impurity: F,
406 mut acc: A,
407) -> Option<BestSplit<F>>
408where
409 F: Float,
410 A: SplitAccumulator<F>,
411{
412 let mut best: Option<CandidateSplit<F>> = None;
413 let mut best_improvement = F::neg_infinity();
414
415 let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
416
417 for &feature in feature_indices {
418 sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
419
420 acc.reset(y, indices);
421
422 for pos in 0..n - 1 {
423 let (cur_val, cur_idx) = sorted_pairs[pos];
424 acc.move_to_left(y, cur_idx);
425
426 let next_val = sorted_pairs[pos + 1].0;
427 if let Some((threshold, improvement)) = evaluate_candidate_split(
428 &acc,
429 n,
430 min_samples_leaf,
431 cur_val,
432 next_val,
433 parent_impurity,
434 ) {
435 try_update_best_split(
436 improvement,
437 &mut best_improvement,
438 &mut best,
439 feature,
440 threshold,
441 );
442 }
443 }
444 }
445
446 best.map(|candidate| {
448 let mut left_indices = Vec::with_capacity(n);
449 let mut right_indices = Vec::with_capacity(n);
450 for &i in indices {
451 if x[[i, candidate.feature]] <= candidate.threshold {
452 left_indices.push(i);
453 } else {
454 right_indices.push(i);
455 }
456 }
457 BestSplit {
458 feature_index: candidate.feature,
459 threshold: candidate.threshold,
460 left_indices,
461 right_indices,
462 improvement: candidate.improvement,
463 }
464 })
465}
466
467#[allow(clippy::too_many_arguments)]
469fn find_best_split_classification<F: Float>(
470 x: &Array2<F>,
471 y: &Array1<F>,
472 indices: &[usize],
473 criterion: SplitCriterion,
474 min_samples_leaf: usize,
475 feature_indices: &[usize],
476 n: usize,
477 parent_impurity: F,
478) -> Option<BestSplit<F>> {
479 let acc = ClassificationAccumulator::<F>::new(y, indices).with_criterion(criterion);
480 find_best_split_inner(
481 x,
482 y,
483 indices,
484 min_samples_leaf,
485 feature_indices,
486 n,
487 parent_impurity,
488 acc,
489 )
490}
491
492fn find_best_split_regression<F: Float>(
494 x: &Array2<F>,
495 y: &Array1<F>,
496 indices: &[usize],
497 min_samples_leaf: usize,
498 feature_indices: &[usize],
499 n: usize,
500 parent_impurity: F,
501) -> Option<BestSplit<F>> {
502 let acc = RegressionAccumulator::<F>::new(y, indices);
503 find_best_split_inner(
504 x,
505 y,
506 indices,
507 min_samples_leaf,
508 feature_indices,
509 n,
510 parent_impurity,
511 acc,
512 )
513}
514
515fn build_class_map<F: Float>(y: &Array1<F>, indices: &[usize]) -> HashMap<u64, usize> {
517 let mut map = HashMap::new();
518 let mut next_idx = 0;
519 for &i in indices {
520 let bits = float_key(y[i]);
521 if let std::collections::hash_map::Entry::Vacant(e) = map.entry(bits) {
522 e.insert(next_idx);
523 next_idx += 1;
524 }
525 }
526 map
527}
528
529#[inline]
531fn impurity_from_counts<F: Float>(counts: &[usize], total: usize, criterion: SplitCriterion) -> F {
532 let n = F::from_usize(total).unwrap();
533 match criterion {
534 SplitCriterion::Gini => {
535 let sum_sq: F = counts
536 .iter()
537 .filter(|&&c| c > 0)
538 .map(|&c| {
539 let p = F::from_usize(c).unwrap() / n;
540 p * p
541 })
542 .fold(F::zero(), |a, b| a + b);
543 F::one() - sum_sq
544 }
545 SplitCriterion::Entropy => {
546 let sum: F = counts
547 .iter()
548 .filter(|&&c| c > 0)
549 .map(|&c| {
550 let p = F::from_usize(c).unwrap() / n;
551 p * p.ln()
552 })
553 .fold(F::zero(), |a, b| a + b);
554 -sum
555 }
556 SplitCriterion::Mse => unreachable!("MSE does not use class counts"),
557 }
558}
559
560#[inline]
562pub fn compute_impurity<F: Float>(
563 y: &Array1<F>,
564 indices: &[usize],
565 criterion: SplitCriterion,
566) -> F {
567 match criterion {
568 SplitCriterion::Gini => gini(y, indices),
569 SplitCriterion::Entropy => entropy(y, indices),
570 SplitCriterion::Mse => mse_impurity(y, indices),
571 }
572}
573
574#[inline]
575fn gini<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
576 let n = F::from_usize(indices.len()).unwrap();
577 let class_counts = count_classes(y, indices);
578
579 let sum_sq: F = class_counts
580 .iter()
581 .map(|&(_, count)| {
582 let p = F::from_usize(count).unwrap() / n;
583 p * p
584 })
585 .fold(F::zero(), |a, b| a + b);
586
587 F::one() - sum_sq
588}
589
590#[inline]
591fn entropy<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
592 let n = F::from_usize(indices.len()).unwrap();
593 let class_counts = count_classes(y, indices);
594
595 let sum: F = class_counts
596 .iter()
597 .map(|&(_, count)| {
598 let p = F::from_usize(count).unwrap() / n;
599 if p > F::zero() {
600 p * p.ln()
601 } else {
602 F::zero()
603 }
604 })
605 .fold(F::zero(), |a, b| a + b);
606
607 -sum
608}
609
610#[inline]
611fn mse_impurity<F: Float>(y: &Array1<F>, indices: &[usize]) -> F {
612 let n = F::from_usize(indices.len()).unwrap();
613 let mean: F = indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n;
614
615 indices
616 .iter()
617 .map(|&i| (y[i] - mean) * (y[i] - mean))
618 .fold(F::zero(), |a, b| a + b)
619 / n
620}
621
622pub fn count_classes<F: Float>(y: &Array1<F>, indices: &[usize]) -> Vec<(F, usize)> {
624 let mut map: HashMap<u64, (F, usize)> = HashMap::new();
625 for &i in indices {
626 let val = y[i];
627 let bits = float_key(val);
628 map.entry(bits).and_modify(|e| e.1 += 1).or_insert((val, 1));
629 }
630 let mut counts: Vec<(F, usize)> = map.into_values().collect();
631 counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
632 counts
633}
634
635#[inline]
637pub fn leaf_value<F: Float>(y: &Array1<F>, indices: &[usize], criterion: SplitCriterion) -> F {
638 match criterion {
639 SplitCriterion::Mse => {
640 let n = F::from_usize(indices.len()).unwrap();
641 indices.iter().map(|&i| y[i]).fold(F::zero(), |a, b| a + b) / n
642 }
643 SplitCriterion::Gini | SplitCriterion::Entropy => {
644 let counts = count_classes(y, indices);
645 counts
646 .into_iter()
647 .max_by_key(|&(_, count)| count)
648 .unwrap()
649 .0
650 }
651 }
652}
653
654pub fn find_random_split<F: Float>(
659 x: &Array2<F>,
660 y: &Array1<F>,
661 indices: &[usize],
662 criterion: SplitCriterion,
663 min_samples_leaf: usize,
664 seed: u64,
665) -> Option<BestSplit<F>> {
666 let n_features = x.ncols();
667 let n = indices.len();
668 if n < 2 * min_samples_leaf {
669 return None;
670 }
671
672 let parent_impurity = compute_impurity(y, indices, criterion);
673
674 let mut best: Option<CandidateSplit<F>> = None;
675 let mut best_improvement = F::neg_infinity();
676
677 let mut rng_state = seed.wrapping_add(0x9E3779B97F4A7C15);
679
680 for feature in 0..n_features {
681 let mut min_val = x[[indices[0], feature]];
683 let mut max_val = min_val;
684 for &i in &indices[1..] {
685 let v = x[[i, feature]];
686 if v < min_val {
687 min_val = v;
688 }
689 if v > max_val {
690 max_val = v;
691 }
692 }
693
694 if (max_val - min_val).abs() < F::from_f64(1e-15).unwrap() {
695 continue;
696 }
697
698 rng_state ^= rng_state << 13;
700 rng_state ^= rng_state >> 7;
701 rng_state ^= rng_state << 17;
702 let t = F::from_f64((rng_state as f64) / (u64::MAX as f64)).unwrap();
703 let threshold = min_val + t * (max_val - min_val);
704
705 let mut n_left = 0usize;
707 let mut n_right = 0usize;
708 for &i in indices {
709 if x[[i, feature]] <= threshold {
710 n_left += 1;
711 } else {
712 n_right += 1;
713 }
714 }
715
716 if n_left < min_samples_leaf || n_right < min_samples_leaf {
717 continue;
718 }
719
720 let left_indices: Vec<usize> = indices
722 .iter()
723 .copied()
724 .filter(|&i| x[[i, feature]] <= threshold)
725 .collect();
726 let right_indices: Vec<usize> = indices
727 .iter()
728 .copied()
729 .filter(|&i| x[[i, feature]] > threshold)
730 .collect();
731
732 let left_imp = compute_impurity(y, &left_indices, criterion);
733 let right_imp = compute_impurity(y, &right_indices, criterion);
734
735 let n_f = F::from_usize(n).unwrap();
736 let nl_f = F::from_usize(n_left).unwrap();
737 let nr_f = F::from_usize(n_right).unwrap();
738 let weighted = (nl_f / n_f) * left_imp + (nr_f / n_f) * right_imp;
739 let improvement = parent_impurity - weighted;
740
741 try_update_best_split(
742 improvement,
743 &mut best_improvement,
744 &mut best,
745 feature,
746 threshold,
747 );
748 }
749
750 best.map(|candidate| {
752 let mut left_indices = Vec::with_capacity(n);
753 let mut right_indices = Vec::with_capacity(n);
754 for &i in indices {
755 if x[[i, candidate.feature]] <= candidate.threshold {
756 left_indices.push(i);
757 } else {
758 right_indices.push(i);
759 }
760 }
761 BestSplit {
762 feature_index: candidate.feature,
763 threshold: candidate.threshold,
764 left_indices,
765 right_indices,
766 improvement: candidate.improvement,
767 }
768 })
769}
770
771#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
773pub enum ClassWeight {
774 Balanced,
776 Manual(Vec<(f64, f64)>),
778}
779
780pub fn compute_sample_weights_from_class_weight<F: Float>(
782 y: &Array1<F>,
783 class_weight: &ClassWeight,
784) -> Array1<F> {
785 let n_samples = y.len();
786 match class_weight {
787 ClassWeight::Balanced => {
788 let counts = count_classes(y, &(0..n_samples).collect::<Vec<_>>());
789 let n_classes = counts.len();
790 let n_f = F::from_usize(n_samples).unwrap();
791 let nc_f = F::from_usize(n_classes).unwrap();
792 let mut weights = Array1::<F>::ones(n_samples);
793 for i in 0..n_samples {
794 for &(class_val, count) in &counts {
795 if (y[i] - class_val).abs() < F::from_f64(1e-9).unwrap() {
796 weights[i] = n_f / (nc_f * F::from_usize(count).unwrap());
797 break;
798 }
799 }
800 }
801 weights
802 }
803 ClassWeight::Manual(mapping) => {
804 let mut weights = Array1::<F>::ones(n_samples);
805 for i in 0..n_samples {
806 let yi = y[i].to_f64().unwrap();
807 for &(class_val, w) in mapping {
808 if (yi - class_val).abs() < 1e-9 {
809 weights[i] = F::from_f64(w).unwrap();
810 break;
811 }
812 }
813 }
814 weights
815 }
816 }
817}
818
819pub fn compute_weighted_impurity<F: Float>(
821 y: &Array1<F>,
822 indices: &[usize],
823 weights: &Array1<F>,
824 criterion: SplitCriterion,
825) -> F {
826 let total_weight: F = indices
827 .iter()
828 .map(|&i| weights[i])
829 .fold(F::zero(), |a, b| a + b);
830 if total_weight <= F::zero() {
831 return F::zero();
832 }
833
834 match criterion {
835 SplitCriterion::Gini => {
836 let mut class_weights: HashMap<u64, F> = HashMap::new();
838 for &i in indices {
839 let key = float_key(y[i]);
840 *class_weights.entry(key).or_insert(F::zero()) += weights[i];
841 }
842 let sum_sq: F = class_weights
843 .values()
844 .map(|&w| {
845 let p = w / total_weight;
846 p * p
847 })
848 .fold(F::zero(), |a, b| a + b);
849 F::one() - sum_sq
850 }
851 SplitCriterion::Entropy => {
852 let mut class_weights: HashMap<u64, F> = HashMap::new();
853 for &i in indices {
854 let key = float_key(y[i]);
855 *class_weights.entry(key).or_insert(F::zero()) += weights[i];
856 }
857 let sum: F = class_weights
858 .values()
859 .filter(|&&w| w > F::zero())
860 .map(|&w| {
861 let p = w / total_weight;
862 p * p.ln()
863 })
864 .fold(F::zero(), |a, b| a + b);
865 -sum
866 }
867 SplitCriterion::Mse => {
868 let w_mean: F = indices
870 .iter()
871 .map(|&i| weights[i] * y[i])
872 .fold(F::zero(), |a, b| a + b)
873 / total_weight;
874 indices
875 .iter()
876 .map(|&i| weights[i] * (y[i] - w_mean) * (y[i] - w_mean))
877 .fold(F::zero(), |a, b| a + b)
878 / total_weight
879 }
880 }
881}
882
883pub fn weighted_leaf_value<F: Float>(
885 y: &Array1<F>,
886 indices: &[usize],
887 weights: &Array1<F>,
888 criterion: SplitCriterion,
889) -> F {
890 match criterion {
891 SplitCriterion::Mse => {
892 let total_weight: F = indices
893 .iter()
894 .map(|&i| weights[i])
895 .fold(F::zero(), |a, b| a + b);
896 if total_weight <= F::zero() {
897 return F::zero();
898 }
899 indices
900 .iter()
901 .map(|&i| weights[i] * y[i])
902 .fold(F::zero(), |a, b| a + b)
903 / total_weight
904 }
905 SplitCriterion::Gini | SplitCriterion::Entropy => {
906 let mut class_weights: HashMap<u64, (F, F)> = HashMap::new();
908 for &i in indices {
909 let key = float_key(y[i]);
910 class_weights
911 .entry(key)
912 .and_modify(|e| e.1 += weights[i])
913 .or_insert((y[i], weights[i]));
914 }
915 class_weights
916 .into_values()
917 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
918 .unwrap()
919 .0
920 }
921 }
922}
923
924pub fn weighted_count_classes<F: Float>(
926 y: &Array1<F>,
927 indices: &[usize],
928 weights: &Array1<F>,
929) -> Vec<(F, F)> {
930 let mut map: HashMap<u64, (F, F)> = HashMap::new();
931 for &i in indices {
932 let val = y[i];
933 let bits = float_key(val);
934 map.entry(bits)
935 .and_modify(|e| e.1 += weights[i])
936 .or_insert((val, weights[i]));
937 }
938 let mut counts: Vec<(F, F)> = map.into_values().collect();
939 counts.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
940 counts
941}
942
943pub fn find_best_split_weighted<F: Float>(
947 x: &Array2<F>,
948 y: &Array1<F>,
949 indices: &[usize],
950 weights: &Array1<F>,
951 criterion: SplitCriterion,
952 min_samples_leaf: usize,
953 feature_indices: &[usize],
954) -> Option<BestSplit<F>> {
955 let n = indices.len();
956 if n < 2 * min_samples_leaf {
957 return None;
958 }
959
960 let parent_impurity = compute_weighted_impurity(y, indices, weights, criterion);
961
962 let mut best: Option<CandidateSplit<F>> = None;
963 let mut best_improvement = F::neg_infinity();
964 let mut sorted_pairs: Vec<(F, usize)> = Vec::with_capacity(n);
965
966 let total_weight: F = indices
967 .iter()
968 .map(|&i| weights[i])
969 .fold(F::zero(), |a, b| a + b);
970
971 for &feature in feature_indices {
972 sort_feature_pairs(x, indices, feature, &mut sorted_pairs);
973
974 let mut left_weight = F::zero();
976 let mut right_weight = total_weight;
977 let mut left_class_weights: HashMap<u64, F> = HashMap::new();
978 let mut right_class_weights: HashMap<u64, F> = HashMap::new();
979
980 for &i in indices {
982 let key = float_key(y[i]);
983 *right_class_weights.entry(key).or_insert(F::zero()) += weights[i];
984 }
985
986 for pos in 0..n - 1 {
987 let (cur_val, cur_idx) = sorted_pairs[pos];
988 let w = weights[cur_idx];
989 let key = float_key(y[cur_idx]);
990
991 left_weight += w;
993 right_weight -= w;
994 *left_class_weights.entry(key).or_insert(F::zero()) += w;
995 *right_class_weights.entry(key).or_insert(F::zero()) -= w;
996
997 let next_val = sorted_pairs[pos + 1].0;
998 if (next_val - cur_val).abs() < F::from_f64(1e-15).unwrap() {
999 continue;
1000 }
1001
1002 let n_left = pos + 1;
1004 let n_right = n - n_left;
1005 if n_left < min_samples_leaf || n_right < min_samples_leaf {
1006 continue;
1007 }
1008
1009 let left_imp = match criterion {
1011 SplitCriterion::Gini => {
1012 let sum_sq: F = left_class_weights
1013 .values()
1014 .filter(|&&w| w > F::zero())
1015 .map(|&w| {
1016 let p = w / left_weight;
1017 p * p
1018 })
1019 .fold(F::zero(), |a, b| a + b);
1020 F::one() - sum_sq
1021 }
1022 SplitCriterion::Entropy => {
1023 let sum: F = left_class_weights
1024 .values()
1025 .filter(|&&w| w > F::zero())
1026 .map(|&w| {
1027 let p = w / left_weight;
1028 p * p.ln()
1029 })
1030 .fold(F::zero(), |a, b| a + b);
1031 -sum
1032 }
1033 SplitCriterion::Mse => {
1034 let left_indices: Vec<usize> =
1036 sorted_pairs[..=pos].iter().map(|&(_, i)| i).collect();
1037 compute_weighted_impurity(y, &left_indices, weights, criterion)
1038 }
1039 };
1040
1041 let right_imp = match criterion {
1042 SplitCriterion::Gini => {
1043 let sum_sq: F = right_class_weights
1044 .values()
1045 .filter(|&&w| w > F::zero())
1046 .map(|&w| {
1047 let p = w / right_weight;
1048 p * p
1049 })
1050 .fold(F::zero(), |a, b| a + b);
1051 F::one() - sum_sq
1052 }
1053 SplitCriterion::Entropy => {
1054 let sum: F = right_class_weights
1055 .values()
1056 .filter(|&&w| w > F::zero())
1057 .map(|&w| {
1058 let p = w / right_weight;
1059 p * p.ln()
1060 })
1061 .fold(F::zero(), |a, b| a + b);
1062 -sum
1063 }
1064 SplitCriterion::Mse => {
1065 let right_indices: Vec<usize> =
1066 sorted_pairs[pos + 1..].iter().map(|&(_, i)| i).collect();
1067 compute_weighted_impurity(y, &right_indices, weights, criterion)
1068 }
1069 };
1070
1071 let weighted_imp =
1072 (left_weight / total_weight) * left_imp + (right_weight / total_weight) * right_imp;
1073 let improvement = parent_impurity - weighted_imp;
1074 let threshold = (cur_val + next_val) / (F::one() + F::one());
1075
1076 try_update_best_split(
1077 improvement,
1078 &mut best_improvement,
1079 &mut best,
1080 feature,
1081 threshold,
1082 );
1083 }
1084 }
1085
1086 best.map(|candidate| {
1087 let mut left_indices = Vec::with_capacity(n);
1088 let mut right_indices = Vec::with_capacity(n);
1089 for &i in indices {
1090 if x[[i, candidate.feature]] <= candidate.threshold {
1091 left_indices.push(i);
1092 } else {
1093 right_indices.push(i);
1094 }
1095 }
1096 BestSplit {
1097 feature_index: candidate.feature,
1098 threshold: candidate.threshold,
1099 left_indices,
1100 right_indices,
1101 improvement: candidate.improvement,
1102 }
1103 })
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use super::*;
1109 use approx::assert_abs_diff_eq;
1110 use ndarray::array;
1111
1112 #[test]
1113 fn test_gini_pure() {
1114 let y = array![1.0, 1.0, 1.0];
1115 let indices = vec![0, 1, 2];
1116 assert_abs_diff_eq!(gini(&y, &indices), 0.0, epsilon = 1e-10);
1117 }
1118
1119 #[test]
1120 fn test_gini_balanced() {
1121 let y = array![0.0, 1.0];
1122 let indices = vec![0, 1];
1123 assert_abs_diff_eq!(gini(&y, &indices), 0.5, epsilon = 1e-10);
1124 }
1125
1126 #[test]
1127 fn test_mse_pure() {
1128 let y = array![5.0, 5.0, 5.0];
1129 let indices = vec![0, 1, 2];
1130 assert_abs_diff_eq!(mse_impurity(&y, &indices), 0.0, epsilon = 1e-10);
1131 }
1132
1133 #[test]
1134 fn test_find_best_split() {
1135 let x = array![[1.0], [2.0], [3.0], [4.0]];
1136 let y = array![0.0, 0.0, 1.0, 1.0];
1137 let indices = vec![0, 1, 2, 3];
1138
1139 let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 1).unwrap();
1140 assert!(split.threshold > 2.0 && split.threshold < 3.0);
1142 }
1143
1144 #[test]
1145 fn test_find_best_split_regression() {
1146 let x = array![[1.0], [2.0], [3.0], [4.0]];
1147 let y = array![1.0, 1.5, 10.0, 10.5];
1148 let indices = vec![0, 1, 2, 3];
1149
1150 let split = find_best_split(&x, &y, &indices, SplitCriterion::Mse, 1).unwrap();
1151 assert!(split.threshold > 2.0 && split.threshold < 3.0);
1153 assert_eq!(split.left_indices.len(), 2);
1154 assert_eq!(split.right_indices.len(), 2);
1155 }
1156
1157 #[test]
1158 fn test_count_classes_uses_exact_bits() {
1159 let y = array![0.0, 1.0, 0.0, 2.0, 1.0];
1160 let indices = vec![0, 1, 2, 3, 4];
1161 let counts = count_classes(&y, &indices);
1162 assert_eq!(counts.len(), 3);
1163 assert_eq!(counts[0].1, 2); assert_eq!(counts[1].1, 2); assert_eq!(counts[2].1, 1); }
1168
1169 #[test]
1170 fn test_find_best_split_entropy() {
1171 let x = array![[1.0], [2.0], [3.0], [4.0]];
1172 let y = array![0.0, 0.0, 1.0, 1.0];
1173 let indices = vec![0, 1, 2, 3];
1174
1175 let split = find_best_split(&x, &y, &indices, SplitCriterion::Entropy, 1).unwrap();
1176 assert!(split.threshold > 2.0 && split.threshold < 3.0);
1177 }
1178
1179 #[test]
1180 fn test_min_samples_leaf_respected() {
1181 let x = array![[1.0], [2.0], [3.0], [4.0]];
1182 let y = array![0.0, 0.0, 1.0, 1.0];
1183 let indices = vec![0, 1, 2, 3];
1184
1185 let split = find_best_split(&x, &y, &indices, SplitCriterion::Gini, 3);
1187 assert!(split.is_none());
1188 }
1189}