1use crate::random::get_rng;
2use crate::{UtilsError, UtilsResult};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::Rng;
5use scirs2_core::random::StandardNormal;
6
7#[allow(clippy::too_many_arguments)]
8pub fn make_classification(
9 n_samples: usize,
10 n_features: usize,
11 n_classes: usize,
12 n_informative: Option<usize>,
13 n_redundant: Option<usize>,
14 flip_y: f64,
15 class_sep: f64,
16 random_state: Option<u64>,
17) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
18 if n_samples == 0 {
19 return Err(UtilsError::EmptyInput);
20 }
21
22 if n_classes < 2 {
23 return Err(UtilsError::InvalidParameter(
24 "n_classes must be >= 2".to_string(),
25 ));
26 }
27
28 let n_informative = n_informative.unwrap_or(n_features);
29 let n_redundant = n_redundant.unwrap_or(0);
30
31 if n_informative + n_redundant > n_features {
32 return Err(UtilsError::InvalidParameter(
33 "n_informative + n_redundant must be <= n_features".to_string(),
34 ));
35 }
36
37 let mut rng = get_rng(random_state);
38
39 let mut x = Array2::<f64>::zeros((n_samples, n_features));
41 let mut y = Array1::<i32>::zeros(n_samples);
42
43 for i in 0..n_samples {
45 y[i] = (i % n_classes) as i32;
46 }
47
48 for i in (1..n_samples).rev() {
50 let j = rng.gen_range(0..=i);
51 y.swap(i, j);
52 }
53
54 let mut centroids = Array2::<f64>::zeros((n_classes, n_informative));
56 for i in 0..n_classes {
57 for j in 0..n_informative {
58 centroids[[i, j]] = rng.sample::<f64, _>(StandardNormal) * class_sep;
59 }
60 }
61
62 for i in 0..n_samples {
64 let class_idx = y[i] as usize;
65 for j in 0..n_informative {
66 x[[i, j]] = centroids[[class_idx, j]] + rng.sample::<f64, _>(StandardNormal);
67 }
68 }
69
70 for j in 0..n_redundant {
72 let feat_idx = n_informative + j;
73 let base_feat = j % n_informative;
74 let coeff = rng.gen_range(-1.0..1.0);
75
76 for i in 0..n_samples {
77 x[[i, feat_idx]] =
78 x[[i, base_feat]] * coeff + rng.sample::<f64, _>(StandardNormal) * 0.1;
79 }
80 }
81
82 for j in (n_informative + n_redundant)..n_features {
84 for i in 0..n_samples {
85 x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
86 }
87 }
88
89 if flip_y > 0.0 {
91 let n_flip = (n_samples as f64 * flip_y) as usize;
92 for _ in 0..n_flip {
93 let idx = rng.gen_range(0..n_samples);
94 y[idx] = rng.gen_range(0..n_classes as i32);
95 }
96 }
97
98 Ok((x, y))
99}
100
101pub fn make_regression(
102 n_samples: usize,
103 n_features: usize,
104 n_informative: Option<usize>,
105 noise: f64,
106 bias: f64,
107 random_state: Option<u64>,
108) -> UtilsResult<(Array2<f64>, Array1<f64>)> {
109 if n_samples == 0 {
110 return Err(UtilsError::EmptyInput);
111 }
112
113 let n_informative = n_informative.unwrap_or(n_features);
114
115 if n_informative > n_features {
116 return Err(UtilsError::InvalidParameter(
117 "n_informative must be <= n_features".to_string(),
118 ));
119 }
120
121 let mut rng = get_rng(random_state);
122
123 let mut x = Array2::<f64>::zeros((n_samples, n_features));
125 for i in 0..n_samples {
126 for j in 0..n_features {
127 x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
128 }
129 }
130
131 let mut coef = Array1::<f64>::zeros(n_features);
133 for i in 0..n_informative {
134 coef[i] = rng.sample::<f64, _>(StandardNormal) * 100.0;
135 }
136
137 let mut y = Array1::<f64>::zeros(n_samples);
139 for i in 0..n_samples {
140 let mut target = bias;
141 for j in 0..n_features {
142 target += x[[i, j]] * coef[j];
143 }
144
145 if noise > 0.0 {
146 target += rng.sample::<f64, _>(StandardNormal) * noise;
147 }
148
149 y[i] = target;
150 }
151
152 Ok((x, y))
153}
154
155pub fn make_blobs(
156 n_samples: usize,
157 n_features: usize,
158 centers: Option<usize>,
159 cluster_std: f64,
160 center_box: (f64, f64),
161 random_state: Option<u64>,
162) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
163 if n_samples == 0 {
164 return Err(UtilsError::EmptyInput);
165 }
166
167 let n_centers = centers.unwrap_or(3);
168 let mut rng = get_rng(random_state);
169
170 let mut cluster_centers = Array2::<f64>::zeros((n_centers, n_features));
172 for i in 0..n_centers {
173 for j in 0..n_features {
174 cluster_centers[[i, j]] = rng.gen_range(center_box.0..center_box.1);
175 }
176 }
177
178 let mut x = Array2::<f64>::zeros((n_samples, n_features));
180 let mut y = Array1::<i32>::zeros(n_samples);
181
182 let samples_per_center = n_samples / n_centers;
183 let remainder = n_samples % n_centers;
184
185 let mut sample_idx = 0;
186 for center_idx in 0..n_centers {
187 let n_samples_this_center = samples_per_center + if center_idx < remainder { 1 } else { 0 };
188
189 for _ in 0..n_samples_this_center {
190 y[sample_idx] = center_idx as i32;
191
192 for j in 0..n_features {
193 let center_val = cluster_centers[[center_idx, j]];
194 x[[sample_idx, j]] =
195 center_val + rng.sample::<f64, _>(StandardNormal) * cluster_std;
196 }
197
198 sample_idx += 1;
199 }
200 }
201
202 for i in (1..n_samples).rev() {
204 let j = rng.gen_range(0..=i);
205
206 y.swap(i, j);
208
209 for k in 0..n_features {
211 let temp = x[[i, k]];
212 x[[i, k]] = x[[j, k]];
213 x[[j, k]] = temp;
214 }
215 }
216
217 Ok((x, y))
218}
219
220pub fn make_circles(
221 n_samples: usize,
222 noise: f64,
223 factor: f64,
224 random_state: Option<u64>,
225) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
226 if n_samples == 0 {
227 return Err(UtilsError::EmptyInput);
228 }
229
230 if factor <= 0.0 || factor >= 1.0 {
231 return Err(UtilsError::InvalidParameter(
232 "factor must be in (0, 1)".to_string(),
233 ));
234 }
235
236 let mut rng = get_rng(random_state);
237 let mut x = Array2::<f64>::zeros((n_samples, 2));
238 let mut y = Array1::<i32>::zeros(n_samples);
239
240 let n_outer = n_samples / 2;
241 let n_inner = n_samples - n_outer;
242
243 for i in 0..n_outer {
245 let angle = 2.0 * std::f64::consts::PI * rng.gen::<f64>();
246 x[[i, 0]] = angle.cos() + rng.sample::<f64, _>(StandardNormal) * noise;
247 x[[i, 1]] = angle.sin() + rng.sample::<f64, _>(StandardNormal) * noise;
248 y[i] = 0;
249 }
250
251 for i in 0..n_inner {
253 let idx = n_outer + i;
254 let angle = 2.0 * std::f64::consts::PI * rng.gen::<f64>();
255 x[[idx, 0]] = factor * angle.cos() + rng.sample::<f64, _>(StandardNormal) * noise;
256 x[[idx, 1]] = factor * angle.sin() + rng.sample::<f64, _>(StandardNormal) * noise;
257 y[idx] = 1;
258 }
259
260 Ok((x, y))
261}
262
263pub fn make_moons(
265 n_samples: usize,
266 noise: f64,
267 random_state: Option<u64>,
268) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
269 if n_samples == 0 {
270 return Err(UtilsError::EmptyInput);
271 }
272
273 let mut rng = get_rng(random_state);
274 let mut x = Array2::<f64>::zeros((n_samples, 2));
275 let mut y = Array1::<i32>::zeros(n_samples);
276
277 let n_samples_per_class = n_samples / 2;
278 let remainder = n_samples % 2;
279
280 for i in 0..n_samples_per_class + remainder {
282 let angle = std::f64::consts::PI * (i as f64) / (n_samples_per_class as f64);
283 x[[i, 0]] = angle.cos() + noise * rng.gen::<f64>() * 2.0 - noise;
284 x[[i, 1]] = angle.sin() + noise * rng.gen::<f64>() * 2.0 - noise;
285 y[i] = 0;
286 }
287
288 for i in 0..n_samples_per_class {
290 let idx = i + n_samples_per_class + remainder;
291 let angle = std::f64::consts::PI * (i as f64) / (n_samples_per_class as f64);
292 x[[idx, 0]] = 1.0 - angle.cos() + noise * rng.gen::<f64>() * 2.0 - noise;
293 x[[idx, 1]] = 1.0 - angle.sin() - 0.5 + noise * rng.gen::<f64>() * 2.0 - noise;
294 y[idx] = 1;
295 }
296
297 Ok((x, y))
298}
299
300pub fn make_sparse_classification(
302 n_samples: usize,
303 n_features: usize,
304 n_classes: usize,
305 n_informative: Option<usize>,
306 sparsity: f64,
307 random_state: Option<u64>,
308) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
309 if n_samples == 0 {
310 return Err(UtilsError::EmptyInput);
311 }
312
313 if !(0.0..=1.0).contains(&sparsity) {
314 return Err(UtilsError::InvalidParameter(
315 "sparsity must be between 0.0 and 1.0".to_string(),
316 ));
317 }
318
319 let (mut x, y) = make_classification(
321 n_samples,
322 n_features,
323 n_classes,
324 n_informative,
325 Some(0),
326 0.0,
327 1.0,
328 random_state,
329 )?;
330
331 let mut rng = get_rng(random_state);
333 let total_elements = n_samples * n_features;
334 let n_zeros = (total_elements as f64 * sparsity) as usize;
335
336 for _ in 0..n_zeros {
337 let row = rng.gen_range(0..n_samples);
338 let col = rng.gen_range(0..n_features);
339 x[[row, col]] = 0.0;
340 }
341
342 Ok((x, y))
343}
344
345pub fn make_multilabel_classification(
347 n_samples: usize,
348 n_features: usize,
349 n_classes: usize,
350 n_labels: usize,
351 random_state: Option<u64>,
352) -> UtilsResult<(Array2<f64>, Array2<i32>)> {
353 if n_samples == 0 {
354 return Err(UtilsError::EmptyInput);
355 }
356
357 if n_labels > n_classes {
358 return Err(UtilsError::InvalidParameter(
359 "n_labels cannot be greater than n_classes".to_string(),
360 ));
361 }
362
363 let mut rng = get_rng(random_state);
364 let mut x = Array2::<f64>::zeros((n_samples, n_features));
365 let mut y = Array2::<i32>::zeros((n_samples, n_classes));
366
367 for i in 0..n_samples {
369 for j in 0..n_features {
370 x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
371 }
372 }
373
374 for i in 0..n_samples {
376 let mut available_labels: Vec<usize> = (0..n_classes).collect();
378 for _ in 0..n_labels {
379 if available_labels.is_empty() {
380 break;
381 }
382 let idx = rng.gen_range(0..available_labels.len());
383 let label = available_labels.remove(idx);
384 y[[i, label]] = 1;
385 }
386 }
387
388 Ok((x, y))
389}
390
391pub fn make_hastie_10_2(
393 n_samples: usize,
394 random_state: Option<u64>,
395) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
396 if n_samples == 0 {
397 return Err(UtilsError::EmptyInput);
398 }
399
400 let mut rng = get_rng(random_state);
401 let n_features = 10;
402 let mut x = Array2::<f64>::zeros((n_samples, n_features));
403 let mut y = Array1::<i32>::zeros(n_samples);
404
405 for i in 0..n_samples {
406 for j in 0..n_features {
408 x[[i, j]] = rng.sample::<f64, _>(StandardNormal);
409 }
410
411 let sum_of_squares: f64 = x.row(i).iter().map(|&val| val * val).sum();
413 y[i] = if sum_of_squares > 9.34 { 1 } else { -1 };
414 }
415
416 Ok((x, y))
417}
418
419pub fn make_gaussian_quantiles(
421 n_samples: usize,
422 n_features: usize,
423 n_classes: usize,
424 mean: f64,
425 cov: f64,
426 random_state: Option<u64>,
427) -> UtilsResult<(Array2<f64>, Array1<i32>)> {
428 if n_samples == 0 {
429 return Err(UtilsError::EmptyInput);
430 }
431
432 if n_classes < 2 {
433 return Err(UtilsError::InvalidParameter(
434 "n_classes must be >= 2".to_string(),
435 ));
436 }
437
438 let mut rng = get_rng(random_state);
439 let mut x = Array2::<f64>::zeros((n_samples, n_features));
440 let mut y = Array1::<i32>::zeros(n_samples);
441
442 for i in 0..n_samples {
444 for j in 0..n_features {
445 x[[i, j]] = mean + cov * rng.sample::<f64, _>(StandardNormal);
446 }
447 }
448
449 let mut norms: Vec<(f64, usize)> = Vec::new();
452 for i in 0..n_samples {
453 let norm = x.row(i).iter().map(|&val| val * val).sum::<f64>().sqrt();
454 norms.push((norm, i));
455 }
456
457 norms.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
459
460 let samples_per_class = n_samples / n_classes;
461 let remainder = n_samples % n_classes;
462
463 let mut current_class = 0;
464 let mut samples_in_current_class = 0;
465 let mut max_samples_for_class =
466 samples_per_class + if current_class < remainder { 1 } else { 0 };
467
468 for (_, original_idx) in norms {
469 y[original_idx] = current_class as i32;
470 samples_in_current_class += 1;
471
472 if samples_in_current_class >= max_samples_for_class && current_class < n_classes - 1 {
473 current_class += 1;
474 samples_in_current_class = 0;
475 max_samples_for_class =
476 samples_per_class + if current_class < remainder { 1 } else { 0 };
477 }
478 }
479
480 Ok((x, y))
481}
482
483#[allow(non_snake_case)]
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_make_classification() {
490 let (x, y) = make_classification(100, 5, 3, None, None, 0.0, 1.0, Some(42)).unwrap();
491
492 assert_eq!(x.shape(), &[100, 5]);
493 assert_eq!(y.len(), 100);
494
495 let unique_classes: std::collections::HashSet<i32> = y.iter().copied().collect();
497 assert!(unique_classes.len() <= 3);
498 }
499
500 #[test]
501 fn test_make_regression() {
502 let (x, y) = make_regression(50, 3, Some(2), 0.1, 5.0, Some(42)).unwrap();
503
504 assert_eq!(x.shape(), &[50, 3]);
505 assert_eq!(y.len(), 50);
506 }
507
508 #[test]
509 fn test_make_blobs() {
510 let (x, y) = make_blobs(60, 2, Some(3), 1.0, (-10.0, 10.0), Some(42)).unwrap();
511
512 assert_eq!(x.shape(), &[60, 2]);
513 assert_eq!(y.len(), 60);
514
515 let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
517 assert_eq!(unique_labels.len(), 3);
518 }
519
520 #[test]
521 fn test_make_circles() {
522 let (x, y) = make_circles(100, 0.1, 0.5, Some(42)).unwrap();
523
524 assert_eq!(x.shape(), &[100, 2]);
525 assert_eq!(y.len(), 100);
526
527 let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
529 assert_eq!(unique_labels.len(), 2);
530 assert!(unique_labels.contains(&0));
531 assert!(unique_labels.contains(&1));
532 }
533
534 #[test]
535 fn test_make_moons() {
536 let (x, y) = make_moons(100, 0.1, Some(42)).unwrap();
537
538 assert_eq!(x.shape(), &[100, 2]);
539 assert_eq!(y.len(), 100);
540
541 let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
543 assert_eq!(unique_labels.len(), 2);
544 assert!(unique_labels.contains(&0));
545 assert!(unique_labels.contains(&1));
546 }
547
548 #[test]
549 fn test_make_sparse_classification() {
550 let (x, y) = make_sparse_classification(50, 10, 2, Some(5), 0.3, Some(42)).unwrap();
551
552 assert_eq!(x.shape(), &[50, 10]);
553 assert_eq!(y.len(), 50);
554
555 let zero_count = x.iter().filter(|&&val| val == 0.0).count();
557 assert!(zero_count > 0);
558 }
559
560 #[test]
561 fn test_make_multilabel_classification() {
562 let (x, y) = make_multilabel_classification(30, 5, 4, 2, Some(42)).unwrap();
563
564 assert_eq!(x.shape(), &[30, 5]);
565 assert_eq!(y.shape(), &[30, 4]);
566
567 for i in 0..30 {
569 let active_labels = y.row(i).iter().filter(|&&val| val == 1).count();
570 assert!(active_labels <= 2); }
572 }
573
574 #[test]
575 fn test_make_hastie_10_2() {
576 let (x, y) = make_hastie_10_2(100, Some(42)).unwrap();
577
578 assert_eq!(x.shape(), &[100, 10]);
579 assert_eq!(y.len(), 100);
580
581 let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
583 assert_eq!(unique_labels.len(), 2);
584 assert!(unique_labels.contains(&-1));
585 assert!(unique_labels.contains(&1));
586 }
587
588 #[test]
589 fn test_make_gaussian_quantiles() {
590 let (x, y) = make_gaussian_quantiles(60, 3, 3, 0.0, 1.0, Some(42)).unwrap();
591
592 assert_eq!(x.shape(), &[60, 3]);
593 assert_eq!(y.len(), 60);
594
595 let unique_labels: std::collections::HashSet<i32> = y.iter().copied().collect();
597 assert_eq!(unique_labels.len(), 3);
598 assert!(unique_labels.contains(&0));
599 assert!(unique_labels.contains(&1));
600 assert!(unique_labels.contains(&2));
601 }
602}