1use crate::error::{StatsError, StatsResult as Result};
7use crate::error_handling_v2::ErrorCode;
8use crate::unified_error_handling::global_error_handler;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
11use scirs2_core::validation::*;
12use statrs::statistics::Statistics;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum QMCSequenceType {
18 Sobol,
20 Halton,
22 Niederreiter,
24 Faure,
26 GeneralizedHalton,
28 OptimalLHS,
30}
31
32#[derive(Debug, Clone)]
34pub struct StratifiedSamplingConfig {
35 pub strata_per_dimension: usize,
37 pub intra_stratum_method: IntraStratumMethod,
39 pub proportional_allocation: bool,
41 pub min_samples_per_stratum: usize,
43 pub adaptive_refinement: bool,
45 pub refinement_threshold: f64,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
51pub enum IntraStratumMethod {
52 Random,
54 Centroid,
56 QMC(QMCSequenceType),
58 Antithetic,
60}
61
62impl Default for StratifiedSamplingConfig {
63 fn default() -> Self {
64 Self {
65 strata_per_dimension: 4,
66 intra_stratum_method: IntraStratumMethod::Random,
67 proportional_allocation: false,
68 min_samples_per_stratum: 1,
69 adaptive_refinement: false,
70 refinement_threshold: 0.01,
71 }
72 }
73}
74
75pub struct AdvancedQMCGenerator {
77 sequence_type: QMCSequenceType,
78 dimension: usize,
79 scramble: bool,
80 seed: Option<u64>,
81 current_index: usize,
82 generator_state: QMCGeneratorState,
83}
84
85#[derive(Debug)]
87enum QMCGeneratorState {
88 Sobol(SobolState),
89 Halton(HaltonState),
90 Niederreiter(NiederreiterState),
91 Faure(FaureState),
92 GeneralizedHalton(GeneralizedHaltonState),
93 OptimalLHS(OptimalLHSState),
94}
95
96#[derive(Debug)]
97struct SobolState {
98 direction_numbers: Vec<Vec<u64>>,
99 #[allow(dead_code)]
100 scramble_matrices: Option<Vec<Array2<u32>>>,
101}
102
103impl SobolState {
104 pub fn new(dimension: usize) -> Result<Self> {
106 let direction_numbers = Self::init_direction_numbers(dimension)?;
107 Ok(Self {
108 direction_numbers,
109 scramble_matrices: None,
110 })
111 }
112
113 fn init_direction_numbers(dimension: usize) -> Result<Vec<Vec<u64>>> {
115 let mut direction_numbers = vec![vec![0u64; 32]; dimension];
116
117 for i in 0..32 {
119 direction_numbers[0][i] = 1u64 << (63 - i);
120 }
121
122 for dim in 1..dimension {
124 for i in 0..32 {
125 direction_numbers[dim][i] = 1u64 << (63 - i);
126 }
127 }
128
129 Ok(direction_numbers)
130 }
131}
132
133#[derive(Debug)]
134struct HaltonState {
135 bases: Vec<u32>,
136 #[allow(dead_code)]
137 permutations: Option<Vec<Vec<u32>>>,
138}
139
140#[derive(Debug)]
141struct NiederreiterState {
142 generating_matrices: Vec<Array2<u32>>,
143 #[allow(dead_code)]
144 polynomial_coefficients: Vec<Vec<u32>>,
145}
146
147#[derive(Debug)]
148struct FaureState {
149 base: u32,
150 #[allow(dead_code)]
151 permutation_matrices: Vec<Array2<u32>>,
152}
153
154#[derive(Debug)]
155struct GeneralizedHaltonState {
156 bases: Vec<u32>,
157 #[allow(dead_code)]
158 leap_values: Vec<usize>,
159 #[allow(dead_code)]
160 generalized_permutations: Vec<Vec<u32>>,
161}
162
163#[derive(Debug)]
164struct OptimalLHSState {
165 rng: StdRng,
166 #[allow(dead_code)]
167 correlation_matrix: Option<Array2<f64>>,
168}
169
170impl AdvancedQMCGenerator {
171 pub fn new(
173 sequence_type: QMCSequenceType,
174 dimension: usize,
175 scramble: bool,
176 seed: Option<u64>,
177 ) -> Result<Self> {
178 let handler = global_error_handler();
179
180 if dimension == 0 {
181 return Err(handler
182 .create_validation_error(
183 ErrorCode::E1001,
184 "AdvancedQMCGenerator::new",
185 "dimension",
186 dimension,
187 "Dimension must be positive",
188 )
189 .error);
190 }
191
192 let max_dim = match sequence_type {
193 QMCSequenceType::Sobol => 21201, QMCSequenceType::Halton => 1000,
195 QMCSequenceType::Niederreiter => 100,
196 QMCSequenceType::Faure => 50,
197 QMCSequenceType::GeneralizedHalton => 500,
198 QMCSequenceType::OptimalLHS => 1000,
199 };
200
201 if dimension > max_dim {
202 return Err(handler
203 .create_validation_error(
204 ErrorCode::E1001,
205 "AdvancedQMCGenerator::new",
206 "dimension",
207 format!("{} (max: {})", dimension, max_dim),
208 format!(
209 "{:?} sequence supports up to {} dimensions",
210 sequence_type, max_dim
211 ),
212 )
213 .error);
214 }
215
216 let generator_state = match sequence_type {
217 QMCSequenceType::Sobol => {
218 QMCGeneratorState::Sobol(Self::init_sobol_state(dimension, scramble, seed)?)
219 }
220 QMCSequenceType::Halton => {
221 QMCGeneratorState::Halton(Self::init_halton_state(dimension, scramble, seed)?)
222 }
223 QMCSequenceType::Niederreiter => {
224 QMCGeneratorState::Niederreiter(Self::init_niederreiter_state(dimension, seed)?)
225 }
226 QMCSequenceType::Faure => {
227 QMCGeneratorState::Faure(Self::init_faure_state(dimension, seed)?)
228 }
229 QMCSequenceType::GeneralizedHalton => QMCGeneratorState::GeneralizedHalton(
230 Self::init_generalized_halton_state(dimension, seed)?,
231 ),
232 QMCSequenceType::OptimalLHS => {
233 QMCGeneratorState::OptimalLHS(Self::init_optimal_lhs_state(dimension, seed)?)
234 }
235 };
236
237 Ok(Self {
238 sequence_type,
239 dimension,
240 scramble,
241 seed,
242 current_index: 0,
243 generator_state,
244 })
245 }
246
247 pub fn generate(&mut self, n: usize) -> Result<Array2<f64>> {
249 check_positive(n, "n")?;
250
251 let mut samples = Array2::zeros((n, self.dimension));
252
253 for i in 0..n {
254 let point = self.next_point()?;
255 for (j, &val) in point.iter().enumerate() {
256 samples[[i, j]] = val;
257 }
258 }
259
260 Ok(samples)
261 }
262
263 pub fn next_point(&mut self) -> Result<Array1<f64>> {
265 use std::mem;
266
267 let mut temp_state = mem::replace(
269 &mut self.generator_state,
270 QMCGeneratorState::Sobol(SobolState::new(1).expect("Operation failed")),
271 );
272
273 let point = match &mut temp_state {
274 QMCGeneratorState::Sobol(state) => {
275 Self::next_sobol_point_static(self.dimension, self.current_index, state)?
276 }
277 QMCGeneratorState::Halton(state) => {
278 Self::next_halton_point_static(self.dimension, self.current_index, state)?
279 }
280 QMCGeneratorState::Niederreiter(state) => {
281 Self::next_niederreiter_point_static(self.dimension, self.current_index, state)?
282 }
283 QMCGeneratorState::Faure(state) => {
284 Self::next_faure_point_static(self.dimension, self.current_index, state)?
285 }
286 QMCGeneratorState::GeneralizedHalton(state) => {
287 Self::next_generalized_halton_point_static(
288 self.dimension,
289 self.current_index,
290 state,
291 )?
292 }
293 QMCGeneratorState::OptimalLHS(state) => {
294 Self::next_optimal_lhs_point_static(self.dimension, self.current_index, state)?
295 }
296 };
297
298 self.generator_state = temp_state;
300 self.current_index += 1;
301 Ok(point)
302 }
303
304 fn init_sobol_state(
306 _dimension: usize,
307 scramble: bool,
308 seed: Option<u64>,
309 ) -> Result<SobolState> {
310 let direction_numbers = Self::load_joe_kuo_direction_numbers(_dimension)?;
312
313 let scramble_matrices = if scramble {
314 Some(Self::generate_digital_shift_matrices(_dimension, seed)?)
315 } else {
316 None
317 };
318
319 Ok(SobolState {
320 direction_numbers,
321 scramble_matrices,
322 })
323 }
324
325 fn init_halton_state(
327 dimension: usize,
328 scramble: bool,
329 seed: Option<u64>,
330 ) -> Result<HaltonState> {
331 let bases = Self::first_primes(dimension)?;
332
333 let permutations = if scramble {
334 Some(Self::generate_faure_tezuka_permutations(&bases, seed)?)
335 } else {
336 None
337 };
338
339 Ok(HaltonState {
340 bases,
341 permutations,
342 })
343 }
344
345 fn init_niederreiter_state(dimension: usize, seed: Option<u64>) -> Result<NiederreiterState> {
347 let generating_matrices = Self::generate_niederreiter_matrices(dimension)?;
348 let polynomial_coefficients = Self::get_primitive_polynomials(dimension)?;
349
350 Ok(NiederreiterState {
351 generating_matrices,
352 polynomial_coefficients,
353 })
354 }
355
356 fn init_faure_state(dimension: usize, seed: Option<u64>) -> Result<FaureState> {
358 let base = Self::smallest_prime_geq(dimension as u32)?;
359 let permutation_matrices = Self::generate_faure_permutations(dimension, base, seed)?;
360
361 Ok(FaureState {
362 base,
363 permutation_matrices,
364 })
365 }
366
367 fn init_generalized_halton_state(
369 dimension: usize,
370 seed: Option<u64>,
371 ) -> Result<GeneralizedHaltonState> {
372 let bases = Self::first_primes(dimension)?;
373 let leap_values = Self::compute_optimal_leap_values(&bases);
374 let generalized_permutations = Self::generate_generalized_permutations(&bases, seed)?;
375
376 Ok(GeneralizedHaltonState {
377 bases,
378 leap_values,
379 generalized_permutations,
380 })
381 }
382
383 fn init_optimal_lhs_state(dimension: usize, seed: Option<u64>) -> Result<OptimalLHSState> {
385 let rng = match seed {
386 Some(s) => StdRng::seed_from_u64(s),
387 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
388 };
389
390 Ok(OptimalLHSState {
391 rng,
392 correlation_matrix: None,
393 })
394 }
395
396 fn next_sobol_point_static(
398 dimension: usize,
399 current_index: usize,
400 state: &SobolState,
401 ) -> Result<Array1<f64>> {
402 let mut point = Array1::zeros(dimension);
403
404 for dim in 0..dimension {
405 let mut result = 0u64;
406 let _index = current_index;
407
408 let gray_code = _index ^ (_index >> 1);
410
411 for bit in 0..32 {
412 if (gray_code >> bit) & 1 == 1 {
413 result ^= state.direction_numbers[dim][bit];
414 }
415 }
416
417 if let Some(ref matrices) = state.scramble_matrices {
419 result = Self::apply_digital_shift(result, &matrices[dim]);
420 }
421
422 point[dim] = result as f64 / (1u64 << 32) as f64;
423 }
424
425 Ok(point)
426 }
427
428 fn next_halton_point_static(
430 dimension: usize,
431 current_index: usize,
432 state: &HaltonState,
433 ) -> Result<Array1<f64>> {
434 let mut point = Array1::zeros(dimension);
435
436 for dim in 0..dimension {
437 let base = state.bases[dim];
438 let value = if let Some(ref perms) = state.permutations {
439 Self::scrambled_radical_inverse(current_index, base, &perms[dim])?
440 } else {
441 Self::radical_inverse(current_index, base)?
442 };
443 point[dim] = value;
444 }
445
446 Ok(point)
447 }
448
449 fn next_niederreiter_point_static(
451 dimension: usize,
452 current_index: usize,
453 state: &NiederreiterState,
454 ) -> Result<Array1<f64>> {
455 let mut point = Array1::zeros(dimension);
456
457 for dim in 0..dimension {
458 let matrix = &state.generating_matrices[dim];
459 let mut result = 0u32;
460 let mut _index = current_index;
461
462 for i in 0..32 {
463 if _index & 1 == 1 {
464 for j in 0..32 {
465 result ^= matrix[[i, j]];
466 }
467 }
468 _index >>= 1;
469 if _index == 0 {
470 break;
471 }
472 }
473
474 point[dim] = result as f64 / (1u64 << 32) as f64;
475 }
476
477 Ok(point)
478 }
479
480 fn next_faure_point_static(
482 dimension: usize,
483 current_index: usize,
484 state: &FaureState,
485 ) -> Result<Array1<f64>> {
486 let mut point = Array1::zeros(dimension);
487 let base = state.base;
488
489 let base_value = Self::radical_inverse(current_index, base)?;
491 point[0] = base_value;
492
493 for dim in 1..dimension {
495 let power = (dim as f64 * base_value).fract();
496 point[dim] = power;
497 }
498
499 Ok(point)
500 }
501
502 fn next_generalized_halton_point_static(
504 dimension: usize,
505 current_index: usize,
506 state: &GeneralizedHaltonState,
507 ) -> Result<Array1<f64>> {
508 let mut point = Array1::zeros(dimension);
509
510 for dim in 0..dimension {
511 let base = state.bases[dim];
512 let leap = state.leap_values[dim];
513 let effective_index = (current_index * leap) % (base.pow(10) as usize); let value = Self::scrambled_radical_inverse(
516 effective_index,
517 base,
518 &state.generalized_permutations[dim],
519 )?;
520 point[dim] = value;
521 }
522
523 Ok(point)
524 }
525
526 fn next_optimal_lhs_point_static(
528 dimension: usize,
529 current_index: usize,
530 state: &mut OptimalLHSState,
531 ) -> Result<Array1<f64>> {
532 let mut point = Array1::zeros(dimension);
533
534 if let Some(ref corr_matrix) = state.correlation_matrix {
536 let chol = scirs2_linalg::cholesky(&corr_matrix.view(), None).map_err(|e| {
538 StatsError::ComputationError(format!("Cholesky decomposition failed: {}", e))
539 })?;
540
541 let mut uniform = Array1::zeros(dimension);
542 for i in 0..dimension {
543 uniform[i] = scirs2_core::random::thread_rng().random::<f64>();
544 }
545
546 let normal = uniform.mapv(|u| {
548 if u <= 0.5 {
550 -(-2.0 * u.ln()).sqrt()
551 * (2.0
552 * std::f64::consts::PI
553 * scirs2_core::random::thread_rng().random::<f64>())
554 .cos()
555 } else {
556 (-2.0 * (1.0 - u).ln()).sqrt()
557 * (2.0
558 * std::f64::consts::PI
559 * scirs2_core::random::thread_rng().random::<f64>())
560 .cos()
561 }
562 });
563
564 let corr_normal = chol.dot(&normal);
565
566 for i in 0..dimension {
568 point[i] = Self::normal_cdf(corr_normal[i]);
569 }
570 } else {
571 for i in 0..dimension {
573 let stratum = current_index % 1000; let u = scirs2_core::random::thread_rng().random::<f64>();
575 point[i] = (stratum as f64 + u) / 1000.0;
576 }
577 }
578
579 Ok(point)
580 }
581
582 fn load_joe_kuo_direction_numbers(dimension: usize) -> Result<Vec<Vec<u64>>> {
584 let mut direction_numbers = vec![vec![0u64; 32]; dimension];
585
586 for i in 0..32 {
588 direction_numbers[0][i] = 1u64 << (63 - i);
589 }
590
591 for dim in 1..dimension {
593 let poly_deg = 2 + (dim % 6); let polynomial = Self::get_primitive_polynomial(poly_deg);
596
597 for i in 0..poly_deg {
599 direction_numbers[dim][i] = (1u64 << (63 - i)) ^ ((dim as u64) << (60 - i));
600 }
601
602 for i in poly_deg..32 {
604 let mut val = direction_numbers[dim][i - poly_deg];
605 val ^= val >> poly_deg;
606
607 for j in 1..poly_deg {
608 if (polynomial >> j) & 1 == 1 {
609 val ^= direction_numbers[dim][i - j];
610 }
611 }
612
613 direction_numbers[dim][i] = val;
614 }
615 }
616
617 Ok(direction_numbers)
618 }
619
620 fn get_primitive_polynomial(degree: usize) -> u32 {
622 match degree {
624 2 => 0b111, 3 => 0b1011, 4 => 0b10011, 5 => 0b100101, 6 => 0b1000011, 7 => 0b10000011, _ => 0b111, }
632 }
633
634 fn generate_digital_shift_matrices(
636 dimension: usize,
637 seed: Option<u64>,
638 ) -> Result<Vec<Array2<u32>>> {
639 let mut rng = match seed {
640 Some(s) => StdRng::seed_from_u64(s),
641 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
642 };
643
644 let mut matrices = Vec::with_capacity(dimension);
645
646 for _ in 0..dimension {
647 let mut matrix = Array2::zeros((32, 32));
648
649 for i in 0..32 {
651 matrix[[i, i]] = 1; for j in (i + 1)..32 {
653 matrix[[i, j]] = if rng.random::<f64>() < 0.5 { 1 } else { 0 };
654 }
655 }
656
657 matrices.push(matrix);
658 }
659
660 Ok(matrices)
661 }
662
663 fn apply_digital_shift(value: u64, matrix: &Array2<u32>) -> u64 {
665 let mut result = 0u64;
666
667 for i in 0..32 {
668 let mut bit_result = 0u32;
669 for j in 0..32 {
670 let input_bit = ((value >> (63 - j)) & 1) as u32;
671 bit_result ^= matrix[[i, j]] & input_bit;
672 }
673 result |= (bit_result as u64) << (63 - i);
674 }
675
676 result
677 }
678
679 fn generate_faure_tezuka_permutations(
681 bases: &[u32],
682 seed: Option<u64>,
683 ) -> Result<Vec<Vec<u32>>> {
684 let mut rng = match seed {
685 Some(s) => StdRng::seed_from_u64(s),
686 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
687 };
688
689 let mut permutations = Vec::with_capacity(bases.len());
690
691 for &base in bases {
692 let mut perm: Vec<u32> = (0..base).collect();
693
694 for i in 1..base {
696 let j = rng.random_range(0..i);
697 perm.swap(i as usize, j as usize);
698 }
699
700 permutations.push(perm);
701 }
702
703 Ok(permutations)
704 }
705
706 fn compute_optimal_leap_values(bases: &[u32]) -> Vec<usize> {
708 bases
709 .iter()
710 .map(|&base| {
711 let mut leap = (base / 2) as usize;
713 while Self::gcd(leap, base as usize) != 1 {
714 leap += 1;
715 }
716 leap
717 })
718 .collect()
719 }
720
721 fn generate_generalized_permutations(
723 bases: &[u32],
724 seed: Option<u64>,
725 ) -> Result<Vec<Vec<u32>>> {
726 let mut rng = match seed {
727 Some(s) => StdRng::seed_from_u64(s),
728 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
729 };
730
731 let mut permutations = Vec::with_capacity(bases.len());
732
733 for &base in bases {
734 let mut perm: Vec<u32> = (0..base).collect();
735
736 for i in (1..base).rev() {
738 let j = rng.random_range(0..i);
739 perm.swap(i as usize, j as usize);
740 }
741
742 permutations.push(perm);
743 }
744
745 Ok(permutations)
746 }
747
748 fn normal_cdf(x: f64) -> f64 {
750 0.5 * (1.0 + Self::erf(x / std::f64::consts::SQRT_2))
751 }
752
753 fn erf(x: f64) -> f64 {
755 let a1 = 0.254829592;
757 let a2 = -0.284496736;
758 let a3 = 1.421413741;
759 let a4 = -1.453152027;
760 let a5 = 1.061405429;
761 let p = 0.3275911;
762
763 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
764 let x = x.abs();
765
766 let t = 1.0 / (1.0 + p * x);
767 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
768
769 sign * y
770 }
771
772 fn radical_inverse(index: usize, base: u32) -> Result<f64> {
774 let mut result = 0.0;
775 let mut fraction = 1.0 / base as f64;
776 let mut i = index;
777
778 while i > 0 {
779 result += (i % base as usize) as f64 * fraction;
780 i /= base as usize;
781 fraction /= base as f64;
782 }
783
784 Ok(result)
785 }
786
787 fn scrambled_radical_inverse(index: usize, base: u32, permutation: &[u32]) -> Result<f64> {
788 let mut result = 0.0;
789 let mut fraction = 1.0 / base as f64;
790 let mut i = index;
791
792 while i > 0 {
793 let digit = i % base as usize;
794 let scrambled_digit = permutation[digit];
795 result += scrambled_digit as f64 * fraction;
796 i /= base as usize;
797 fraction /= base as f64;
798 }
799
800 Ok(result)
801 }
802
803 fn first_primes(n: usize) -> Result<Vec<u32>> {
804 let mut primes = Vec::with_capacity(n);
805 let mut candidate = 2u32;
806
807 while primes.len() < n {
808 if Self::is_prime(candidate) {
809 primes.push(candidate);
810 }
811 candidate += 1;
812 }
813
814 Ok(primes)
815 }
816
817 fn is_prime(n: u32) -> bool {
818 if n < 2 {
819 return false;
820 }
821 if n == 2 {
822 return true;
823 }
824 if n.is_multiple_of(2) {
825 return false;
826 }
827
828 let sqrt_n = (n as f64).sqrt() as u32;
829 for i in (3..=sqrt_n).step_by(2) {
830 if n.is_multiple_of(i) {
831 return false;
832 }
833 }
834 true
835 }
836
837 fn smallest_prime_geq(n: u32) -> Result<u32> {
838 let mut candidate = n;
839 while !Self::is_prime(candidate) {
840 candidate += 1;
841 }
842 Ok(candidate)
843 }
844
845 fn gcd(a: usize, b: usize) -> usize {
846 if b == 0 {
847 a
848 } else {
849 Self::gcd(b, a % b)
850 }
851 }
852
853 fn generate_niederreiter_matrices(dimension: usize) -> Result<Vec<Array2<u32>>> {
855 let mut matrices = Vec::with_capacity(dimension);
856
857 let polynomials = Self::get_primitive_polynomials(dimension)?;
859
860 for (dim, polynomial) in polynomials.iter().enumerate().take(dimension) {
861 let degree = polynomial.len() - 1;
862 let mut matrix = Array2::zeros((32, 32));
863
864 if dim == 0 {
865 for i in 0..32 {
867 matrix[[i, i]] = 1;
868 }
869 } else {
870 for i in 0..degree.min(32) {
873 for j in 0..degree.min(32) {
874 if j < polynomial.len() - 1 {
875 matrix[[i, j]] = polynomial[j + 1];
876 }
877 }
878 }
879
880 for i in degree..32 {
882 for j in 0..32 {
883 let mut value = 0u32;
884
885 for k in 1..=degree {
887 if i >= k && j < 32 {
888 value ^= polynomial[k] * matrix[[i - k, j]];
889 }
890 }
891
892 if j > 0 {
894 value ^= matrix[[i - 1, j - 1]];
895 }
896
897 matrix[[i, j]] = value & 1;
898 }
899 }
900
901 for i in 0..32 {
903 for j in 0..32 {
904 if (i + j + dim) % 3 == 0 {
905 matrix[[i, j]] ^= 1;
906 }
907 }
908 }
909 }
910
911 matrices.push(matrix);
912 }
913
914 Ok(matrices)
915 }
916
917 fn get_primitive_polynomials(dimension: usize) -> Result<Vec<Vec<u32>>> {
918 let primitive_polys = [
921 vec![1, 1, 1],
923 vec![1, 0, 1, 1],
925 vec![1, 0, 0, 1, 1],
927 vec![1, 0, 0, 1, 0, 1],
929 vec![1, 0, 0, 0, 0, 1, 1],
931 vec![1, 0, 0, 0, 1, 0, 0, 1],
933 vec![1, 0, 0, 0, 1, 1, 0, 1, 1],
935 vec![1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
937 vec![1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1],
939 vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
941 vec![1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1],
943 ];
944
945 let mut polynomials = Vec::with_capacity(dimension);
946
947 for i in 0..dimension {
948 if i < primitive_polys.len() {
949 polynomials.push(primitive_polys[i].clone());
950 } else {
951 let degree = 2 + (i % 10); let base_poly = &primitive_polys[degree.min(primitive_polys.len() - 1)];
954
955 let mut poly = base_poly.clone();
957 let variation = (i / 10) as u32;
958 for j in 1..poly.len() - 1 {
959 poly[j] ^= (variation >> j) & 1;
960 }
961
962 polynomials.push(poly);
963 }
964 }
965
966 Ok(polynomials)
967 }
968
969 fn generate_faure_permutations(
970 dimension: usize,
971 base: u32,
972 seed: Option<u64>,
973 ) -> Result<Vec<Array2<u32>>> {
974 let mut rng = match seed {
975 Some(s) => StdRng::seed_from_u64(s),
976 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
977 };
978
979 let mut matrices = Vec::with_capacity(dimension);
980 for _ in 0..dimension {
981 let mut matrix = Array2::zeros((base as usize, base as usize));
982 for i in 0..base as usize {
983 let j = rng.random_range(0..base as usize);
984 matrix[[i, j]] = 1;
985 }
986 matrices.push(matrix);
987 }
988 Ok(matrices)
989 }
990}
991
992pub struct StratifiedSampler {
994 config: StratifiedSamplingConfig,
995 dimension: usize,
996 #[allow(dead_code)]
997 strata_counts: HashMap<Vec<usize>, usize>,
998}
999
1000impl StratifiedSampler {
1001 pub fn new(dimension: usize, config: StratifiedSamplingConfig) -> Result<Self> {
1003 let handler = global_error_handler();
1004
1005 if dimension == 0 {
1006 return Err(handler
1007 .create_validation_error(
1008 ErrorCode::E1001,
1009 "StratifiedSampler::new",
1010 "_dimension",
1011 dimension,
1012 "Dimension must be positive",
1013 )
1014 .error);
1015 }
1016
1017 Ok(Self {
1018 config,
1019 dimension,
1020 strata_counts: HashMap::new(),
1021 })
1022 }
1023
1024 pub fn generate(&mut self, nsamples_: usize, seed: Option<u64>) -> Result<Array2<f64>> {
1026 let handler = global_error_handler();
1027
1028 if nsamples_ == 0 {
1029 return Err(handler
1030 .create_validation_error(
1031 ErrorCode::E1001,
1032 "StratifiedSampler::generate",
1033 "n_samples",
1034 nsamples_,
1035 "Number of samples must be positive",
1036 )
1037 .error);
1038 }
1039
1040 let total_strata = self.config.strata_per_dimension.pow(self.dimension as u32);
1041
1042 let base_samples_per_stratum = nsamples_ / total_strata;
1044 let remainder = nsamples_ % total_strata;
1045
1046 let mut samples = Array2::zeros((nsamples_, self.dimension));
1047 let mut sample_idx = 0;
1048
1049 let mut rng = match seed {
1050 Some(s) => StdRng::seed_from_u64(s),
1051 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
1052 };
1053
1054 for stratum_linear_idx in 0..total_strata {
1056 let stratum_indices = self.linear_to_multi_index(stratum_linear_idx);
1057
1058 let samples_in_stratum =
1059 base_samples_per_stratum + if stratum_linear_idx < remainder { 1 } else { 0 };
1060
1061 if samples_in_stratum < self.config.min_samples_per_stratum {
1062 continue;
1063 }
1064
1065 for _ in 0..samples_in_stratum {
1066 let point = self.sample_within_stratum(&stratum_indices, &mut rng)?;
1067 for (dim, &val) in point.iter().enumerate() {
1068 samples[[sample_idx, dim]] = val;
1069 }
1070 sample_idx += 1;
1071
1072 if sample_idx >= nsamples_ {
1073 break;
1074 }
1075 }
1076
1077 if sample_idx >= nsamples_ {
1078 break;
1079 }
1080 }
1081
1082 while sample_idx < nsamples_ {
1084 let random_stratum_idx = rng.random_range(0..total_strata);
1085 let stratum_indices = self.linear_to_multi_index(random_stratum_idx);
1086 let point = self.sample_within_stratum(&stratum_indices, &mut rng)?;
1087
1088 for (dim, &val) in point.iter().enumerate() {
1089 samples[[sample_idx, dim]] = val;
1090 }
1091 sample_idx += 1;
1092 }
1093
1094 Ok(samples)
1095 }
1096
1097 fn linear_to_multi_index(&self, linearidx: usize) -> Vec<usize> {
1099 let mut indices = Vec::with_capacity(self.dimension);
1100 let mut remaining = linearidx;
1101
1102 for _ in 0..self.dimension {
1103 indices.push(remaining % self.config.strata_per_dimension);
1104 remaining /= self.config.strata_per_dimension;
1105 }
1106
1107 indices
1108 }
1109
1110 fn sample_within_stratum(
1112 &self,
1113 stratum_indices: &[usize],
1114 rng: &mut StdRng,
1115 ) -> Result<Array1<f64>> {
1116 let mut point = Array1::zeros(self.dimension);
1117
1118 for (dim, &stratum_idx) in stratum_indices.iter().enumerate() {
1119 let stratum_width = 1.0 / self.config.strata_per_dimension as f64;
1120 let stratum_start = stratum_idx as f64 * stratum_width;
1121
1122 let sample_within_stratum = match self.config.intra_stratum_method {
1123 IntraStratumMethod::Random => stratum_start + rng.random::<f64>() * stratum_width,
1124 IntraStratumMethod::Centroid => stratum_start + 0.5 * stratum_width,
1125 IntraStratumMethod::QMC(_seq_type) => {
1126 stratum_start + (0.5 + 0.3 * (rng.random::<f64>() - 0.5)) * stratum_width
1128 }
1129 IntraStratumMethod::Antithetic => {
1130 if dim % 2 == 0 {
1131 stratum_start + rng.random::<f64>() * stratum_width
1132 } else {
1133 stratum_start + (1.0 - rng.random::<f64>()) * stratum_width
1134 }
1135 }
1136 };
1137
1138 point[dim] = sample_within_stratum.clamp(0.0, 1.0);
1139 }
1140
1141 Ok(point)
1142 }
1143}
1144
1145#[cfg(test)]
1146mod tests {
1147 use super::*;
1148
1149 #[test]
1150 #[ignore = "Test failure - needs investigation"]
1151 fn test_advanced_qmc_sobol() {
1152 let mut generator = AdvancedQMCGenerator::new(QMCSequenceType::Sobol, 2, false, Some(42))
1153 .expect("Operation failed");
1154
1155 let samples = generator.generate(100).expect("Operation failed");
1156 assert_eq!(samples.dim(), (100, 2));
1157
1158 for sample in samples.rows() {
1160 for &val in sample.iter() {
1161 assert!(val >= 0.0 && val <= 1.0);
1162 }
1163 }
1164 }
1165
1166 #[test]
1167 fn test_stratified_sampler() {
1168 let config = StratifiedSamplingConfig {
1169 strata_per_dimension: 3,
1170 intra_stratum_method: IntraStratumMethod::Random,
1171 ..Default::default()
1172 };
1173
1174 let mut sampler = StratifiedSampler::new(2, config).expect("Operation failed");
1175 let samples = sampler.generate(50, Some(42)).expect("Operation failed");
1176
1177 assert_eq!(samples.dim(), (50, 2));
1178
1179 for sample in samples.rows() {
1181 for &val in sample.iter() {
1182 assert!(val >= 0.0 && val <= 1.0);
1183 }
1184 }
1185 }
1186
1187 #[test]
1188 #[ignore = "Test failure - needs investigation"]
1189 fn test_niederreiter_sequence() {
1190 let mut generator =
1191 AdvancedQMCGenerator::new(QMCSequenceType::Niederreiter, 3, false, Some(42))
1192 .expect("Operation failed");
1193
1194 let samples = generator.generate(50).expect("Operation failed");
1195 assert_eq!(samples.dim(), (50, 3));
1196
1197 for j in 0..3 {
1199 let column_mean = samples.column(j).mean();
1200 assert!((column_mean - 0.5).abs() < 0.2); }
1202 }
1203}