sklears_utils/
data_generation.rs

1use crate::random::get_rng;
2use crate::{UtilsError, UtilsResult};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::Rng;
5use scirs2_core::random::StandardNormal;
6
7#[allow(clippy::too_many_arguments)]
8pub fn make_classification(
9    n_samples: usize,
10    n_features: usize,
11    n_classes: usize,
12    n_informative: Option<usize>,
13    n_redundant: Option<usize>,
14    flip_y: f64,
15    class_sep: f64,
16    random_state: Option<u64>,
17) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
18    if n_samples == 0 {
19        return Err(UtilsError::EmptyInput);
20    }
21
22    if n_classes < 2 {
23        return Err(UtilsError::InvalidParameter(
24            "n_classes must be >= 2".to_string(),
25        ));
26    }
27
28    let n_informative = n_informative.unwrap_or(n_features);
29    let n_redundant = n_redundant.unwrap_or(0);
30
31    if n_informative + n_redundant > n_features {
32        return Err(UtilsError::InvalidParameter(
33            "n_informative + n_redundant must be <= n_features".to_string(),
34        ));
35    }
36
37    let mut rng = get_rng(random_state);
38
39    // Generate informative features
40    let mut x = Array2::<f64>::zeros((n_samples, n_features));
41    let mut y = Array1::<i32>::zeros(n_samples);
42
43    // Assign classes uniformly
44    for i in 0..n_samples {
45        y[i] = (i % n_classes) as i32;
46    }
47
48    // Shuffle class labels
49    for i in (1..n_samples).rev() {
50        let j = rng.gen_range(0..=i);
51        y.swap(i, j);
52    }
53
54    // Generate class centroids
55    let mut centroids = Array2::<f64>::zeros((n_classes, n_informative));
56    for i in 0..n_classes {
57        for j in 0..n_informative {
58            centroids[[i, j]] = rng.sample::<f64, _>(StandardNormal) * class_sep;
59        }
60    }
61
62    // Generate features around centroids
63    for i in 0..n_samples {
64        let class_idx = y[i] as usize;
65        for j in 0..n_informative {
66            x[[i, j]] = centroids[[class_idx, j]] + rng.sample::<f64, _>(StandardNormal);
67        }
68    }
69
70    // Add redundant features (linear combinations of informative features)
71    for j in 0..n_redundant {
72        let feat_idx = n_informative + j;
73        let base_feat = j % n_informative;
74        let coeff = rng.gen_range(-1.0..1.0);
75
76        for i in 0..n_samples {
77            x[[i, feat_idx]] =
78                x[[i, base_feat]] * coeff + rng.sample::<f64, _>(StandardNormal) * 0.1;
79        }
80    }
81
82    // Add noise to remaining features
83    for j in (n_informative + n_redundant)..n_features {
84        for i in 0..n_samples {
85            x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
86        }
87    }
88
89    // Flip some labels
90    if flip_y > 0.0 {
91        let n_flip = (n_samples as f64 * flip_y) as usize;
92        for _ in 0..n_flip {
93            let idx = rng.gen_range(0..n_samples);
94            y[idx] = rng.gen_range(0..n_classes as i32);
95        }
96    }
97
98    Ok((x, y))
99}
100
101pub fn make_regression(
102    n_samples: usize,
103    n_features: usize,
104    n_informative: Option<usize>,
105    noise: f64,
106    bias: f64,
107    random_state: Option<u64>,
108) -> UtilsResult<(Array2<f64>, Array1<f64>)> {
109    if n_samples == 0 {
110        return Err(UtilsError::EmptyInput);
111    }
112
113    let n_informative = n_informative.unwrap_or(n_features);
114
115    if n_informative > n_features {
116        return Err(UtilsError::InvalidParameter(
117            "n_informative must be <= n_features".to_string(),
118        ));
119    }
120
121    let mut rng = get_rng(random_state);
122
123    // Generate features
124    let mut x = Array2::<f64>::zeros((n_samples, n_features));
125    for i in 0..n_samples {
126        for j in 0..n_features {
127            x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
128        }
129    }
130
131    // Generate true coefficients
132    let mut coef = Array1::<f64>::zeros(n_features);
133    for i in 0..n_informative {
134        coef[i] = rng.sample::<f64, _>(StandardNormal) * 100.0;
135    }
136
137    // Generate target values
138    let mut y = Array1::<f64>::zeros(n_samples);
139    for i in 0..n_samples {
140        let mut target = bias;
141        for j in 0..n_features {
142            target += x[[i, j]] * coef[j];
143        }
144
145        if noise > 0.0 {
146            target += rng.sample::<f64, _>(StandardNormal) * noise;
147        }
148
149        y[i] = target;
150    }
151
152    Ok((x, y))
153}
154
155pub fn make_blobs(
156    n_samples: usize,
157    n_features: usize,
158    centers: Option<usize>,
159    cluster_std: f64,
160    center_box: (f64, f64),
161    random_state: Option<u64>,
162) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
163    if n_samples == 0 {
164        return Err(UtilsError::EmptyInput);
165    }
166
167    let n_centers = centers.unwrap_or(3);
168    let mut rng = get_rng(random_state);
169
170    // Generate cluster centers
171    let mut cluster_centers = Array2::<f64>::zeros((n_centers, n_features));
172    for i in 0..n_centers {
173        for j in 0..n_features {
174            cluster_centers[[i, j]] = rng.gen_range(center_box.0..center_box.1);
175        }
176    }
177
178    // Generate samples
179    let mut x = Array2::<f64>::zeros((n_samples, n_features));
180    let mut y = Array1::<i32>::zeros(n_samples);
181
182    let samples_per_center = n_samples / n_centers;
183    let remainder = n_samples % n_centers;
184
185    let mut sample_idx = 0;
186    for center_idx in 0..n_centers {
187        let n_samples_this_center = samples_per_center + if center_idx < remainder { 1 } else { 0 };
188
189        for _ in 0..n_samples_this_center {
190            y[sample_idx] = center_idx as i32;
191
192            for j in 0..n_features {
193                let center_val = cluster_centers[[center_idx, j]];
194                x[[sample_idx, j]] =
195                    center_val + rng.sample::<f64, _>(StandardNormal) * cluster_std;
196            }
197
198            sample_idx += 1;
199        }
200    }
201
202    // Shuffle the samples
203    for i in (1..n_samples).rev() {
204        let j = rng.gen_range(0..=i);
205
206        // Swap labels
207        y.swap(i, j);
208
209        // Swap rows in X
210        for k in 0..n_features {
211            let temp = x[[i, k]];
212            x[[i, k]] = x[[j, k]];
213            x[[j, k]] = temp;
214        }
215    }
216
217    Ok((x, y))
218}
219
220pub fn make_circles(
221    n_samples: usize,
222    noise: f64,
223    factor: f64,
224    random_state: Option<u64>,
225) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
226    if n_samples == 0 {
227        return Err(UtilsError::EmptyInput);
228    }
229
230    if factor <= 0.0 || factor >= 1.0 {
231        return Err(UtilsError::InvalidParameter(
232            "factor must be in (0, 1)".to_string(),
233        ));
234    }
235
236    let mut rng = get_rng(random_state);
237    let mut x = Array2::<f64>::zeros((n_samples, 2));
238    let mut y = Array1::<i32>::zeros(n_samples);
239
240    let n_outer = n_samples / 2;
241    let n_inner = n_samples - n_outer;
242
243    // Generate outer circle
244    for i in 0..n_outer {
245        let angle = 2.0 * std::f64::consts::PI * rng.gen::<f64>();
246        x[[i, 0]] = angle.cos() + rng.sample::<f64, _>(StandardNormal) * noise;
247        x[[i, 1]] = angle.sin() + rng.sample::<f64, _>(StandardNormal) * noise;
248        y[i] = 0;
249    }
250
251    // Generate inner circle
252    for i in 0..n_inner {
253        let idx = n_outer + i;
254        let angle = 2.0 * std::f64::consts::PI * rng.gen::<f64>();
255        x[[idx, 0]] = factor * angle.cos() + rng.sample::<f64, _>(StandardNormal) * noise;
256        x[[idx, 1]] = factor * angle.sin() + rng.sample::<f64, _>(StandardNormal) * noise;
257        y[idx] = 1;
258    }
259
260    Ok((x, y))
261}
262
263/// Generate a moon-shaped 2D dataset for non-linear classification
264pub fn make_moons(
265    n_samples: usize,
266    noise: f64,
267    random_state: Option<u64>,
268) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
269    if n_samples == 0 {
270        return Err(UtilsError::EmptyInput);
271    }
272
273    let mut rng = get_rng(random_state);
274    let mut x = Array2::<f64>::zeros((n_samples, 2));
275    let mut y = Array1::<i32>::zeros(n_samples);
276
277    let n_samples_per_class = n_samples / 2;
278    let remainder = n_samples % 2;
279
280    // Generate first moon (upper moon)
281    for i in 0..n_samples_per_class + remainder {
282        let angle = std::f64::consts::PI * (i as f64) / (n_samples_per_class as f64);
283        x[[i, 0]] = angle.cos() + noise * rng.gen::<f64>() * 2.0 - noise;
284        x[[i, 1]] = angle.sin() + noise * rng.gen::<f64>() * 2.0 - noise;
285        y[i] = 0;
286    }
287
288    // Generate second moon (lower moon)
289    for i in 0..n_samples_per_class {
290        let idx = i + n_samples_per_class + remainder;
291        let angle = std::f64::consts::PI * (i as f64) / (n_samples_per_class as f64);
292        x[[idx, 0]] = 1.0 - angle.cos() + noise * rng.gen::<f64>() * 2.0 - noise;
293        x[[idx, 1]] = 1.0 - angle.sin() - 0.5 + noise * rng.gen::<f64>() * 2.0 - noise;
294        y[idx] = 1;
295    }
296
297    Ok((x, y))
298}
299
300/// Generate a sparse classification dataset
301pub fn make_sparse_classification(
302    n_samples: usize,
303    n_features: usize,
304    n_classes: usize,
305    n_informative: Option<usize>,
306    sparsity: f64,
307    random_state: Option<u64>,
308) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
309    if n_samples == 0 {
310        return Err(UtilsError::EmptyInput);
311    }
312
313    if !(0.0..=1.0).contains(&sparsity) {
314        return Err(UtilsError::InvalidParameter(
315            "sparsity must be between 0.0 and 1.0".to_string(),
316        ));
317    }
318
319    // Generate a dense classification dataset first
320    let (mut x, y) = make_classification(
321        n_samples,
322        n_features,
323        n_classes,
324        n_informative,
325        Some(0),
326        0.0,
327        1.0,
328        random_state,
329    )?;
330
331    // Make it sparse by setting a fraction of values to zero
332    let mut rng = get_rng(random_state);
333    let total_elements = n_samples * n_features;
334    let n_zeros = (total_elements as f64 * sparsity) as usize;
335
336    for _ in 0..n_zeros {
337        let row = rng.gen_range(0..n_samples);
338        let col = rng.gen_range(0..n_features);
339        x[[row, col]] = 0.0;
340    }
341
342    Ok((x, y))
343}
344
345/// Generate a multilabel classification dataset
346pub fn make_multilabel_classification(
347    n_samples: usize,
348    n_features: usize,
349    n_classes: usize,
350    n_labels: usize,
351    random_state: Option<u64>,
352) -> UtilsResult<(Array2<f64>, Array2<i32>)> {
353    if n_samples == 0 {
354        return Err(UtilsError::EmptyInput);
355    }
356
357    if n_labels > n_classes {
358        return Err(UtilsError::InvalidParameter(
359            "n_labels cannot be greater than n_classes".to_string(),
360        ));
361    }
362
363    let mut rng = get_rng(random_state);
364    let mut x = Array2::<f64>::zeros((n_samples, n_features));
365    let mut y = Array2::<i32>::zeros((n_samples, n_classes));
366
367    // Generate features
368    for i in 0..n_samples {
369        for j in 0..n_features {
370            x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
371        }
372    }
373
374    // Generate multilabel targets
375    for i in 0..n_samples {
376        // Randomly select which labels are active for this sample
377        let mut available_labels: Vec<usize> = (0..n_classes).collect();
378        for _ in 0..n_labels {
379            if available_labels.is_empty() {
380                break;
381            }
382            let idx = rng.gen_range(0..available_labels.len());
383            let label = available_labels.remove(idx);
384            y[[i, label]] = 1;
385        }
386    }
387
388    Ok((x, y))
389}
390
391/// Generate the Hastie 10-2 dataset for binary classification
392pub fn make_hastie_10_2(
393    n_samples: usize,
394    random_state: Option<u64>,
395) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
396    if n_samples == 0 {
397        return Err(UtilsError::EmptyInput);
398    }
399
400    let mut rng = get_rng(random_state);
401    let n_features = 10;
402    let mut x = Array2::<f64>::zeros((n_samples, n_features));
403    let mut y = Array1::<i32>::zeros(n_samples);
404
405    for i in 0..n_samples {
406        // Generate 10 features from standard normal distribution
407        for j in 0..n_features {
408            x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
409        }
410
411        // Hastie's formula: y = 1 if sum(X_j^2) > 9.34, else y = -1
412        let sum_of_squares: f64 = x.row(i).iter().map(|&val| val * val).sum();
413        y[i] = if sum_of_squares > 9.34 { 1 } else { -1 };
414    }
415
416    Ok((x, y))
417}
418
419/// Generate a dataset with Gaussian quantiles for classification
420pub fn make_gaussian_quantiles(
421    n_samples: usize,
422    n_features: usize,
423    n_classes: usize,
424    mean: f64,
425    cov: f64,
426    random_state: Option<u64>,
427) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
428    if n_samples == 0 {
429        return Err(UtilsError::EmptyInput);
430    }
431
432    if n_classes < 2 {
433        return Err(UtilsError::InvalidParameter(
434            "n_classes must be >= 2".to_string(),
435        ));
436    }
437
438    let mut rng = get_rng(random_state);
439    let mut x = Array2::<f64>::zeros((n_samples, n_features));
440    let mut y = Array1::<i32>::zeros(n_samples);
441
442    // Generate features from multivariate normal distribution
443    for i in 0..n_samples {
444        for j in 0..n_features {
445            x[[i, j]] = mean + cov * rng.sample::<f64, _>(StandardNormal);
446        }
447    }
448
449    // Compute quantiles for class assignment
450    // Calculate the L2 norm (distance from origin) for each sample
451    let mut norms: Vec<(f64, usize)> = Vec::new();
452    for i in 0..n_samples {
453        let norm = x.row(i).iter().map(|&val| val * val).sum::<f64>().sqrt();
454        norms.push((norm, i));
455    }
456
457    // Sort by norm and assign classes based on quantiles
458    norms.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
459
460    let samples_per_class = n_samples / n_classes;
461    let remainder = n_samples % n_classes;
462
463    let mut current_class = 0;
464    let mut samples_in_current_class = 0;
465    let mut max_samples_for_class =
466        samples_per_class + if current_class < remainder { 1 } else { 0 };
467
468    for (_, original_idx) in norms {
469        y[original_idx] = current_class as i32;
470        samples_in_current_class += 1;
471
472        if samples_in_current_class >= max_samples_for_class && current_class < n_classes - 1 {
473            current_class += 1;
474            samples_in_current_class = 0;
475            max_samples_for_class =
476                samples_per_class + if current_class < remainder { 1 } else { 0 };
477        }
478    }
479
480    Ok((x, y))
481}
482
483#[allow(non_snake_case)]
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_make_classification() {
490        let (x, y) = make_classification(100, 5, 3, None, None, 0.0, 1.0, Some(42)).unwrap();
491
492        assert_eq!(x.shape(), &[100, 5]);
493        assert_eq!(y.len(), 100);
494
495        // Check that all classes are present
496        let unique_classes: std::collections::HashSet<i32> = y.iter().copied().collect();
497        assert!(unique_classes.len() <= 3);
498    }
499
500    #[test]
501    fn test_make_regression() {
502        let (x, y) = make_regression(50, 3, Some(2), 0.1, 5.0, Some(42)).unwrap();
503
504        assert_eq!(x.shape(), &[50, 3]);
505        assert_eq!(y.len(), 50);
506    }
507
508    #[test]
509    fn test_make_blobs() {
510        let (x, y) = make_blobs(60, 2, Some(3), 1.0, (-10.0, 10.0), Some(42)).unwrap();
511
512        assert_eq!(x.shape(), &[60, 2]);
513        assert_eq!(y.len(), 60);
514
515        // Check that all expected cluster labels are present
516        let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
517        assert_eq!(unique_labels.len(), 3);
518    }
519
520    #[test]
521    fn test_make_circles() {
522        let (x, y) = make_circles(100, 0.1, 0.5, Some(42)).unwrap();
523
524        assert_eq!(x.shape(), &[100, 2]);
525        assert_eq!(y.len(), 100);
526
527        // Check that both classes are present
528        let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
529        assert_eq!(unique_labels.len(), 2);
530        assert!(unique_labels.contains(&0));
531        assert!(unique_labels.contains(&1));
532    }
533
534    #[test]
535    fn test_make_moons() {
536        let (x, y) = make_moons(100, 0.1, Some(42)).unwrap();
537
538        assert_eq!(x.shape(), &[100, 2]);
539        assert_eq!(y.len(), 100);
540
541        // Check that both classes are present
542        let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
543        assert_eq!(unique_labels.len(), 2);
544        assert!(unique_labels.contains(&0));
545        assert!(unique_labels.contains(&1));
546    }
547
548    #[test]
549    fn test_make_sparse_classification() {
550        let (x, y) = make_sparse_classification(50, 10, 2, Some(5), 0.3, Some(42)).unwrap();
551
552        assert_eq!(x.shape(), &[50, 10]);
553        assert_eq!(y.len(), 50);
554
555        // Check sparsity - should have some zero values
556        let zero_count = x.iter().filter(|&&val| val == 0.0).count();
557        assert!(zero_count > 0);
558    }
559
560    #[test]
561    fn test_make_multilabel_classification() {
562        let (x, y) = make_multilabel_classification(30, 5, 4, 2, Some(42)).unwrap();
563
564        assert_eq!(x.shape(), &[30, 5]);
565        assert_eq!(y.shape(), &[30, 4]);
566
567        // Check that each sample has the expected number of labels
568        for i in 0..30 {
569            let active_labels = y.row(i).iter().filter(|&&val| val == 1).count();
570            assert!(active_labels <= 2); // Should have at most n_labels active
571        }
572    }
573
574    #[test]
575    fn test_make_hastie_10_2() {
576        let (x, y) = make_hastie_10_2(100, Some(42)).unwrap();
577
578        assert_eq!(x.shape(), &[100, 10]);
579        assert_eq!(y.len(), 100);
580
581        // Check that both classes (-1 and 1) are present
582        let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
583        assert_eq!(unique_labels.len(), 2);
584        assert!(unique_labels.contains(&-1));
585        assert!(unique_labels.contains(&1));
586    }
587
588    #[test]
589    fn test_make_gaussian_quantiles() {
590        let (x, y) = make_gaussian_quantiles(60, 3, 3, 0.0, 1.0, Some(42)).unwrap();
591
592        assert_eq!(x.shape(), &[60, 3]);
593        assert_eq!(y.len(), 60);
594
595        // Check that all expected classes are present
596        let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
597        assert_eq!(unique_labels.len(), 3);
598        assert!(unique_labels.contains(&0));
599        assert!(unique_labels.contains(&1));
600        assert!(unique_labels.contains(&2));
601    }
602}