1use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::rngs::StdRng;
16use scirs2_core::random::{Rng, RngExt, SeedableRng};
17
18use crate::error::{OptimizeError, OptimizeResult};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SamplingStrategy {
23 Random,
25 LatinHypercube,
27 Sobol,
29 Halton,
31}
32
33impl Default for SamplingStrategy {
34 fn default() -> Self {
35 Self::LatinHypercube
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct SamplingConfig {
42 pub lhs_maximin_iters: usize,
44 pub seed: Option<u64>,
46 pub scramble: bool,
48}
49
50impl Default for SamplingConfig {
51 fn default() -> Self {
52 Self {
53 lhs_maximin_iters: 100,
54 seed: None,
55 scramble: true,
56 }
57 }
58}
59
60pub fn generate_samples(
71 n_samples: usize,
72 bounds: &[(f64, f64)],
73 strategy: SamplingStrategy,
74 config: Option<SamplingConfig>,
75) -> OptimizeResult<Array2<f64>> {
76 let config = config.unwrap_or_default();
77 let n_dims = bounds.len();
78
79 if n_samples == 0 {
80 return Ok(Array2::zeros((0, n_dims)));
81 }
82 if n_dims == 0 {
83 return Err(OptimizeError::InvalidInput(
84 "Bounds must have at least one dimension".to_string(),
85 ));
86 }
87
88 for (i, &(lo, hi)) in bounds.iter().enumerate() {
90 if lo >= hi {
91 return Err(OptimizeError::InvalidInput(format!(
92 "Lower bound must be strictly less than upper bound for dimension {} (got [{}, {}])",
93 i, lo, hi
94 )));
95 }
96 if !lo.is_finite() || !hi.is_finite() {
97 return Err(OptimizeError::InvalidInput(format!(
98 "Bounds must be finite for dimension {} (got [{}, {}])",
99 i, lo, hi
100 )));
101 }
102 }
103
104 match strategy {
105 SamplingStrategy::Random => random_sampling(n_samples, bounds, &config),
106 SamplingStrategy::LatinHypercube => latin_hypercube_sampling(n_samples, bounds, &config),
107 SamplingStrategy::Sobol => sobol_sampling(n_samples, bounds, &config),
108 SamplingStrategy::Halton => halton_sampling(n_samples, bounds, &config),
109 }
110}
111
112fn random_sampling(
117 n_samples: usize,
118 bounds: &[(f64, f64)],
119 config: &SamplingConfig,
120) -> OptimizeResult<Array2<f64>> {
121 let n_dims = bounds.len();
122 let mut rng = make_rng(config.seed);
123 let mut samples = Array2::zeros((n_samples, n_dims));
124
125 for i in 0..n_samples {
126 for (j, &(lo, hi)) in bounds.iter().enumerate() {
127 samples[[i, j]] = lo + rng.random_range(0.0..1.0) * (hi - lo);
128 }
129 }
130
131 Ok(samples)
132}
133
134fn latin_hypercube_sampling(
145 n_samples: usize,
146 bounds: &[(f64, f64)],
147 config: &SamplingConfig,
148) -> OptimizeResult<Array2<f64>> {
149 let n_dims = bounds.len();
150 let mut rng = make_rng(config.seed);
151
152 let mut unit_samples = Array2::zeros((n_samples, n_dims));
155
156 for j in 0..n_dims {
157 let mut perm: Vec<usize> = (0..n_samples).collect();
158 for i in (1..n_samples).rev() {
160 let swap_idx = rng.random_range(0..=i);
161 perm.swap(i, swap_idx);
162 }
163 for i in 0..n_samples {
164 let u: f64 = rng.random_range(0.0..1.0);
166 unit_samples[[i, j]] = (perm[i] as f64 + u) / n_samples as f64;
167 }
168 }
169
170 if config.lhs_maximin_iters > 0 && n_samples > 2 {
172 let mut best_min_dist = compute_min_distance(&unit_samples);
173
174 for _ in 0..config.lhs_maximin_iters {
175 let dim = rng.random_range(0..n_dims);
177 let r1 = rng.random_range(0..n_samples);
179 let mut r2 = rng.random_range(0..n_samples.saturating_sub(1));
180 if r2 >= r1 {
181 r2 += 1;
182 }
183
184 let tmp = unit_samples[[r1, dim]];
186 unit_samples[[r1, dim]] = unit_samples[[r2, dim]];
187 unit_samples[[r2, dim]] = tmp;
188
189 let new_min_dist = compute_min_distance(&unit_samples);
190 if new_min_dist > best_min_dist {
191 best_min_dist = new_min_dist;
192 } else {
193 let tmp = unit_samples[[r1, dim]];
195 unit_samples[[r1, dim]] = unit_samples[[r2, dim]];
196 unit_samples[[r2, dim]] = tmp;
197 }
198 }
199 }
200
201 let mut result = Array2::zeros((n_samples, n_dims));
203 for i in 0..n_samples {
204 for (j, &(lo, hi)) in bounds.iter().enumerate() {
205 result[[i, j]] = lo + unit_samples[[i, j]] * (hi - lo);
206 }
207 }
208
209 Ok(result)
210}
211
212fn compute_min_distance(samples: &Array2<f64>) -> f64 {
214 let n = samples.nrows();
215 if n < 2 {
216 return f64::INFINITY;
217 }
218 let mut min_dist = f64::INFINITY;
219 for i in 0..n {
220 for j in (i + 1)..n {
221 let mut sq_dist = 0.0;
222 for k in 0..samples.ncols() {
223 let d = samples[[i, k]] - samples[[j, k]];
224 sq_dist += d * d;
225 }
226 if sq_dist < min_dist {
227 min_dist = sq_dist;
228 }
229 }
230 }
231 min_dist.sqrt()
232}
233
234fn sobol_sampling(
244 n_samples: usize,
245 bounds: &[(f64, f64)],
246 config: &SamplingConfig,
247) -> OptimizeResult<Array2<f64>> {
248 let n_dims = bounds.len();
249 let mut samples = Array2::zeros((n_samples, n_dims));
250
251 let direction_numbers = get_sobol_direction_numbers(n_dims)?;
255
256 for j in 0..n_dims {
257 let dirs = &direction_numbers[j];
258 let mut x: u64 = 0;
259 for i in 0..n_samples {
260 if j == 0 {
261 x = gray_code_sobol(i as u64 + 1);
263 } else {
264 if i == 0 {
266 x = 0;
267 } else {
268 let c = rightmost_zero_bit(i as u64);
270 let dir_idx = c.min(dirs.len() - 1);
271 x ^= dirs[dir_idx];
272 }
273 }
274
275 let value = x as f64 / (1u64 << 32) as f64;
276
277 let scrambled = if config.scramble {
279 let mut rng = make_rng(config.seed.map(|s| s.wrapping_add(j as u64 * 1000 + 7)));
280 let shift: f64 = rng.random_range(0.0..1.0);
281 (value + shift) % 1.0
282 } else {
283 value
284 };
285
286 let (lo, hi) = bounds[j];
287 samples[[i, j]] = lo + scrambled * (hi - lo);
288 }
289 }
290
291 Ok(samples)
292}
293
294fn gray_code_sobol(n: u64) -> u64 {
296 let mut result: u64 = 0;
299 let mut val = n;
300 let mut bit = 1u64 << 31;
301 while val > 0 {
302 if val & 1 != 0 {
303 result ^= bit;
304 }
305 val >>= 1;
306 bit >>= 1;
307 }
308 result
309}
310
311fn rightmost_zero_bit(n: u64) -> usize {
313 let mut val = n;
314 let mut c = 0usize;
315 while val & 1 != 0 {
316 val >>= 1;
317 c += 1;
318 }
319 c
320}
321
322fn get_sobol_direction_numbers(n_dims: usize) -> OptimizeResult<Vec<Vec<u64>>> {
329 let max_bits = 32usize;
331
332 let mut all_dirs = Vec::with_capacity(n_dims);
333
334 all_dirs.push(vec![0u64; max_bits]);
336
337 if n_dims <= 1 {
338 return Ok(all_dirs);
339 }
340
341 let primitive_polys: &[(u32, u32)] = &[
347 (1, 0), (2, 1), (3, 1), (3, 2), (4, 1), (4, 4), (5, 2), (5, 4), (5, 7), (5, 11), (5, 13), (5, 14), (6, 1), (6, 13), (6, 16), (6, 19), (6, 22), (6, 25), (7, 1), (7, 4), ];
368
369 let initial_m: &[&[u64]] = &[
373 &[1], &[1, 1], &[1, 1, 1], &[1, 3, 1], &[1, 1, 1, 1], &[1, 1, 3, 1], &[1, 3, 5, 1, 3], &[1, 3, 3, 1, 1], &[1, 3, 7, 7, 5], &[1, 1, 5, 1, 15], &[1, 3, 1, 3, 5], &[1, 3, 7, 7, 5], &[1, 1, 1, 1, 1, 1], &[1, 1, 5, 3, 13, 7], &[1, 3, 3, 1, 1, 1], &[1, 1, 1, 5, 7, 11], &[1, 1, 7, 3, 29, 3], &[1, 3, 7, 7, 21, 25], &[1, 1, 1, 1, 1, 1, 1], &[1, 3, 1, 1, 1, 7, 1], ];
394
395 for dim_idx in 1..n_dims {
396 let poly_idx = if dim_idx - 1 < primitive_polys.len() {
397 dim_idx - 1
398 } else {
399 (dim_idx - 1) % primitive_polys.len()
400 };
401
402 let (degree, poly_bits) = primitive_polys[poly_idx];
403 let s = degree as usize;
404
405 let mut dirs = vec![0u64; max_bits];
406
407 let init = if dim_idx - 1 < initial_m.len() {
409 initial_m[dim_idx - 1]
410 } else {
411 &[1u64; 1][..] };
414
415 for k in 0..s.min(max_bits) {
416 let m_k = if k < init.len() { init[k] } else { 1 };
417 dirs[k] = m_k << (max_bits - k - 1);
419 }
420
421 for k in s..max_bits {
425 let mut new_v = dirs[k - s] ^ (dirs[k - s] >> s);
426 for j in 1..s {
427 if (poly_bits >> (s - 1 - j)) & 1 == 1 {
428 new_v ^= dirs[k - j];
429 }
430 }
431 dirs[k] = new_v;
432 }
433
434 all_dirs.push(dirs);
435 }
436
437 Ok(all_dirs)
438}
439
440fn halton_sampling(
449 n_samples: usize,
450 bounds: &[(f64, f64)],
451 config: &SamplingConfig,
452) -> OptimizeResult<Array2<f64>> {
453 let n_dims = bounds.len();
454 let primes = first_n_primes(n_dims);
455 let mut samples = Array2::zeros((n_samples, n_dims));
456
457 let shifts: Vec<f64> = if config.scramble {
459 let mut rng = make_rng(config.seed);
460 (0..n_dims).map(|_| rng.random_range(0.0..1.0)).collect()
461 } else {
462 vec![0.0; n_dims]
463 };
464
465 for i in 0..n_samples {
466 for j in 0..n_dims {
467 let raw = radical_inverse(i as u64 + 1, primes[j]);
468 let value = if config.scramble {
469 (raw + shifts[j]) % 1.0
470 } else {
471 raw
472 };
473 let (lo, hi) = bounds[j];
474 samples[[i, j]] = lo + value * (hi - lo);
475 }
476 }
477
478 Ok(samples)
479}
480
481fn radical_inverse(n: u64, base: u64) -> f64 {
486 let mut result = 0.0;
487 let mut denom = 1.0;
488 let mut val = n;
489
490 while val > 0 {
491 denom *= base as f64;
492 result += (val % base) as f64 / denom;
493 val /= base;
494 }
495 result
496}
497
498fn first_n_primes(n: usize) -> Vec<u64> {
500 if n == 0 {
501 return Vec::new();
502 }
503 let mut primes = Vec::with_capacity(n);
504 let mut candidate = 2u64;
505
506 while primes.len() < n {
507 let is_prime = primes
508 .iter()
509 .take_while(|&&p| p * p <= candidate)
510 .all(|&p| candidate % p != 0);
511 if is_prime {
512 primes.push(candidate);
513 }
514 candidate += 1;
515 }
516 primes
517}
518
519fn make_rng(seed: Option<u64>) -> StdRng {
524 match seed {
525 Some(s) => StdRng::seed_from_u64(s),
526 None => {
527 let s: u64 = scirs2_core::random::rng().random();
528 StdRng::seed_from_u64(s)
529 }
530 }
531}
532
533#[cfg(test)]
538mod tests {
539 use super::*;
540
541 fn bounds_2d() -> Vec<(f64, f64)> {
542 vec![(-5.0, 5.0), (0.0, 10.0)]
543 }
544
545 fn bounds_5d() -> Vec<(f64, f64)> {
546 vec![
547 (0.0, 1.0),
548 (-1.0, 1.0),
549 (0.0, 100.0),
550 (-10.0, 10.0),
551 (5.0, 15.0),
552 ]
553 }
554
555 #[test]
558 fn test_random_sampling_shape() {
559 let samples = generate_samples(20, &bounds_2d(), SamplingStrategy::Random, None)
560 .expect("should succeed");
561 assert_eq!(samples.nrows(), 20);
562 assert_eq!(samples.ncols(), 2);
563 }
564
565 #[test]
566 fn test_random_sampling_within_bounds() {
567 let b = bounds_2d();
568 let samples =
569 generate_samples(100, &b, SamplingStrategy::Random, None).expect("should succeed");
570 for i in 0..samples.nrows() {
571 for (j, &(lo, hi)) in b.iter().enumerate() {
572 assert!(
573 samples[[i, j]] >= lo && samples[[i, j]] <= hi,
574 "sample[{},{}] = {} not in [{}, {}]",
575 i,
576 j,
577 samples[[i, j]],
578 lo,
579 hi
580 );
581 }
582 }
583 }
584
585 #[test]
588 fn test_lhs_shape_and_bounds() {
589 let b = bounds_5d();
590 let samples = generate_samples(30, &b, SamplingStrategy::LatinHypercube, None)
591 .expect("should succeed");
592 assert_eq!(samples.nrows(), 30);
593 assert_eq!(samples.ncols(), 5);
594
595 for i in 0..samples.nrows() {
596 for (j, &(lo, hi)) in b.iter().enumerate() {
597 assert!(
598 samples[[i, j]] >= lo && samples[[i, j]] <= hi,
599 "LHS sample[{},{}] = {} not in [{}, {}]",
600 i,
601 j,
602 samples[[i, j]],
603 lo,
604 hi
605 );
606 }
607 }
608 }
609
610 #[test]
611 fn test_lhs_stratification() {
612 let n = 10;
614 let bounds = vec![(0.0, 1.0); 3];
615 let cfg = SamplingConfig {
616 lhs_maximin_iters: 0, seed: Some(42),
618 scramble: false,
619 };
620 let samples = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg))
621 .expect("should succeed");
622
623 for j in 0..3 {
624 let mut strata = vec![false; n];
625 for i in 0..n {
626 let stratum = (samples[[i, j]] * n as f64).floor() as usize;
627 let stratum = stratum.min(n - 1);
628 strata[stratum] = true;
629 }
630 for (s, &occupied) in strata.iter().enumerate() {
632 assert!(occupied, "Stratum {} in dimension {} is unoccupied", s, j);
633 }
634 }
635 }
636
637 #[test]
638 fn test_lhs_maximin_improves_spacing() {
639 let n = 15;
640 let bounds = vec![(0.0, 1.0); 2];
641
642 let cfg0 = SamplingConfig {
644 lhs_maximin_iters: 0,
645 seed: Some(123),
646 scramble: false,
647 };
648 let s0 = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg0))
649 .expect("should succeed");
650
651 let cfg1 = SamplingConfig {
653 lhs_maximin_iters: 500,
654 seed: Some(123),
655 scramble: false,
656 };
657 let s1 = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg1))
658 .expect("should succeed");
659
660 let d0 = compute_min_distance(&s0);
661 let d1 = compute_min_distance(&s1);
662
663 assert!(
665 d1 >= d0 - 1e-12,
666 "Maximin LHS should not decrease min distance: d_opt={} < d_raw={}",
667 d1,
668 d0
669 );
670 }
671
672 #[test]
675 fn test_sobol_shape_and_bounds() {
676 let b = bounds_2d();
677 let samples =
678 generate_samples(32, &b, SamplingStrategy::Sobol, None).expect("should succeed");
679 assert_eq!(samples.nrows(), 32);
680 assert_eq!(samples.ncols(), 2);
681
682 for i in 0..samples.nrows() {
683 for (j, &(lo, hi)) in b.iter().enumerate() {
684 assert!(
685 samples[[i, j]] >= lo && samples[[i, j]] <= hi,
686 "Sobol sample[{},{}] = {} not in [{}, {}]",
687 i,
688 j,
689 samples[[i, j]],
690 lo,
691 hi
692 );
693 }
694 }
695 }
696
697 #[test]
698 fn test_sobol_reproducibility() {
699 let b = bounds_2d();
700 let cfg = SamplingConfig {
701 seed: Some(99),
702 scramble: true,
703 ..Default::default()
704 };
705 let s1 = generate_samples(16, &b, SamplingStrategy::Sobol, Some(cfg.clone()))
706 .expect("should succeed");
707 let s2 =
708 generate_samples(16, &b, SamplingStrategy::Sobol, Some(cfg)).expect("should succeed");
709 assert_eq!(s1, s2);
710 }
711
712 #[test]
715 fn test_halton_shape_and_bounds() {
716 let b = bounds_5d();
717 let samples =
718 generate_samples(50, &b, SamplingStrategy::Halton, None).expect("should succeed");
719 assert_eq!(samples.nrows(), 50);
720 assert_eq!(samples.ncols(), 5);
721
722 for i in 0..samples.nrows() {
723 for (j, &(lo, hi)) in b.iter().enumerate() {
724 assert!(
725 samples[[i, j]] >= lo && samples[[i, j]] <= hi,
726 "Halton sample[{},{}] = {} not in [{}, {}]",
727 i,
728 j,
729 samples[[i, j]],
730 lo,
731 hi
732 );
733 }
734 }
735 }
736
737 #[test]
738 fn test_halton_low_discrepancy() {
739 let bounds = vec![(0.0, 1.0)];
742 let cfg = SamplingConfig {
743 seed: None,
744 scramble: false,
745 ..Default::default()
746 };
747 let samples = generate_samples(4, &bounds, SamplingStrategy::Halton, Some(cfg))
748 .expect("should succeed");
749
750 let expected = [0.5, 0.25, 0.75, 0.125];
751 for (i, &exp) in expected.iter().enumerate() {
752 assert!(
753 (samples[[i, 0]] - exp).abs() < 1e-10,
754 "Halton[{}] = {}, expected {}",
755 i,
756 samples[[i, 0]],
757 exp
758 );
759 }
760 }
761
762 #[test]
765 fn test_zero_samples() {
766 let samples = generate_samples(0, &bounds_2d(), SamplingStrategy::Random, None)
767 .expect("should succeed");
768 assert_eq!(samples.nrows(), 0);
769 }
770
771 #[test]
772 fn test_single_sample() {
773 for strategy in &[
774 SamplingStrategy::Random,
775 SamplingStrategy::LatinHypercube,
776 SamplingStrategy::Sobol,
777 SamplingStrategy::Halton,
778 ] {
779 let samples =
780 generate_samples(1, &bounds_2d(), *strategy, None).expect("should succeed");
781 assert_eq!(samples.nrows(), 1);
782 assert_eq!(samples.ncols(), 2);
783 }
784 }
785
786 #[test]
787 fn test_invalid_bounds_rejected() {
788 let result = generate_samples(10, &[(5.0, 5.0)], SamplingStrategy::Random, None);
790 assert!(result.is_err());
791
792 let result = generate_samples(
794 10,
795 &[(f64::NEG_INFINITY, 1.0)],
796 SamplingStrategy::Random,
797 None,
798 );
799 assert!(result.is_err());
800 }
801
802 #[test]
805 fn test_first_n_primes() {
806 let p = first_n_primes(10);
807 assert_eq!(p, vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29]);
808 }
809
810 #[test]
811 fn test_radical_inverse() {
812 assert!((radical_inverse(1, 2) - 0.5).abs() < 1e-15);
814 assert!((radical_inverse(2, 2) - 0.25).abs() < 1e-15);
816 assert!((radical_inverse(3, 2) - 0.75).abs() < 1e-15);
818 assert!((radical_inverse(1, 3) - 1.0 / 3.0).abs() < 1e-15);
820 }
821
822 #[test]
823 fn test_high_dimensional_sampling() {
824 let bounds: Vec<(f64, f64)> = (0..15).map(|_| (0.0, 1.0)).collect();
825 for strategy in &[
826 SamplingStrategy::Random,
827 SamplingStrategy::LatinHypercube,
828 SamplingStrategy::Sobol,
829 SamplingStrategy::Halton,
830 ] {
831 let samples = generate_samples(20, &bounds, *strategy, None).expect("should succeed");
832 assert_eq!(samples.nrows(), 20);
833 assert_eq!(samples.ncols(), 15);
834 }
835 }
836}