1use crate::error::{DatasetsError, Result};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::prelude::*;
11use scirs2_core::random::rngs::StdRng;
12use scirs2_core::random::seq::SliceRandom;
13use scirs2_core::random::Uniform;
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
18pub enum BalancingStrategy {
19 RandomOversample,
21 RandomUndersample,
23 SMOTE {
25 k_neighbors: usize,
27 },
28}
29
30#[allow(dead_code)]
57pub fn random_oversample(
58 data: &Array2<f64>,
59 targets: &Array1<f64>,
60 random_seed: Option<u64>,
61) -> Result<(Array2<f64>, Array1<f64>)> {
62 if data.nrows() != targets.len() {
63 return Err(DatasetsError::InvalidFormat(
64 "Data rows and targets length must match".to_string(),
65 ));
66 }
67
68 if data.is_empty() || targets.is_empty() {
69 return Err(DatasetsError::InvalidFormat(
70 "Data and targets cannot be empty".to_string(),
71 ));
72 }
73
74 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
76 for (i, &target) in targets.iter().enumerate() {
77 let class = target.round() as i64;
78 class_indices.entry(class).or_default().push(i);
79 }
80
81 let max_class_size = class_indices
83 .values()
84 .map(|v| v.len())
85 .max()
86 .expect("Operation failed");
87
88 let mut rng = match random_seed {
89 Some(_seed) => StdRng::seed_from_u64(_seed),
90 None => {
91 let mut r = thread_rng();
92 StdRng::seed_from_u64(r.next_u64())
93 }
94 };
95
96 let mut resampled_indices = Vec::new();
98
99 for (_, indices) in class_indices {
100 let class_size = indices.len();
101
102 resampled_indices.extend(&indices);
104
105 if class_size < max_class_size {
107 let samples_needed = max_class_size - class_size;
108 for _ in 0..samples_needed {
109 let random_idx = rng.sample(Uniform::new(0, class_size).expect("Operation failed"));
110 resampled_indices.push(indices[random_idx]);
111 }
112 }
113 }
114
115 let resampled_data = data.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
117 let resampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &resampled_indices);
118
119 Ok((resampled_data, resampled_targets))
120}
121
122#[allow(dead_code)]
149pub fn random_undersample(
150 data: &Array2<f64>,
151 targets: &Array1<f64>,
152 random_seed: Option<u64>,
153) -> Result<(Array2<f64>, Array1<f64>)> {
154 if data.nrows() != targets.len() {
155 return Err(DatasetsError::InvalidFormat(
156 "Data rows and targets length must match".to_string(),
157 ));
158 }
159
160 if data.is_empty() || targets.is_empty() {
161 return Err(DatasetsError::InvalidFormat(
162 "Data and targets cannot be empty".to_string(),
163 ));
164 }
165
166 let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
168 for (i, &target) in targets.iter().enumerate() {
169 let class = target.round() as i64;
170 class_indices.entry(class).or_default().push(i);
171 }
172
173 let min_class_size = class_indices
175 .values()
176 .map(|v| v.len())
177 .min()
178 .expect("Operation failed");
179
180 let mut rng = match random_seed {
181 Some(_seed) => StdRng::seed_from_u64(_seed),
182 None => {
183 let mut r = thread_rng();
184 StdRng::seed_from_u64(r.next_u64())
185 }
186 };
187
188 let mut undersampled_indices = Vec::new();
190
191 for (_, mut indices) in class_indices {
192 if indices.len() > min_class_size {
193 indices.shuffle(&mut rng);
195 undersampled_indices.extend(&indices[0..min_class_size]);
196 } else {
197 undersampled_indices.extend(&indices);
199 }
200 }
201
202 let undersampled_data = data.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
204 let undersampled_targets = targets.select(scirs2_core::ndarray::Axis(0), &undersampled_indices);
205
206 Ok((undersampled_data, undersampled_targets))
207}
208
209#[allow(dead_code)]
240pub fn generate_synthetic_samples(
241 data: &Array2<f64>,
242 targets: &Array1<f64>,
243 target_class: f64,
244 n_synthetic: usize,
245 k_neighbors: usize,
246 random_seed: Option<u64>,
247) -> Result<(Array2<f64>, Array1<f64>)> {
248 if data.nrows() != targets.len() {
249 return Err(DatasetsError::InvalidFormat(
250 "Data rows and targets length must match".to_string(),
251 ));
252 }
253
254 if n_synthetic == 0 {
255 return Err(DatasetsError::InvalidFormat(
256 "Number of _synthetic samples must be > 0".to_string(),
257 ));
258 }
259
260 if k_neighbors == 0 {
261 return Err(DatasetsError::InvalidFormat(
262 "Number of _neighbors must be > 0".to_string(),
263 ));
264 }
265
266 let class_indices: Vec<usize> = targets
268 .iter()
269 .enumerate()
270 .filter(|(_, &target)| (target - target_class).abs() < 1e-10)
271 .map(|(i, _)| i)
272 .collect();
273
274 if class_indices.len() < 2 {
275 return Err(DatasetsError::InvalidFormat(
276 "Need at least 2 samples of the target _class for _synthetic generation".to_string(),
277 ));
278 }
279
280 if k_neighbors >= class_indices.len() {
281 return Err(DatasetsError::InvalidFormat(
282 "k_neighbors must be less than the number of samples in the target _class".to_string(),
283 ));
284 }
285
286 let mut rng = match random_seed {
287 Some(_seed) => StdRng::seed_from_u64(_seed),
288 None => {
289 let mut r = thread_rng();
290 StdRng::seed_from_u64(r.next_u64())
291 }
292 };
293
294 let n_features = data.ncols();
295 let mut synthetic_data = Array2::zeros((n_synthetic, n_features));
296 let synthetic_targets = Array1::from_elem(n_synthetic, target_class);
297
298 for i in 0..n_synthetic {
299 let base_idx = class_indices
301 [rng.sample(Uniform::new(0, class_indices.len()).expect("Operation failed"))];
302 let base_sample = data.row(base_idx);
303
304 let mut distances: Vec<(usize, f64)> = class_indices
306 .iter()
307 .filter(|&&idx| idx != base_idx)
308 .map(|&idx| {
309 let neighbor = data.row(idx);
310 let distance: f64 = base_sample
311 .iter()
312 .zip(neighbor.iter())
313 .map(|(&a, &b)| (a - b).powi(2))
314 .sum::<f64>()
315 .sqrt();
316 (idx, distance)
317 })
318 .collect();
319
320 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
321 let k_nearest = &distances[0..k_neighbors.min(distances.len())];
322
323 let neighbor_idx =
325 k_nearest[rng.sample(Uniform::new(0, k_nearest.len()).expect("Operation failed"))].0;
326 let neighbor_sample = data.row(neighbor_idx);
327
328 let alpha = rng.random_range(0.0..1.0);
330 for (j, synthetic_feature) in synthetic_data.row_mut(i).iter_mut().enumerate() {
331 *synthetic_feature = base_sample[j] + alpha * (neighbor_sample[j] - base_sample[j]);
332 }
333 }
334
335 Ok((synthetic_data, synthetic_targets))
336}
337
338#[allow(dead_code)]
365pub fn create_balanced_dataset(
366 data: &Array2<f64>,
367 targets: &Array1<f64>,
368 strategy: BalancingStrategy,
369 random_seed: Option<u64>,
370) -> Result<(Array2<f64>, Array1<f64>)> {
371 match strategy {
372 BalancingStrategy::RandomOversample => random_oversample(data, targets, random_seed),
373 BalancingStrategy::RandomUndersample => random_undersample(data, targets, random_seed),
374 BalancingStrategy::SMOTE { k_neighbors } => {
375 let mut class_counts: HashMap<i64, usize> = HashMap::new();
377 for &target in targets.iter() {
378 let class = target.round() as i64;
379 *class_counts.entry(class).or_default() += 1;
380 }
381
382 let max_count = *class_counts.values().max().expect("Operation failed");
383 let mut combined_data = data.clone();
384 let mut combined_targets = targets.clone();
385
386 for (&class, &count) in &class_counts {
387 if count < max_count {
388 let samples_needed = max_count - count;
389 let (synthetic_data, synthetic_targets) = generate_synthetic_samples(
390 data,
391 targets,
392 class as f64,
393 samples_needed,
394 k_neighbors,
395 random_seed,
396 )?;
397
398 combined_data = scirs2_core::ndarray::concatenate(
400 scirs2_core::ndarray::Axis(0),
401 &[combined_data.view(), synthetic_data.view()],
402 )
403 .map_err(|_| {
404 DatasetsError::InvalidFormat("Failed to concatenate data".to_string())
405 })?;
406
407 combined_targets = scirs2_core::ndarray::concatenate(
408 scirs2_core::ndarray::Axis(0),
409 &[combined_targets.view(), synthetic_targets.view()],
410 )
411 .map_err(|_| {
412 DatasetsError::InvalidFormat("Failed to concatenate targets".to_string())
413 })?;
414 }
415 }
416
417 Ok((combined_data, combined_targets))
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use scirs2_core::random::Uniform;
426
427 #[test]
428 fn test_random_oversample() {
429 let data = Array2::from_shape_vec(
430 (6, 2),
431 vec![
432 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
433 ],
434 )
435 .expect("Test: SMOTE operation failed");
436 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
439 random_oversample(&data, &targets, Some(42)).expect("Operation failed");
440
441 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
443 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
444 assert_eq!(class_0_count, 4); assert_eq!(class_1_count, 4);
446
447 assert_eq!(balanced_data.nrows(), 8);
449 assert_eq!(balanced_targets.len(), 8);
450
451 assert_eq!(balanced_data.ncols(), 2);
453 }
454
455 #[test]
456 fn test_random_oversample_invalid_params() {
457 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
458 .expect("Operation failed");
459 let targets = Array1::from(vec![0.0, 1.0]);
460
461 assert!(random_oversample(&data, &targets, None).is_err());
463
464 let empty_data = Array2::zeros((0, 2));
466 let empty_targets = Array1::from(vec![]);
467 assert!(random_oversample(&empty_data, &empty_targets, None).is_err());
468 }
469
470 #[test]
471 fn test_random_undersample() {
472 let data = Array2::from_shape_vec(
473 (6, 2),
474 vec![
475 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
476 ],
477 )
478 .expect("Test: ADASYN operation failed");
479 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) =
482 random_undersample(&data, &targets, Some(42)).expect("Operation failed");
483
484 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
486 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
487 assert_eq!(class_0_count, 2); assert_eq!(class_1_count, 2); assert_eq!(balanced_data.nrows(), 4);
492 assert_eq!(balanced_targets.len(), 4);
493
494 assert_eq!(balanced_data.ncols(), 2);
496 }
497
498 #[test]
499 fn test_random_undersample_invalid_params() {
500 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
501 .expect("Operation failed");
502 let targets = Array1::from(vec![0.0, 1.0]);
503
504 assert!(random_undersample(&data, &targets, None).is_err());
506
507 let empty_data = Array2::zeros((0, 2));
509 let empty_targets = Array1::from(vec![]);
510 assert!(random_undersample(&empty_data, &empty_targets, None).is_err());
511 }
512
513 #[test]
514 fn test_generate_synthetic_samples() {
515 let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
516 .expect("Operation failed");
517 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
518
519 let (synthetic_data, synthetic_targets) =
520 generate_synthetic_samples(&data, &targets, 0.0, 2, 2, Some(42))
521 .expect("Operation failed");
522
523 assert_eq!(synthetic_data.nrows(), 2);
525 assert_eq!(synthetic_targets.len(), 2);
526
527 for &target in synthetic_targets.iter() {
529 assert_eq!(target, 0.0);
530 }
531
532 assert_eq!(synthetic_data.ncols(), 2);
534
535 for i in 0..synthetic_data.nrows() {
537 for j in 0..synthetic_data.ncols() {
538 let value = synthetic_data[[i, j]];
539 assert!((0.5..=2.5).contains(&value)); }
541 }
542 }
543
544 #[test]
545 fn test_generate_synthetic_samples_invalid_params() {
546 let data = Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 1.5, 1.5, 2.5, 2.5])
547 .expect("Operation failed");
548 let targets = Array1::from(vec![0.0, 0.0, 0.0, 1.0]);
549
550 let bad_targets = Array1::from(vec![0.0, 1.0]);
552 assert!(generate_synthetic_samples(&data, &bad_targets, 0.0, 2, 2, None).is_err());
553
554 assert!(generate_synthetic_samples(&data, &targets, 0.0, 0, 2, None).is_err());
556
557 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 0, None).is_err());
559
560 assert!(generate_synthetic_samples(&data, &targets, 1.0, 2, 2, None).is_err());
562
563 assert!(generate_synthetic_samples(&data, &targets, 0.0, 2, 3, None).is_err());
565 }
566
567 #[test]
568 fn test_create_balanced_dataset_random_oversample() {
569 let data = Array2::from_shape_vec(
570 (6, 2),
571 vec![
572 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
573 ],
574 )
575 .expect("Test: undersample operation failed");
576 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
577
578 let (balanced_data, balanced_targets) = create_balanced_dataset(
579 &data,
580 &targets,
581 BalancingStrategy::RandomOversample,
582 Some(42),
583 )
584 .expect("Test: undersample operation failed");
585
586 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
588 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
589 assert_eq!(class_0_count, class_1_count);
590 assert_eq!(balanced_data.nrows(), balanced_targets.len());
591 }
592
593 #[test]
594 fn test_create_balanced_dataset_random_undersample() {
595 let data = Array2::from_shape_vec(
596 (6, 2),
597 vec![
598 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
599 ],
600 )
601 .expect("Test: cluster centroids operation failed");
602 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0]);
603
604 let (balanced_data, balanced_targets) = create_balanced_dataset(
605 &data,
606 &targets,
607 BalancingStrategy::RandomUndersample,
608 Some(42),
609 )
610 .expect("Test: cluster centroids operation failed");
611
612 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
614 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
615 assert_eq!(class_0_count, class_1_count);
616 assert_eq!(balanced_data.nrows(), balanced_targets.len());
617 }
618
619 #[test]
620 fn test_create_balanced_dataset_smote() {
621 let data = Array2::from_shape_vec(
622 (8, 2),
623 vec![
624 1.0, 1.0, 1.5, 1.5, 2.0, 2.0, 2.5, 2.5, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0,
625 ],
626 )
627 .expect("Test: edited operation failed");
628 let targets = Array1::from(vec![0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0]); let (balanced_data, balanced_targets) = create_balanced_dataset(
631 &data,
632 &targets,
633 BalancingStrategy::SMOTE { k_neighbors: 2 },
634 Some(42),
635 )
636 .expect("Test: edited operation failed");
637
638 let class_0_count = balanced_targets.iter().filter(|&&x| x == 0.0).count();
640 let class_1_count = balanced_targets.iter().filter(|&&x| x == 1.0).count();
641 assert_eq!(class_0_count, class_1_count);
642 assert_eq!(balanced_data.nrows(), balanced_targets.len());
643 }
644
645 #[test]
646 fn test_balancing_strategy_with_multiple_classes() {
647 let data = Array2::from_shape_vec((9, 2), (0..18).map(|x| x as f64).collect())
649 .expect("Operation failed");
650 let targets = Array1::from(vec![0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]);
651 let (_over_data, over_targets) = create_balanced_dataset(
655 &data,
656 &targets,
657 BalancingStrategy::RandomOversample,
658 Some(42),
659 )
660 .expect("Test: borderline SMOTE operation failed");
661
662 let over_class_0_count = over_targets.iter().filter(|&&x| x == 0.0).count();
663 let over_class_1_count = over_targets.iter().filter(|&&x| x == 1.0).count();
664 let over_class_2_count = over_targets.iter().filter(|&&x| x == 2.0).count();
665
666 assert_eq!(over_class_0_count, 4);
668 assert_eq!(over_class_1_count, 4);
669 assert_eq!(over_class_2_count, 4);
670
671 let (_under_data, under_targets) = create_balanced_dataset(
673 &data,
674 &targets,
675 BalancingStrategy::RandomUndersample,
676 Some(42),
677 )
678 .expect("Test: borderline SMOTE operation failed");
679
680 let under_class_0_count = under_targets.iter().filter(|&&x| x == 0.0).count();
681 let under_class_1_count = under_targets.iter().filter(|&&x| x == 1.0).count();
682 let under_class_2_count = under_targets.iter().filter(|&&x| x == 2.0).count();
683
684 assert_eq!(under_class_0_count, 2);
686 assert_eq!(under_class_1_count, 2);
687 assert_eq!(under_class_2_count, 2);
688 }
689}