use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
use scirs2_core::{parallel_ops::*, simd_ops::SimdUnifiedOps, validation::*};
use std::marker::PhantomData;
pub struct EnhancedQMCGenerator<F> {
pub sequence_type: EnhancedSequenceType,
pub dimension: usize,
pub config: EnhancedQMCConfig,
pub state: QMCGeneratorState,
_phantom: PhantomData<F>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum EnhancedSequenceType {
SobolAdvanced {
owen_scrambling: bool,
digital_shift: bool,
nested_scrambling: bool,
},
Niederreiter {
base_strategy: BaseSelectionStrategy,
matrix_optimization: bool,
},
FaureImproved {
permutation_optimization: bool,
radical_inverse_improvements: bool,
},
DigitalNet {
net_params: DigitalNetParams,
construction_method: NetConstructionMethod,
},
Hybrid {
primary: Box<EnhancedSequenceType>,
secondary: Box<EnhancedSequenceType>,
combination: HybridCombinationStrategy,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum BaseSelectionStrategy {
FirstPrimes,
OptimizedPrimes,
PrimePowers,
Automatic,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DigitalNetParams {
pub t: usize,
pub m: usize,
pub s: usize,
pub base: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NetConstructionMethod {
Sobol,
NiederreiterXing,
PolynomialLattice,
FiniteField,
}
#[derive(Debug, Clone, PartialEq)]
pub enum HybridCombinationStrategy {
Interleave,
Weighted(f64),
DimensionAlternation,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct EnhancedQMCConfig {
pub parallel: bool,
pub chunksize: usize,
pub seed: Option<u64>,
pub use_simd: bool,
pub quality_threshold: f64,
pub max_assessment_length: usize,
pub adaptive_refinement: bool,
}
impl Default for EnhancedQMCConfig {
fn default() -> Self {
Self {
parallel: true,
chunksize: 1000,
seed: None,
use_simd: true,
quality_threshold: 1e-3,
max_assessment_length: 10000,
adaptive_refinement: false,
}
}
}
#[derive(Debug, Clone)]
pub struct QMCGeneratorState {
pub current_index: usize,
pub scrambling_matrices: Option<Vec<Array2<u32>>>,
pub digital_shifts: Option<Vec<Array1<u32>>>,
pub quality_metrics: QualityMetrics,
}
#[derive(Debug, Clone, Default)]
pub struct QualityMetrics {
pub star_discrepancy: f64,
pub wraparound_discrepancy: f64,
pub diaphony: f64,
pub figure_of_merit: f64,
}
impl<F> EnhancedQMCGenerator<F>
where
F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
for<'a> &'a F: std::iter::Product<&'a F>,
{
pub fn new(
sequence_type: EnhancedSequenceType,
dimension: usize,
config: EnhancedQMCConfig,
) -> StatsResult<Self> {
check_positive(dimension, "dimension")?;
if dimension > 1000 {
return Err(StatsError::InvalidArgument(
"Dimension cannot exceed 1000 for enhanced QMC sequences".to_string(),
));
}
let state = QMCGeneratorState {
current_index: 0,
scrambling_matrices: None,
digital_shifts: None,
quality_metrics: QualityMetrics::default(),
};
let mut generator = Self {
sequence_type,
dimension,
config,
state,
_phantom: PhantomData,
};
generator.initialize_randomization()?;
Ok(generator)
}
pub fn generate(&mut self, n: usize) -> StatsResult<Array2<F>> {
check_positive(n, "n")?;
if self.config.parallel && n >= self.config.chunksize {
self.generate_parallel(n)
} else {
self.generate_sequential(n)
}
}
fn generate_parallel(&mut self, n: usize) -> StatsResult<Array2<F>> {
let chunksize = self.config.chunksize;
let num_chunks = n.div_ceil(chunksize);
let chunks = parallel_map_result(
(0..num_chunks).collect::<Vec<_>>().as_slice(),
|&chunk_idx| {
let start = chunk_idx * chunksize;
let end = (start + chunksize).min(n);
let chunksize = end - start;
self.generate_chunk(start, chunksize)
},
)?;
let mut result = Array2::zeros((n, self.dimension));
let mut row_idx = 0;
for chunk in chunks {
let chunk = chunk;
let chunk_rows = chunk.nrows();
result
.slice_mut(scirs2_core::ndarray::s![row_idx..row_idx + chunk_rows, ..])
.assign(&chunk);
row_idx += chunk_rows;
}
if n <= self.config.max_assessment_length {
self.assess_quality(&result)?;
}
Ok(result)
}
fn generate_sequential(&mut self, n: usize) -> StatsResult<Array2<F>> {
let mut result = Array2::zeros((n, self.dimension));
for i in 0..n {
let point = self.next_point()?;
result.row_mut(i).assign(&point);
}
if n <= self.config.max_assessment_length {
self.assess_quality(&result)?;
}
Ok(result)
}
fn generate_chunk(&self, start_index: usize, chunksize: usize) -> StatsResult<Array2<F>> {
let mut chunk = Array2::zeros((chunksize, self.dimension));
for i in 0..chunksize {
let _index = start_index + i;
let point = self.compute_point_at_index(_index)?;
chunk.row_mut(i).assign(&point);
}
Ok(chunk)
}
fn next_point(&mut self) -> StatsResult<Array1<F>> {
let point = self.compute_point_at_index(self.state.current_index)?;
self.state.current_index += 1;
Ok(point)
}
fn compute_point_at_index(&self, index: usize) -> StatsResult<Array1<F>> {
match &self.sequence_type {
EnhancedSequenceType::SobolAdvanced {
owen_scrambling,
digital_shift,
nested_scrambling,
} => self.compute_sobol_advanced(
index,
*owen_scrambling,
*digital_shift,
*nested_scrambling,
),
EnhancedSequenceType::Niederreiter {
base_strategy,
matrix_optimization,
} => self.compute_niederreiter_enhanced(index, base_strategy, *matrix_optimization),
EnhancedSequenceType::FaureImproved {
permutation_optimization,
radical_inverse_improvements,
} => self.compute_faure_improved(
index,
*permutation_optimization,
*radical_inverse_improvements,
),
EnhancedSequenceType::DigitalNet {
net_params,
construction_method,
} => self.compute_digital_net(index, net_params, construction_method),
EnhancedSequenceType::Hybrid {
primary,
secondary,
combination,
} => self.compute_hybrid_sequence(index, primary, secondary, combination),
}
}
fn compute_sobol_advanced(
&self,
index: usize,
owen_scrambling: bool,
digital_shift: bool,
_nested_scrambling: bool,
) -> StatsResult<Array1<F>> {
let mut point = Array1::zeros(self.dimension);
for dim in 0..self.dimension {
let mut result = 0u32;
let mut idx = index;
let mut base_power = 1u32;
while idx > 0 {
if idx & 1 == 1 {
result ^= base_power << (31 - (base_power.trailing_zeros()));
}
idx >>= 1;
base_power <<= 1;
}
if owen_scrambling {
if let Some(ref matrices) = self.state.scrambling_matrices {
if dim < matrices.len() {
result = self.apply_owen_scrambling(result, &matrices[dim]);
}
}
}
if digital_shift {
if let Some(ref shifts) = self.state.digital_shifts {
if dim < shifts.len() {
result ^= shifts[dim][0]; }
}
}
point[dim] = F::from(result as f64 / (1u64 << 32) as f64).expect("Operation failed");
}
Ok(point)
}
fn compute_niederreiter_enhanced(
&self,
index: usize,
base_strategy: &BaseSelectionStrategy,
_matrix_optimization: bool,
) -> StatsResult<Array1<F>> {
let mut point = Array1::zeros(self.dimension);
for dim in 0..self.dimension {
let base = self.get_prime(dim);
point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
}
Ok(point)
}
fn compute_faure_improved(
&self,
index: usize,
_permutation_optimization: bool,
_radical_inverse_improvements: bool,
) -> StatsResult<Array1<F>> {
let mut point = Array1::zeros(self.dimension);
let base = self.smallest_prime_geq(self.dimension as u32);
for dim in 0..self.dimension {
point[dim] = F::from(self.radical_inverse(index, base)).expect("Operation failed");
}
Ok(point)
}
fn compute_digital_net(
&self,
index: usize,
_net_params: &DigitalNetParams,
_construction_method: &NetConstructionMethod,
) -> StatsResult<Array1<F>> {
self.compute_sobol_advanced(index, false, false, false)
}
fn compute_hybrid_sequence(
&self,
index: usize,
_primary: &EnhancedSequenceType,
_secondary: &EnhancedSequenceType,
_combination: &HybridCombinationStrategy,
) -> StatsResult<Array1<F>> {
self.compute_sobol_advanced(index, true, true, false)
}
fn initialize_randomization(&mut self) -> StatsResult<()> {
let mut rng = match self.config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
};
if self.needs_scrambling() {
let mut matrices = Vec::with_capacity(self.dimension);
for _ in 0..self.dimension {
matrices.push(self.generate_scrambling_matrix(&mut rng)?);
}
self.state.scrambling_matrices = Some(matrices);
}
if self.needs_digital_shift() {
let mut shifts = Vec::with_capacity(self.dimension);
for _ in 0..self.dimension {
let shift = Array1::from_shape_fn(32, |_| rng.random::<u32>());
shifts.push(shift);
}
self.state.digital_shifts = Some(shifts);
}
Ok(())
}
fn needs_scrambling(&self) -> bool {
match &self.sequence_type {
EnhancedSequenceType::SobolAdvanced {
owen_scrambling, ..
} => *owen_scrambling,
_ => false,
}
}
fn needs_digital_shift(&self) -> bool {
match &self.sequence_type {
EnhancedSequenceType::SobolAdvanced { digital_shift, .. } => *digital_shift,
_ => false,
}
}
fn generate_scrambling_matrix<R: Rng>(&self, rng: &mut R) -> StatsResult<Array2<u32>> {
let mut matrix = Array2::zeros((32, 32));
for i in 0..32 {
let j = rng.random_range(0..32);
matrix[[i, j]] = 1;
}
Ok(matrix)
}
fn apply_owen_scrambling(&self, value: u32, matrix: &Array2<u32>) -> u32 {
let mut result = 0u32;
for i in 0..32 {
let bit = (value >> (31 - i)) & 1;
for j in 0..32 {
if matrix[[i, j]] == 1 && bit == 1 {
result |= 1u32 << (31 - j);
break;
}
}
}
result
}
fn radical_inverse(&self, index: usize, base: u32) -> f64 {
let mut result = 0.0;
let mut fraction = 1.0 / base as f64;
let mut i = index;
while i > 0 {
result += (i % base as usize) as f64 * fraction;
i /= base as usize;
fraction /= base as f64;
}
result
}
fn get_prime(&self, n: usize) -> u32 {
let primes = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
];
if n < primes.len() {
primes[n]
} else {
let mut candidate = primes[primes.len() - 1] + 2;
let mut count = primes.len();
while count <= n {
if self.is_prime(candidate) {
if count == n {
return candidate;
}
count += 1;
}
candidate += 2;
}
candidate
}
}
fn smallest_prime_geq(&self, n: u32) -> u32 {
if n <= 2 {
return 2;
}
let mut candidate = if n.is_multiple_of(2) { n + 1 } else { n };
while !self.is_prime(candidate) {
candidate += 2;
}
candidate
}
fn is_prime(&self, n: u32) -> bool {
if n < 2 {
return false;
}
if n == 2 {
return true;
}
if n.is_multiple_of(2) {
return false;
}
let sqrt_n = (n as f64).sqrt() as u32;
for i in (3..=sqrt_n).step_by(2) {
if n.is_multiple_of(i) {
return false;
}
}
true
}
fn assess_quality(&mut self, sequence: &Array2<F>) -> StatsResult<()> {
let n = sequence.nrows();
let d = sequence.ncols();
let mut max_discrepancy = 0.0;
let num_test_points = 50.min(n);
let mut rng = scirs2_core::random::thread_rng();
for _ in 0..num_test_points {
let mut test_point = Array1::zeros(d);
for j in 0..d {
test_point[j] = F::from(rng.random::<f64>()).expect("Operation failed");
}
let mut count = 0;
for i in 0..n {
let mut in_box = true;
for j in 0..d {
if sequence[[i, j]] > test_point[j] {
in_box = false;
break;
}
}
if in_box {
count += 1;
}
}
let volume: F = test_point.iter().fold(F::one(), |acc, &x| acc * x);
let expected = volume.to_f64().expect("Operation failed") * n as f64;
let discrepancy = (count as f64 - expected).abs() / n as f64;
max_discrepancy = max_discrepancy.max(discrepancy);
}
self.state.quality_metrics.star_discrepancy = max_discrepancy;
Ok(())
}
pub fn quality_metrics(&self) -> &QualityMetrics {
&self.state.quality_metrics
}
}
#[allow(dead_code)]
pub fn enhanced_sobol<F>(
n: usize,
dimension: usize,
scrambling: bool,
seed: Option<u64>,
) -> StatsResult<Array2<F>>
where
F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
for<'a> &'a F: std::iter::Product<&'a F>,
{
let sequence_type = EnhancedSequenceType::SobolAdvanced {
owen_scrambling: scrambling,
digital_shift: true,
nested_scrambling: false,
};
let config = EnhancedQMCConfig {
seed,
..Default::default()
};
let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
generator.generate(n)
}
#[allow(dead_code)]
pub fn enhanced_niederreiter<F>(
n: usize,
dimension: usize,
seed: Option<u64>,
) -> StatsResult<Array2<F>>
where
F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
for<'a> &'a F: std::iter::Product<&'a F>,
{
let sequence_type = EnhancedSequenceType::Niederreiter {
base_strategy: BaseSelectionStrategy::OptimizedPrimes,
matrix_optimization: true,
};
let config = EnhancedQMCConfig {
seed,
..Default::default()
};
let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
generator.generate(n)
}
#[allow(dead_code)]
pub fn enhanced_digital_net<F>(
n: usize,
dimension: usize,
t: usize,
seed: Option<u64>,
) -> StatsResult<Array2<F>>
where
F: Float + Zero + One + Copy + Send + Sync + SimdUnifiedOps + FromPrimitive + std::fmt::Display,
for<'a> &'a F: std::iter::Product<&'a F>,
{
let net_params = DigitalNetParams {
t,
m: 32,
s: dimension,
base: 2,
};
let sequence_type = EnhancedSequenceType::DigitalNet {
net_params,
construction_method: NetConstructionMethod::Sobol,
};
let config = EnhancedQMCConfig {
seed,
..Default::default()
};
let mut generator = EnhancedQMCGenerator::new(sequence_type, dimension, config)?;
generator.generate(n)
}