Skip to main content

scirs2_optimize/bayesian/
sampling.rs

1//! Sampling strategies for initial experimental design in Bayesian optimization.
2//!
3//! This module provides various sampling methods for generating initial points
4//! in the search space before the surrogate model takes over. Proper space-filling
5//! designs are critical for Bayesian optimization performance.
6//!
7//! # Available Methods
8//!
9//! - **Latin Hypercube Sampling (LHS)**: Stratified sampling with maximin optimization
10//! - **Sobol sequences**: Quasi-random low-discrepancy sequences
11//! - **Halton sequences**: Multi-dimensional quasi-random sequences using prime bases
12//! - **Random sampling**: Uniform random baseline
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::rngs::StdRng;
16use scirs2_core::random::{Rng, RngExt, SeedableRng};
17
18use crate::error::{OptimizeError, OptimizeResult};
19
20/// Strategy for generating initial sample points.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SamplingStrategy {
23    /// Uniform random sampling
24    Random,
25    /// Latin Hypercube Sampling with optional maximin optimization
26    LatinHypercube,
27    /// Sobol quasi-random sequence
28    Sobol,
29    /// Halton quasi-random sequence
30    Halton,
31}
32
33impl Default for SamplingStrategy {
34    fn default() -> Self {
35        Self::LatinHypercube
36    }
37}
38
39/// Configuration for sampling methods.
40#[derive(Debug, Clone)]
41pub struct SamplingConfig {
42    /// Number of maximin iterations for LHS optimization (default: 100)
43    pub lhs_maximin_iters: usize,
44    /// Random seed for reproducibility
45    pub seed: Option<u64>,
46    /// Scramble Sobol/Halton sequences for better uniformity
47    pub scramble: bool,
48}
49
50impl Default for SamplingConfig {
51    fn default() -> Self {
52        Self {
53            lhs_maximin_iters: 100,
54            seed: None,
55            scramble: true,
56        }
57    }
58}
59
60/// Generate sample points within given bounds.
61///
62/// # Arguments
63/// * `n_samples` - Number of points to generate
64/// * `bounds` - Lower and upper bounds for each dimension: `[(low, high), ...]`
65/// * `strategy` - Sampling strategy to use
66/// * `config` - Optional sampling configuration
67///
68/// # Returns
69/// A 2D array of shape `(n_samples, n_dims)` with sample points.
70pub fn generate_samples(
71    n_samples: usize,
72    bounds: &[(f64, f64)],
73    strategy: SamplingStrategy,
74    config: Option<SamplingConfig>,
75) -> OptimizeResult<Array2<f64>> {
76    let config = config.unwrap_or_default();
77    let n_dims = bounds.len();
78
79    if n_samples == 0 {
80        return Ok(Array2::zeros((0, n_dims)));
81    }
82    if n_dims == 0 {
83        return Err(OptimizeError::InvalidInput(
84            "Bounds must have at least one dimension".to_string(),
85        ));
86    }
87
88    // Validate bounds
89    for (i, &(lo, hi)) in bounds.iter().enumerate() {
90        if lo >= hi {
91            return Err(OptimizeError::InvalidInput(format!(
92                "Lower bound must be strictly less than upper bound for dimension {} (got [{}, {}])",
93                i, lo, hi
94            )));
95        }
96        if !lo.is_finite() || !hi.is_finite() {
97            return Err(OptimizeError::InvalidInput(format!(
98                "Bounds must be finite for dimension {} (got [{}, {}])",
99                i, lo, hi
100            )));
101        }
102    }
103
104    match strategy {
105        SamplingStrategy::Random => random_sampling(n_samples, bounds, &config),
106        SamplingStrategy::LatinHypercube => latin_hypercube_sampling(n_samples, bounds, &config),
107        SamplingStrategy::Sobol => sobol_sampling(n_samples, bounds, &config),
108        SamplingStrategy::Halton => halton_sampling(n_samples, bounds, &config),
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Random sampling
114// ---------------------------------------------------------------------------
115
116fn random_sampling(
117    n_samples: usize,
118    bounds: &[(f64, f64)],
119    config: &SamplingConfig,
120) -> OptimizeResult<Array2<f64>> {
121    let n_dims = bounds.len();
122    let mut rng = make_rng(config.seed);
123    let mut samples = Array2::zeros((n_samples, n_dims));
124
125    for i in 0..n_samples {
126        for (j, &(lo, hi)) in bounds.iter().enumerate() {
127            samples[[i, j]] = lo + rng.random_range(0.0..1.0) * (hi - lo);
128        }
129    }
130
131    Ok(samples)
132}
133
134// ---------------------------------------------------------------------------
135// Latin Hypercube Sampling with maximin optimization
136// ---------------------------------------------------------------------------
137
138/// Latin Hypercube Sampling (LHS) with maximin distance optimization.
139///
140/// LHS divides each dimension into `n_samples` equal strata and places exactly
141/// one sample in each stratum per dimension. The maximin optimization then
142/// iteratively swaps elements to maximize the minimum pairwise distance,
143/// yielding better space-filling properties.
144fn latin_hypercube_sampling(
145    n_samples: usize,
146    bounds: &[(f64, f64)],
147    config: &SamplingConfig,
148) -> OptimizeResult<Array2<f64>> {
149    let n_dims = bounds.len();
150    let mut rng = make_rng(config.seed);
151
152    // Step 1: Generate basic LHS in [0,1]^d
153    // For each dimension, create a random permutation of {0, 1, ..., n-1}
154    let mut unit_samples = Array2::zeros((n_samples, n_dims));
155
156    for j in 0..n_dims {
157        let mut perm: Vec<usize> = (0..n_samples).collect();
158        // Fisher-Yates shuffle
159        for i in (1..n_samples).rev() {
160            let swap_idx = rng.random_range(0..=i);
161            perm.swap(i, swap_idx);
162        }
163        for i in 0..n_samples {
164            // Place sample uniformly within its stratum
165            let u: f64 = rng.random_range(0.0..1.0);
166            unit_samples[[i, j]] = (perm[i] as f64 + u) / n_samples as f64;
167        }
168    }
169
170    // Step 2: Maximin optimization by column-wise pair swaps
171    if config.lhs_maximin_iters > 0 && n_samples > 2 {
172        let mut best_min_dist = compute_min_distance(&unit_samples);
173
174        for _ in 0..config.lhs_maximin_iters {
175            // Pick a random dimension
176            let dim = rng.random_range(0..n_dims);
177            // Pick two random rows
178            let r1 = rng.random_range(0..n_samples);
179            let mut r2 = rng.random_range(0..n_samples.saturating_sub(1));
180            if r2 >= r1 {
181                r2 += 1;
182            }
183
184            // Tentatively swap
185            let tmp = unit_samples[[r1, dim]];
186            unit_samples[[r1, dim]] = unit_samples[[r2, dim]];
187            unit_samples[[r2, dim]] = tmp;
188
189            let new_min_dist = compute_min_distance(&unit_samples);
190            if new_min_dist > best_min_dist {
191                best_min_dist = new_min_dist;
192            } else {
193                // Revert swap
194                let tmp = unit_samples[[r1, dim]];
195                unit_samples[[r1, dim]] = unit_samples[[r2, dim]];
196                unit_samples[[r2, dim]] = tmp;
197            }
198        }
199    }
200
201    // Step 3: Scale to bounds
202    let mut result = Array2::zeros((n_samples, n_dims));
203    for i in 0..n_samples {
204        for (j, &(lo, hi)) in bounds.iter().enumerate() {
205            result[[i, j]] = lo + unit_samples[[i, j]] * (hi - lo);
206        }
207    }
208
209    Ok(result)
210}
211
212/// Compute the minimum pairwise Euclidean distance in a sample set.
213fn compute_min_distance(samples: &Array2<f64>) -> f64 {
214    let n = samples.nrows();
215    if n < 2 {
216        return f64::INFINITY;
217    }
218    let mut min_dist = f64::INFINITY;
219    for i in 0..n {
220        for j in (i + 1)..n {
221            let mut sq_dist = 0.0;
222            for k in 0..samples.ncols() {
223                let d = samples[[i, k]] - samples[[j, k]];
224                sq_dist += d * d;
225            }
226            if sq_dist < min_dist {
227                min_dist = sq_dist;
228            }
229        }
230    }
231    min_dist.sqrt()
232}
233
234// ---------------------------------------------------------------------------
235// Sobol quasi-random sequence
236// ---------------------------------------------------------------------------
237
238/// Sobol sequence generator using direction numbers.
239///
240/// Implements the Joe-Kuo direction numbers for up to 21201 dimensions.
241/// Here we provide a compact implementation for the first several dimensions
242/// using hardcoded primitive polynomials and initial direction numbers.
243fn sobol_sampling(
244    n_samples: usize,
245    bounds: &[(f64, f64)],
246    config: &SamplingConfig,
247) -> OptimizeResult<Array2<f64>> {
248    let n_dims = bounds.len();
249    let mut samples = Array2::zeros((n_samples, n_dims));
250
251    // We need direction numbers for each dimension.
252    // Dimension 0 uses the Van der Corput sequence in base 2.
253    // Higher dimensions use Joe-Kuo direction numbers.
254    let direction_numbers = get_sobol_direction_numbers(n_dims)?;
255
256    for j in 0..n_dims {
257        let dirs = &direction_numbers[j];
258        let mut x: u64 = 0;
259        for i in 0..n_samples {
260            if j == 0 {
261                // Dimension 0: Van der Corput in base 2 using gray code
262                x = gray_code_sobol(i as u64 + 1);
263            } else {
264                // Use direction numbers with gray code enumeration
265                if i == 0 {
266                    x = 0;
267                } else {
268                    // Find the rightmost zero bit of i
269                    let c = rightmost_zero_bit(i as u64);
270                    let dir_idx = c.min(dirs.len() - 1);
271                    x ^= dirs[dir_idx];
272                }
273            }
274
275            let value = x as f64 / (1u64 << 32) as f64;
276
277            // Optional scramble: Owen's scrambling approximation via random shift
278            let scrambled = if config.scramble {
279                let mut rng = make_rng(config.seed.map(|s| s.wrapping_add(j as u64 * 1000 + 7)));
280                let shift: f64 = rng.random_range(0.0..1.0);
281                (value + shift) % 1.0
282            } else {
283                value
284            };
285
286            let (lo, hi) = bounds[j];
287            samples[[i, j]] = lo + scrambled * (hi - lo);
288        }
289    }
290
291    Ok(samples)
292}
293
294/// Gray code based Sobol index for dimension 0.
295fn gray_code_sobol(n: u64) -> u64 {
296    // For dimension 0, Sobol sequence is just the bit-reversed fraction.
297    // We use 32-bit precision.
298    let mut result: u64 = 0;
299    let mut val = n;
300    let mut bit = 1u64 << 31;
301    while val > 0 {
302        if val & 1 != 0 {
303            result ^= bit;
304        }
305        val >>= 1;
306        bit >>= 1;
307    }
308    result
309}
310
311/// Find the index of the rightmost zero bit (0-indexed).
312fn rightmost_zero_bit(n: u64) -> usize {
313    let mut val = n;
314    let mut c = 0usize;
315    while val & 1 != 0 {
316        val >>= 1;
317        c += 1;
318    }
319    c
320}
321
322/// Get Sobol direction numbers for up to `n_dims` dimensions.
323///
324/// Dimension 0 is handled by the Van der Corput sequence.
325/// For dimensions 1..n_dims, we use hardcoded Joe-Kuo direction numbers
326/// for the first 20 dimensions, and fall back to a deterministic
327/// construction for higher dimensions.
328fn get_sobol_direction_numbers(n_dims: usize) -> OptimizeResult<Vec<Vec<u64>>> {
329    // Maximum bits of precision (32-bit)
330    let max_bits = 32usize;
331
332    let mut all_dirs = Vec::with_capacity(n_dims);
333
334    // Dimension 0: placeholder (handled specially)
335    all_dirs.push(vec![0u64; max_bits]);
336
337    if n_dims <= 1 {
338        return Ok(all_dirs);
339    }
340
341    // Primitive polynomials (degree, polynomial coefficients as bits)
342    // These are from the Joe-Kuo tables.
343    // Format: (degree, poly_coeffs_bits)
344    // The polynomial x^s + c_{s-1}*x^{s-1} + ... + c_1*x + 1
345    // is stored as the integer with bit pattern c_{s-1}...c_1
346    let primitive_polys: &[(u32, u32)] = &[
347        (1, 0),  // x + 1
348        (2, 1),  // x^2 + x + 1
349        (3, 1),  // x^3 + x + 1
350        (3, 2),  // x^3 + x^2 + 1
351        (4, 1),  // x^4 + x + 1
352        (4, 4),  // x^4 + x^3 + 1
353        (5, 2),  // x^5 + x^2 + 1
354        (5, 4),  // x^5 + x^3 + 1
355        (5, 7),  // x^5 + x^3 + x^2 + x + 1
356        (5, 11), // x^5 + x^4 + x^2 + x + 1
357        (5, 13), // x^5 + x^4 + x^3 + x + 1
358        (5, 14), // x^5 + x^4 + x^3 + x^2 + 1
359        (6, 1),  // x^6 + x + 1
360        (6, 13), // x^6 + x^4 + x^3 + x + 1
361        (6, 16), // x^6 + x^5 + 1
362        (6, 19), // x^6 + x^5 + x^2 + x + 1
363        (6, 22), // x^6 + x^5 + x^3 + x^2 + 1
364        (6, 25), // x^6 + x^5 + x^4 + x + 1
365        (7, 1),  // x^7 + x + 1
366        (7, 4),  // x^7 + x^3 + 1
367    ];
368
369    // Initial direction numbers (m_i values, 1-indexed) from Joe-Kuo
370    // Each row corresponds to a dimension (starting from dimension 1).
371    // The first `degree` values are the initial direction numbers.
372    let initial_m: &[&[u64]] = &[
373        &[1],                   // dim 1
374        &[1, 1],                // dim 2
375        &[1, 1, 1],             // dim 3
376        &[1, 3, 1],             // dim 4
377        &[1, 1, 1, 1],          // dim 5
378        &[1, 1, 3, 1],          // dim 6
379        &[1, 3, 5, 1, 3],       // dim 7
380        &[1, 3, 3, 1, 1],       // dim 8
381        &[1, 3, 7, 7, 5],       // dim 9
382        &[1, 1, 5, 1, 15],      // dim 10
383        &[1, 3, 1, 3, 5],       // dim 11
384        &[1, 3, 7, 7, 5],       // dim 12
385        &[1, 1, 1, 1, 1, 1],    // dim 13
386        &[1, 1, 5, 3, 13, 7],   // dim 14
387        &[1, 3, 3, 1, 1, 1],    // dim 15
388        &[1, 1, 1, 5, 7, 11],   // dim 16
389        &[1, 1, 7, 3, 29, 3],   // dim 17
390        &[1, 3, 7, 7, 21, 25],  // dim 18
391        &[1, 1, 1, 1, 1, 1, 1], // dim 19
392        &[1, 3, 1, 1, 1, 7, 1], // dim 20
393    ];
394
395    for dim_idx in 1..n_dims {
396        let poly_idx = if dim_idx - 1 < primitive_polys.len() {
397            dim_idx - 1
398        } else {
399            (dim_idx - 1) % primitive_polys.len()
400        };
401
402        let (degree, poly_bits) = primitive_polys[poly_idx];
403        let s = degree as usize;
404
405        let mut dirs = vec![0u64; max_bits];
406
407        // Set initial direction numbers
408        let init = if dim_idx - 1 < initial_m.len() {
409            initial_m[dim_idx - 1]
410        } else {
411            // Fall back to all-ones
412            &[1u64; 1][..] // Will be extended below
413        };
414
415        for k in 0..s.min(max_bits) {
416            let m_k = if k < init.len() { init[k] } else { 1 };
417            // Direction number v_k = m_k * 2^(32 - k - 1)
418            dirs[k] = m_k << (max_bits - k - 1);
419        }
420
421        // Generate remaining direction numbers using the recurrence:
422        // v_k = c_1 * v_{k-1} XOR c_2 * v_{k-2} XOR ... XOR c_{s-1} * v_{k-s+1}
423        //       XOR v_{k-s} XOR (v_{k-s} >> s)
424        for k in s..max_bits {
425            let mut new_v = dirs[k - s] ^ (dirs[k - s] >> s);
426            for j in 1..s {
427                if (poly_bits >> (s - 1 - j)) & 1 == 1 {
428                    new_v ^= dirs[k - j];
429                }
430            }
431            dirs[k] = new_v;
432        }
433
434        all_dirs.push(dirs);
435    }
436
437    Ok(all_dirs)
438}
439
440// ---------------------------------------------------------------------------
441// Halton quasi-random sequence
442// ---------------------------------------------------------------------------
443
444/// Halton sequence using prime bases per dimension.
445///
446/// The Halton sequence is a generalization of the Van der Corput sequence
447/// to multiple dimensions, using a different prime base for each dimension.
448fn halton_sampling(
449    n_samples: usize,
450    bounds: &[(f64, f64)],
451    config: &SamplingConfig,
452) -> OptimizeResult<Array2<f64>> {
453    let n_dims = bounds.len();
454    let primes = first_n_primes(n_dims);
455    let mut samples = Array2::zeros((n_samples, n_dims));
456
457    // Optional random shift for scrambling
458    let shifts: Vec<f64> = if config.scramble {
459        let mut rng = make_rng(config.seed);
460        (0..n_dims).map(|_| rng.random_range(0.0..1.0)).collect()
461    } else {
462        vec![0.0; n_dims]
463    };
464
465    for i in 0..n_samples {
466        for j in 0..n_dims {
467            let raw = radical_inverse(i as u64 + 1, primes[j]);
468            let value = if config.scramble {
469                (raw + shifts[j]) % 1.0
470            } else {
471                raw
472            };
473            let (lo, hi) = bounds[j];
474            samples[[i, j]] = lo + value * (hi - lo);
475        }
476    }
477
478    Ok(samples)
479}
480
481/// Compute the radical inverse of `n` in the given `base`.
482///
483/// The radical inverse is the fraction formed by reflecting the digits of `n`
484/// about the decimal point in the given base.
485fn radical_inverse(n: u64, base: u64) -> f64 {
486    let mut result = 0.0;
487    let mut denom = 1.0;
488    let mut val = n;
489
490    while val > 0 {
491        denom *= base as f64;
492        result += (val % base) as f64 / denom;
493        val /= base;
494    }
495    result
496}
497
498/// Return the first `n` prime numbers.
499fn first_n_primes(n: usize) -> Vec<u64> {
500    if n == 0 {
501        return Vec::new();
502    }
503    let mut primes = Vec::with_capacity(n);
504    let mut candidate = 2u64;
505
506    while primes.len() < n {
507        let is_prime = primes
508            .iter()
509            .take_while(|&&p| p * p <= candidate)
510            .all(|&p| candidate % p != 0);
511        if is_prime {
512            primes.push(candidate);
513        }
514        candidate += 1;
515    }
516    primes
517}
518
519// ---------------------------------------------------------------------------
520// Helpers
521// ---------------------------------------------------------------------------
522
523fn make_rng(seed: Option<u64>) -> StdRng {
524    match seed {
525        Some(s) => StdRng::seed_from_u64(s),
526        None => {
527            let s: u64 = scirs2_core::random::rng().random();
528            StdRng::seed_from_u64(s)
529        }
530    }
531}
532
533// ---------------------------------------------------------------------------
534// Tests
535// ---------------------------------------------------------------------------
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    fn bounds_2d() -> Vec<(f64, f64)> {
542        vec![(-5.0, 5.0), (0.0, 10.0)]
543    }
544
545    fn bounds_5d() -> Vec<(f64, f64)> {
546        vec![
547            (0.0, 1.0),
548            (-1.0, 1.0),
549            (0.0, 100.0),
550            (-10.0, 10.0),
551            (5.0, 15.0),
552        ]
553    }
554
555    // ---- random sampling ----
556
557    #[test]
558    fn test_random_sampling_shape() {
559        let samples = generate_samples(20, &bounds_2d(), SamplingStrategy::Random, None)
560            .expect("should succeed");
561        assert_eq!(samples.nrows(), 20);
562        assert_eq!(samples.ncols(), 2);
563    }
564
565    #[test]
566    fn test_random_sampling_within_bounds() {
567        let b = bounds_2d();
568        let samples =
569            generate_samples(100, &b, SamplingStrategy::Random, None).expect("should succeed");
570        for i in 0..samples.nrows() {
571            for (j, &(lo, hi)) in b.iter().enumerate() {
572                assert!(
573                    samples[[i, j]] >= lo && samples[[i, j]] <= hi,
574                    "sample[{},{}] = {} not in [{}, {}]",
575                    i,
576                    j,
577                    samples[[i, j]],
578                    lo,
579                    hi
580                );
581            }
582        }
583    }
584
585    // ---- LHS ----
586
587    #[test]
588    fn test_lhs_shape_and_bounds() {
589        let b = bounds_5d();
590        let samples = generate_samples(30, &b, SamplingStrategy::LatinHypercube, None)
591            .expect("should succeed");
592        assert_eq!(samples.nrows(), 30);
593        assert_eq!(samples.ncols(), 5);
594
595        for i in 0..samples.nrows() {
596            for (j, &(lo, hi)) in b.iter().enumerate() {
597                assert!(
598                    samples[[i, j]] >= lo && samples[[i, j]] <= hi,
599                    "LHS sample[{},{}] = {} not in [{}, {}]",
600                    i,
601                    j,
602                    samples[[i, j]],
603                    lo,
604                    hi
605                );
606            }
607        }
608    }
609
610    #[test]
611    fn test_lhs_stratification() {
612        // Each dimension should have exactly one sample per stratum.
613        let n = 10;
614        let bounds = vec![(0.0, 1.0); 3];
615        let cfg = SamplingConfig {
616            lhs_maximin_iters: 0, // No optimization, raw LHS
617            seed: Some(42),
618            scramble: false,
619        };
620        let samples = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg))
621            .expect("should succeed");
622
623        for j in 0..3 {
624            let mut strata = vec![false; n];
625            for i in 0..n {
626                let stratum = (samples[[i, j]] * n as f64).floor() as usize;
627                let stratum = stratum.min(n - 1);
628                strata[stratum] = true;
629            }
630            // Every stratum should be occupied
631            for (s, &occupied) in strata.iter().enumerate() {
632                assert!(occupied, "Stratum {} in dimension {} is unoccupied", s, j);
633            }
634        }
635    }
636
637    #[test]
638    fn test_lhs_maximin_improves_spacing() {
639        let n = 15;
640        let bounds = vec![(0.0, 1.0); 2];
641
642        // Without maximin
643        let cfg0 = SamplingConfig {
644            lhs_maximin_iters: 0,
645            seed: Some(123),
646            scramble: false,
647        };
648        let s0 = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg0))
649            .expect("should succeed");
650
651        // With maximin
652        let cfg1 = SamplingConfig {
653            lhs_maximin_iters: 500,
654            seed: Some(123),
655            scramble: false,
656        };
657        let s1 = generate_samples(n, &bounds, SamplingStrategy::LatinHypercube, Some(cfg1))
658            .expect("should succeed");
659
660        let d0 = compute_min_distance(&s0);
661        let d1 = compute_min_distance(&s1);
662
663        // Maximin should give equal or better minimum distance
664        assert!(
665            d1 >= d0 - 1e-12,
666            "Maximin LHS should not decrease min distance: d_opt={} < d_raw={}",
667            d1,
668            d0
669        );
670    }
671
672    // ---- Sobol ----
673
674    #[test]
675    fn test_sobol_shape_and_bounds() {
676        let b = bounds_2d();
677        let samples =
678            generate_samples(32, &b, SamplingStrategy::Sobol, None).expect("should succeed");
679        assert_eq!(samples.nrows(), 32);
680        assert_eq!(samples.ncols(), 2);
681
682        for i in 0..samples.nrows() {
683            for (j, &(lo, hi)) in b.iter().enumerate() {
684                assert!(
685                    samples[[i, j]] >= lo && samples[[i, j]] <= hi,
686                    "Sobol sample[{},{}] = {} not in [{}, {}]",
687                    i,
688                    j,
689                    samples[[i, j]],
690                    lo,
691                    hi
692                );
693            }
694        }
695    }
696
697    #[test]
698    fn test_sobol_reproducibility() {
699        let b = bounds_2d();
700        let cfg = SamplingConfig {
701            seed: Some(99),
702            scramble: true,
703            ..Default::default()
704        };
705        let s1 = generate_samples(16, &b, SamplingStrategy::Sobol, Some(cfg.clone()))
706            .expect("should succeed");
707        let s2 =
708            generate_samples(16, &b, SamplingStrategy::Sobol, Some(cfg)).expect("should succeed");
709        assert_eq!(s1, s2);
710    }
711
712    // ---- Halton ----
713
714    #[test]
715    fn test_halton_shape_and_bounds() {
716        let b = bounds_5d();
717        let samples =
718            generate_samples(50, &b, SamplingStrategy::Halton, None).expect("should succeed");
719        assert_eq!(samples.nrows(), 50);
720        assert_eq!(samples.ncols(), 5);
721
722        for i in 0..samples.nrows() {
723            for (j, &(lo, hi)) in b.iter().enumerate() {
724                assert!(
725                    samples[[i, j]] >= lo && samples[[i, j]] <= hi,
726                    "Halton sample[{},{}] = {} not in [{}, {}]",
727                    i,
728                    j,
729                    samples[[i, j]],
730                    lo,
731                    hi
732                );
733            }
734        }
735    }
736
737    #[test]
738    fn test_halton_low_discrepancy() {
739        // Halton in 1D with base 2 should produce Van der Corput sequence:
740        // 1/2, 1/4, 3/4, 1/8, 5/8, 3/8, 7/8, ...
741        let bounds = vec![(0.0, 1.0)];
742        let cfg = SamplingConfig {
743            seed: None,
744            scramble: false,
745            ..Default::default()
746        };
747        let samples = generate_samples(4, &bounds, SamplingStrategy::Halton, Some(cfg))
748            .expect("should succeed");
749
750        let expected = [0.5, 0.25, 0.75, 0.125];
751        for (i, &exp) in expected.iter().enumerate() {
752            assert!(
753                (samples[[i, 0]] - exp).abs() < 1e-10,
754                "Halton[{}] = {}, expected {}",
755                i,
756                samples[[i, 0]],
757                exp
758            );
759        }
760    }
761
762    // ---- edge cases ----
763
764    #[test]
765    fn test_zero_samples() {
766        let samples = generate_samples(0, &bounds_2d(), SamplingStrategy::Random, None)
767            .expect("should succeed");
768        assert_eq!(samples.nrows(), 0);
769    }
770
771    #[test]
772    fn test_single_sample() {
773        for strategy in &[
774            SamplingStrategy::Random,
775            SamplingStrategy::LatinHypercube,
776            SamplingStrategy::Sobol,
777            SamplingStrategy::Halton,
778        ] {
779            let samples =
780                generate_samples(1, &bounds_2d(), *strategy, None).expect("should succeed");
781            assert_eq!(samples.nrows(), 1);
782            assert_eq!(samples.ncols(), 2);
783        }
784    }
785
786    #[test]
787    fn test_invalid_bounds_rejected() {
788        // lo >= hi
789        let result = generate_samples(10, &[(5.0, 5.0)], SamplingStrategy::Random, None);
790        assert!(result.is_err());
791
792        // infinite
793        let result = generate_samples(
794            10,
795            &[(f64::NEG_INFINITY, 1.0)],
796            SamplingStrategy::Random,
797            None,
798        );
799        assert!(result.is_err());
800    }
801
802    // ---- prime generation ----
803
804    #[test]
805    fn test_first_n_primes() {
806        let p = first_n_primes(10);
807        assert_eq!(p, vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29]);
808    }
809
810    #[test]
811    fn test_radical_inverse() {
812        // radical_inverse(1, 2) = 0.5
813        assert!((radical_inverse(1, 2) - 0.5).abs() < 1e-15);
814        // radical_inverse(2, 2) = 0.25
815        assert!((radical_inverse(2, 2) - 0.25).abs() < 1e-15);
816        // radical_inverse(3, 2) = 0.75
817        assert!((radical_inverse(3, 2) - 0.75).abs() < 1e-15);
818        // radical_inverse(1, 3) = 1/3
819        assert!((radical_inverse(1, 3) - 1.0 / 3.0).abs() < 1e-15);
820    }
821
822    #[test]
823    fn test_high_dimensional_sampling() {
824        let bounds: Vec<(f64, f64)> = (0..15).map(|_| (0.0, 1.0)).collect();
825        for strategy in &[
826            SamplingStrategy::Random,
827            SamplingStrategy::LatinHypercube,
828            SamplingStrategy::Sobol,
829            SamplingStrategy::Halton,
830        ] {
831            let samples = generate_samples(20, &bounds, *strategy, None).expect("should succeed");
832            assert_eq!(samples.nrows(), 20);
833            assert_eq!(samples.ncols(), 15);
834        }
835    }
836}