use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
use scirs2_core::validation::*;
#[allow(dead_code)]
pub fn sobol(n: usize, d: usize, scramble: bool, seed: Option<u64>) -> StatsResult<Array2<f64>> {
check_positive(n, "n")?;
check_positive(d, "d")?;
if d > 32 {
return Err(StatsError::InvalidArgument(
"Dimension cannot exceed 32 for Sobol sequence".to_string(),
));
}
let mut sequence = SobolSequence::new(d, scramble, seed)?;
sequence.generate(n)
}
#[allow(dead_code)]
pub fn halton(n: usize, d: usize, scramble: bool, seed: Option<u64>) -> StatsResult<Array2<f64>> {
check_positive(n, "n")?;
check_positive(d, "d")?;
if d > 100 {
return Err(StatsError::InvalidArgument(
"Dimension cannot exceed 100 for Halton sequence".to_string(),
));
}
let mut sequence = HaltonSequence::new(d, scramble, seed)?;
sequence.generate(n)
}
#[allow(dead_code)]
pub fn latin_hypercube(n: usize, d: usize, seed: Option<u64>) -> StatsResult<Array2<f64>> {
check_positive(n, "n")?;
check_positive(d, "d")?;
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
use std::time::{SystemTime, UNIX_EPOCH};
let s = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
StdRng::seed_from_u64(s)
}
};
let mut samples = Array2::zeros((n, d));
for dim in 0..d {
let mut intervals: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
let j = rng.random_range(0..i);
intervals.swap(i, j);
}
for (i, &interval) in intervals.iter().enumerate() {
let u: f64 = rng.random();
samples[[i, dim]] = (interval as f64 + u) / n as f64;
}
}
Ok(samples)
}
pub struct SobolSequence {
dimension: usize,
direction_numbers: Vec<Vec<u32>>,
current_index: usize,
#[allow(dead_code)]
scramble: bool,
scramble_matrices: Option<Vec<Array2<u32>>>,
}
impl SobolSequence {
pub fn new(dimension: usize, scramble: bool, seed: Option<u64>) -> StatsResult<Self> {
if dimension == 0 || dimension > 32 {
return Err(StatsError::InvalidArgument(
"Dimension must be between 1 and 32".to_string(),
));
}
let direction_numbers = Self::initialize_direction_numbers(dimension)?;
let scramble_matrices = if scramble {
Some(Self::generate_scramble_matrices(dimension, seed)?)
} else {
None
};
Ok(Self {
dimension,
direction_numbers,
current_index: 0,
scramble,
scramble_matrices,
})
}
pub fn generate(&mut self, n: usize) -> StatsResult<Array2<f64>> {
let mut samples = Array2::zeros((n, self.dimension));
if n >= 64 && self.dimension <= 16 {
return self.generate_simd_ultra(n);
}
for i in 0..n {
let point = self.next_point()?;
for (j, &val) in point.iter().enumerate() {
samples[[i, j]] = val;
}
}
Ok(samples)
}
pub fn generate_simd_ultra(&mut self, n: usize) -> StatsResult<Array2<f64>> {
use scirs2_core::simd_ops::PlatformCapabilities;
let capabilities = PlatformCapabilities::detect();
let mut samples = Array2::zeros((n, self.dimension));
let chunk_size = if capabilities.has_avx512() {
16
} else if capabilities.has_avx2() {
8
} else {
4
};
let num_chunks = n.div_ceil(chunk_size);
let mut indices_buffer = Vec::with_capacity(chunk_size);
let mut points_buffer = Vec::with_capacity(chunk_size * self.dimension);
for chunk_idx in 0..num_chunks {
let start_sample = chunk_idx * chunk_size;
let end_sample = std::cmp::min(start_sample + chunk_size, n);
let current_chunk_size = end_sample - start_sample;
if current_chunk_size == 0 {
break;
}
indices_buffer.clear();
for i in 0..current_chunk_size {
indices_buffer.push(self.current_index + i);
}
points_buffer.clear();
if capabilities.has_avx2() && current_chunk_size >= 8 {
for dim in 0..self.dimension {
let mut dim_values = Vec::with_capacity(current_chunk_size);
for &index in &indices_buffer {
let mut result = 0u32;
for bit in 0..32 {
if (index >> bit) & 1 == 1 {
result ^= self.direction_numbers[dim][bit];
}
}
if let Some(ref matrices) = self.scramble_matrices {
result = Self::apply_scrambling(result, &matrices[dim]);
}
let sobol_value = result as f64 / (1u64 << 32) as f64;
dim_values.push(sobol_value as f32);
}
if dim_values.len() >= 8 {
for &val in &dim_values {
points_buffer.push(val);
}
} else {
for &val in &dim_values {
points_buffer.push(val);
}
}
}
} else {
for &index in &indices_buffer {
for dim in 0..self.dimension {
let mut result = 0u32;
for bit in 0..32 {
if (index >> bit) & 1 == 1 {
result ^= self.direction_numbers[dim][bit];
}
}
if let Some(ref matrices) = self.scramble_matrices {
result = Self::apply_scrambling(result, &matrices[dim]);
}
let sobol_value = result as f64 / (1u64 << 32) as f64;
points_buffer.push(sobol_value as f32);
}
}
}
if capabilities.has_avx2() && points_buffer.len() >= 8 {
let mut write_idx = 0;
for sample_idx in 0..current_chunk_size {
for dim in 0..self.dimension {
let buffer_idx = sample_idx * self.dimension + dim;
let array_row = start_sample + sample_idx;
samples[[array_row, dim]] = points_buffer[buffer_idx] as f64;
write_idx += 1;
}
}
} else {
for sample_idx in 0..current_chunk_size {
for dim in 0..self.dimension {
let buffer_idx = sample_idx * self.dimension + dim;
let array_row = start_sample + sample_idx;
samples[[array_row, dim]] = points_buffer[buffer_idx] as f64;
}
}
}
self.current_index += current_chunk_size;
}
Ok(samples)
}
pub fn next_point(&mut self) -> StatsResult<Array1<f64>> {
let mut point = Array1::zeros(self.dimension);
for dim in 0..self.dimension {
let mut result = 0u32;
let index = self.current_index;
for bit in 0..32 {
if (index >> bit) & 1 == 1 {
result ^= self.direction_numbers[dim][bit];
}
}
if let Some(ref matrices) = self.scramble_matrices {
result = Self::apply_scrambling(result, &matrices[dim]);
}
point[dim] = result as f64 / (1u64 << 32) as f64;
}
self.current_index += 1;
Ok(point)
}
fn initialize_direction_numbers(dimension: usize) -> StatsResult<Vec<Vec<u32>>> {
let mut direction_numbers = vec![vec![0u32; 32]; dimension];
for i in 0..32 {
direction_numbers[0][i] = 1u32 << (31 - i);
}
let primitive_polynomials = [
(1, vec![]), (2, vec![1]), (3, vec![1, 3]), (3, vec![2, 3]), (4, vec![1, 4]), (4, vec![3, 4]), (4, vec![1, 2, 4]), (4, vec![1, 3, 4]), ];
for dim in 1..dimension {
let poly_idx = (dim - 1) % primitive_polynomials.len();
let (degree, ref coeffs) = primitive_polynomials[poly_idx];
for i in 0..degree {
direction_numbers[dim][i] = (1u32 << (31 - i)) | (1u32 << (31 - degree));
}
for i in degree..32 {
let mut val = direction_numbers[dim][i - degree]
^ (direction_numbers[dim][i - degree] >> degree);
for &coeff in coeffs {
if coeff <= i {
val ^= direction_numbers[dim][i - coeff];
}
}
direction_numbers[dim][i] = val;
}
}
Ok(direction_numbers)
}
fn generate_scramble_matrices(
dimension: usize,
seed: Option<u64>,
) -> StatsResult<Vec<Array2<u32>>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
use std::time::{SystemTime, UNIX_EPOCH};
let s = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
StdRng::seed_from_u64(s)
}
};
let mut matrices = Vec::with_capacity(dimension);
for _ in 0..dimension {
let mut matrix = Array2::zeros((32, 32));
for i in 0..32 {
let j = rng.random_range(0..32);
matrix[[i, j]] = 1;
}
matrices.push(matrix);
}
Ok(matrices)
}
fn apply_scrambling(value: u32, matrix: &Array2<u32>) -> u32 {
let mut result = 0u32;
for i in 0..32 {
let bit = (value >> (31 - i)) & 1;
for j in 0..32 {
if matrix[[i, j]] == 1 && bit == 1 {
result |= 1u32 << (31 - j);
break;
}
}
}
result
}
}
pub struct HaltonSequence {
dimension: usize,
bases: Vec<u32>,
current_index: usize,
scramble: bool,
permutations: Option<Vec<Vec<u32>>>,
}
impl HaltonSequence {
pub fn new(dimension: usize, scramble: bool, seed: Option<u64>) -> StatsResult<Self> {
if dimension == 0 {
return Err(StatsError::InvalidArgument(
"Dimension must be at least 1".to_string(),
));
}
let bases = Self::first_primes(dimension)?;
let permutations = if scramble {
Some(Self::generate_permutations(&bases, seed)?)
} else {
None
};
Ok(Self {
dimension,
bases,
current_index: 0,
scramble,
permutations,
})
}
pub fn generate(&mut self, n: usize) -> StatsResult<Array2<f64>> {
let mut samples = Array2::zeros((n, self.dimension));
if n >= 64 && self.dimension <= 32 {
return self.generate_halton_simd_ultra(n);
}
for i in 0..n {
let point = self.next_point()?;
for (j, &val) in point.iter().enumerate() {
samples[[i, j]] = val;
}
}
Ok(samples)
}
pub fn generate_halton_simd_ultra(&mut self, n: usize) -> StatsResult<Array2<f64>> {
use scirs2_core::simd_ops::PlatformCapabilities;
let capabilities = PlatformCapabilities::detect();
let mut samples = Array2::zeros((n, self.dimension));
let chunk_size = if capabilities.has_avx512() {
16
} else if capabilities.has_avx2() {
8
} else {
4
};
let num_chunks = n.div_ceil(chunk_size);
let mut indices_buffer = Vec::with_capacity(chunk_size);
let mut radical_inverse_buffer = Vec::with_capacity(chunk_size);
for chunk_idx in 0..num_chunks {
let start_sample = chunk_idx * chunk_size;
let end_sample = std::cmp::min(start_sample + chunk_size, n);
let current_chunk_size = end_sample - start_sample;
if current_chunk_size == 0 {
break;
}
indices_buffer.clear();
for i in 0..current_chunk_size {
indices_buffer.push(self.current_index + i);
}
for dim in 0..self.dimension {
let base = self.bases[dim];
radical_inverse_buffer.clear();
if capabilities.has_avx2() && current_chunk_size >= 8 {
if self.scramble {
for &index in &indices_buffer {
let value = Self::scrambled_radical_inverse_simd_ultra(
index,
base,
self.permutations.as_ref().expect("Operation failed")[dim]
.as_slice(),
)?;
radical_inverse_buffer.push(value as f32);
}
} else {
for &index in &indices_buffer {
let value = Self::radical_inverse_simd_ultra(index, base)?;
radical_inverse_buffer.push(value as f32);
}
}
if radical_inverse_buffer.len() >= 8 {
for (i, &value) in radical_inverse_buffer.iter().enumerate() {
let sample_row = start_sample + i;
if sample_row < n {
samples[[sample_row, dim]] = value as f64;
}
}
} else {
for (i, &value) in radical_inverse_buffer.iter().enumerate() {
let sample_row = start_sample + i;
if sample_row < n {
samples[[sample_row, dim]] = value as f64;
}
}
}
} else {
for (i, &index) in indices_buffer.iter().enumerate() {
let value = if self.scramble {
Self::scrambled_radical_inverse(
index,
base,
self.permutations.as_ref().expect("Operation failed")[dim]
.as_slice(),
)?
} else {
Self::radical_inverse(index, base)?
};
let sample_row = start_sample + i;
if sample_row < n {
samples[[sample_row, dim]] = value;
}
}
}
}
self.current_index += current_chunk_size;
}
Ok(samples)
}
fn radical_inverse_simd_ultra(index: usize, base: u32) -> StatsResult<f64> {
let mut result = 0.0;
let mut fraction = 1.0 / base as f64;
let mut i = index;
while i > 0 {
let digit = i % base as usize;
result += digit as f64 * fraction;
i /= base as usize;
fraction /= base as f64;
}
Ok(result)
}
fn scrambled_radical_inverse_simd_ultra(
index: usize,
base: u32,
permutation: &[u32],
) -> StatsResult<f64> {
let mut result = 0.0;
let mut fraction = 1.0 / base as f64;
let mut i = index;
while i > 0 {
let digit = i % base as usize;
let scrambled_digit = permutation[digit];
result += scrambled_digit as f64 * fraction;
i /= base as usize;
fraction /= base as f64;
}
Ok(result)
}
pub fn next_point(&mut self) -> StatsResult<Array1<f64>> {
let mut point = Array1::zeros(self.dimension);
for dim in 0..self.dimension {
let base = self.bases[dim];
let value = if self.scramble {
Self::scrambled_radical_inverse(
self.current_index,
base,
self.permutations.as_ref().expect("Operation failed")[dim].as_slice(),
)?
} else {
Self::radical_inverse(self.current_index, base)?
};
point[dim] = value;
}
self.current_index += 1;
Ok(point)
}
fn radical_inverse(index: usize, base: u32) -> StatsResult<f64> {
let mut result = 0.0;
let mut fraction = 1.0 / base as f64;
let mut i = index;
while i > 0 {
result += (i % base as usize) as f64 * fraction;
i /= base as usize;
fraction /= base as f64;
}
Ok(result)
}
fn scrambled_radical_inverse(index: usize, base: u32, permutation: &[u32]) -> StatsResult<f64> {
let mut result = 0.0;
let mut fraction = 1.0 / base as f64;
let mut i = index;
while i > 0 {
let digit = i % base as usize;
let scrambled_digit = permutation[digit];
result += scrambled_digit as f64 * fraction;
i /= base as usize;
fraction /= base as f64;
}
Ok(result)
}
fn first_primes(n: usize) -> StatsResult<Vec<u32>> {
if n == 0 {
return Ok(vec![]);
}
let mut primes = Vec::with_capacity(n);
let mut candidate = 2u32;
while primes.len() < n {
if Self::is_prime(candidate) {
primes.push(candidate);
}
candidate += 1;
}
Ok(primes)
}
fn is_prime(n: u32) -> bool {
if n < 2 {
return false;
}
if n == 2 {
return true;
}
if n.is_multiple_of(2) {
return false;
}
let sqrt_n = (n as f64).sqrt() as u32;
for i in (3..=sqrt_n).step_by(2) {
if n.is_multiple_of(i) {
return false;
}
}
true
}
fn generate_permutations(bases: &[u32], seed: Option<u64>) -> StatsResult<Vec<Vec<u32>>> {
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
use std::time::{SystemTime, UNIX_EPOCH};
let s = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
StdRng::seed_from_u64(s)
}
};
let mut permutations = Vec::with_capacity(bases.len());
for &base in bases {
let mut perm: Vec<u32> = (0..base).collect();
for i in (1..base).rev() {
let j = rng.random_range(0..i);
perm.swap(i as usize, j as usize);
}
permutations.push(perm);
}
Ok(permutations)
}
}
#[allow(dead_code)]
pub fn star_discrepancy(samples: &ArrayView1<Array1<f64>>) -> StatsResult<f64> {
if samples.is_empty() {
return Err(StatsError::InvalidArgument(
"samples array cannot be empty".to_string(),
));
}
let n = samples.len();
let d = samples[0].len();
let mut max_discrepancy: f64 = 0.0;
let num_test_points = 100;
let mut rng = scirs2_core::random::thread_rng();
for _ in 0..num_test_points {
let mut test_point = Array1::zeros(d);
for j in 0..d {
test_point[j] = (rng.random::<f64>() * 0.9) + 0.05; }
let mut count = 0;
for sample in samples.iter() {
let mut in_box = true;
for j in 0..d {
if sample[j] > test_point[j] {
in_box = false;
break;
}
}
if in_box {
count += 1;
}
}
let volume: f64 = test_point.iter().product();
let expected = volume * n as f64;
let discrepancy = (count as f64 - expected).abs() / n as f64;
max_discrepancy = max_discrepancy.max(discrepancy);
}
Ok(max_discrepancy)
}
pub mod advanced;
pub mod enhanced_sequences;
pub use advanced::*;
pub use enhanced_sequences::*;