use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::uniform::SampleUniform;
use scirs2_core::random::{Distribution, Uniform};
pub trait QmcSequence {
fn next_point(&mut self) -> Vec<f64>;
fn dim(&self) -> usize;
fn reset(&mut self);
fn generate(&mut self, n: usize) -> Array2<f64> {
let d = self.dim();
let mut out = Array2::<f64>::zeros((n, d));
for i in 0..n {
let p = self.next_point();
for j in 0..d {
out[[i, j]] = p[j];
}
}
out
}
}
pub struct HaltonSequence {
pub base: Vec<usize>,
pub index: usize,
}
impl HaltonSequence {
pub fn new(base: Vec<usize>) -> Self {
for (i, &b) in base.iter().enumerate() {
debug_assert!(b >= 2, "base[{i}] must be >= 2");
}
Self { base, index: 0 }
}
pub fn with_primes(dim: usize) -> Self {
let primes = first_n_primes(dim);
Self::new(primes)
}
fn van_der_corput(mut n: usize, base: usize) -> f64 {
let mut result = 0.0f64;
let mut denominator = 1.0f64;
while n > 0 {
denominator *= base as f64;
result += (n % base) as f64 / denominator;
n /= base;
}
result
}
}
impl QmcSequence for HaltonSequence {
fn next_point(&mut self) -> Vec<f64> {
self.index += 1;
self.base
.iter()
.map(|&b| Self::van_der_corput(self.index, b))
.collect()
}
fn dim(&self) -> usize {
self.base.len()
}
fn reset(&mut self) {
self.index = 0;
}
}
static JOE_KUO_DIRECTION_INIT: &[(
u32, u32, &[u32], )] = &[
(1, 0, &[1]),
(2, 1, &[1, 1]),
(3, 1, &[1, 3, 1]),
(3, 2, &[1, 1, 1]),
(4, 1, &[1, 1, 3, 3]),
(4, 4, &[1, 3, 5, 13]),
(5, 2, &[1, 1, 5, 5, 17]),
(5, 4, &[1, 1, 5, 5, 5]),
(5, 7, &[1, 1, 7, 11, 19]),
(5, 11, &[1, 1, 5, 1, 1]),
(5, 13, &[1, 1, 1, 3, 11]),
(5, 14, &[1, 3, 5, 5, 31]),
(6, 1, &[1, 3, 3, 9, 7, 49]),
(6, 13, &[1, 1, 1, 15, 21, 21]),
(6, 16, &[1, 3, 1, 13, 27, 49]),
(6, 19, &[1, 1, 1, 15, 3, 13]),
(6, 22, &[1, 3, 1, 15, 13, 17]),
(6, 25, &[1, 1, 5, 5, 19, 45]),
(7, 1, &[1, 3, 5, 5, 19, 27, 97]),
(7, 4, &[1, 3, 7, 5, 13, 29, 91]),
];
pub struct SobolSequence {
dim: usize,
index: u64,
current: Vec<u64>,
dir: Vec<Vec<u64>>,
}
impl SobolSequence {
pub const MAX_DIM: usize = 21;
const BITS: u32 = 32;
pub fn new(dim: usize) -> Self {
let dim = dim.min(Self::MAX_DIM);
let dir = build_sobol_direction_numbers(dim);
Self {
dim,
index: 0,
current: vec![0u64; dim],
dir,
}
}
fn advance(&mut self) -> Vec<f64> {
if self.index == 0 {
self.index = 1;
return vec![0.0f64; self.dim];
}
let c = (!(self.index - 1)).trailing_zeros() as usize;
for d in 0..self.dim {
if c < self.dir[d].len() {
self.current[d] ^= self.dir[d][c];
}
}
self.index += 1;
let scale = 2.0f64.powi(Self::BITS as i32);
self.current
.iter()
.map(|&x| (x as f64) / scale)
.collect()
}
}
impl QmcSequence for SobolSequence {
fn next_point(&mut self) -> Vec<f64> {
self.advance()
}
fn dim(&self) -> usize {
self.dim
}
fn reset(&mut self) {
self.index = 0;
self.current.iter_mut().for_each(|x| *x = 0);
}
}
fn build_sobol_direction_numbers(dim: usize) -> Vec<Vec<u64>> {
let bits = SobolSequence::BITS as usize;
let mut dir = vec![vec![0u64; bits]; dim];
for j in 0..bits {
dir[0][j] = 1u64 << (bits - 1 - j);
}
for d in 1..dim {
let table_idx = d - 1; if table_idx >= JOE_KUO_DIRECTION_INIT.len() {
for j in 0..bits {
dir[d][j] = 1u64 << (bits - 1 - j);
}
continue;
}
let (s, a, m_init) = JOE_KUO_DIRECTION_INIT[table_idx];
let s = s as usize;
for i in 0..s {
if i < m_init.len() {
dir[d][i] = (m_init[i] as u64) << (bits - 1 - i);
}
}
for j in s..bits {
let mut v = dir[d][j - s] ^ (dir[d][j - s] >> s);
for k in 1..s {
if (a >> (s - 1 - k)) & 1 == 1 {
v ^= dir[d][j - k];
}
}
dir[d][j] = v;
}
}
dir
}
pub struct LatticeRule {
pub generator: Vec<f64>,
pub n: usize,
index: usize,
}
impl LatticeRule {
pub fn new(generator: Vec<f64>, n: usize) -> Self {
Self {
generator,
n,
index: 0,
}
}
pub fn korobov(dim: usize, a: u64, n: usize) -> Self {
let n64 = n as u64;
let generator: Vec<f64> = (0..dim)
.map(|k| {
let mut power = 1u64;
for _ in 0..k {
power = (power * a) % n64;
}
power as f64 / n as f64
})
.collect();
Self::new(generator, n)
}
}
impl QmcSequence for LatticeRule {
fn next_point(&mut self) -> Vec<f64> {
let i = self.index;
self.index += 1;
let i_f = i as f64;
let n_f = self.n as f64;
self.generator
.iter()
.map(|&g| {
let raw = i_f * g * n_f / n_f; let frac = (i_f * g).fract();
frac.rem_euclid(1.0)
})
.collect()
}
fn dim(&self) -> usize {
self.generator.len()
}
fn reset(&mut self) {
self.index = 0;
}
}
pub fn qmc_integrate<F, S>(
f: F,
bounds: &[(f64, f64)],
sequence: &mut S,
n_samples: usize,
) -> IntegrateResult<f64>
where
F: Fn(&[f64]) -> f64,
S: QmcSequence,
{
if bounds.is_empty() {
return Err(IntegrateError::ValueError(
"bounds must not be empty".to_string(),
));
}
if n_samples == 0 {
return Err(IntegrateError::ValueError(
"n_samples must be positive".to_string(),
));
}
let dim = bounds.len();
if sequence.dim() != dim {
return Err(IntegrateError::DimensionMismatch(format!(
"sequence dimension ({}) != bounds dimension ({})",
sequence.dim(),
dim
)));
}
let volume: f64 = bounds.iter().map(|&(a, b)| b - a).product();
let mut sum = 0.0f64;
let mut scaled = vec![0.0f64; dim];
for _ in 0..n_samples {
let unit_pt = sequence.next_point();
for d in 0..dim {
let (a, b) = bounds[d];
scaled[d] = a + unit_pt[d] * (b - a);
}
sum += f(&scaled);
}
Ok(sum / (n_samples as f64) * volume)
}
pub fn star_discrepancy(points: &Array2<f64>) -> IntegrateResult<f64> {
let (n, d) = (points.nrows(), points.ncols());
if n == 0 || d == 0 {
return Err(IntegrateError::ValueError(
"point set must be non-empty".to_string(),
));
}
let n_f = n as f64;
let mut disc = 0.0f64;
if n > 2000 || d > 4 {
for j in 0..d {
let mut col: Vec<f64> = (0..n).map(|i| points[[i, j]]).collect();
col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let d_j = star_discrepancy_1d(&col, n_f);
disc = disc.max(d_j);
}
return Ok(disc);
}
let mut x = vec![0.0f64; d];
star_discrepancy_recursive(points, n, d, n_f, &mut x, 0, &mut disc);
Ok(disc)
}
fn star_discrepancy_recursive(
points: &Array2<f64>,
n: usize,
d: usize,
n_f: f64,
x: &mut Vec<f64>,
dim: usize,
disc: &mut f64,
) {
if dim == d {
let count = (0..n).filter(|&i| {
(0..d).all(|j| points[[i, j]] < x[j])
}).count();
let vol: f64 = x.iter().product();
let gap = (count as f64 / n_f - vol).abs();
if gap > *disc {
*disc = gap;
}
return;
}
let mut vals: Vec<f64> = (0..n).map(|i| points[[i, dim]]).collect();
vals.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
vals.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
for &v in &vals {
x[dim] = v;
star_discrepancy_recursive(points, n, d, n_f, x, dim + 1, disc);
}
}
fn star_discrepancy_1d(sorted: &[f64], n_f: f64) -> f64 {
let mut disc = 0.0f64;
for (i, &x) in sorted.iter().enumerate() {
let vol = x;
let count = i as f64; disc = disc.max((count / n_f - vol).abs());
let count_le = (i + 1) as f64;
disc = disc.max((count_le / n_f - vol).abs());
}
disc
}
pub fn scrambled_halton(
base: &[usize],
scramble_key: u64,
n_points: usize,
) -> IntegrateResult<Array2<f64>> {
if base.is_empty() {
return Err(IntegrateError::ValueError(
"base must not be empty".to_string(),
));
}
for (i, &b) in base.iter().enumerate() {
if b < 2 {
return Err(IntegrateError::ValueError(format!(
"base[{i}] must be >= 2, got {b}"
)));
}
}
let dim = base.len();
let max_digits = 32usize;
let mut rng = StdRng::seed_from_u64(scramble_key);
let mut perm: Vec<Vec<Vec<usize>>> = Vec::with_capacity(dim);
for &b in base.iter() {
let mut dim_perm = Vec::with_capacity(max_digits);
for _ in 0..max_digits {
let mut p: Vec<usize> = (0..b).collect();
for k in (1..b).rev() {
let j = rng.random_range(0..=(k as u64)) as usize;
p.swap(k, j);
}
dim_perm.push(p);
}
perm.push(dim_perm);
}
let mut result = Array2::<f64>::zeros((n_points, dim));
for i in 0..n_points {
for (d, &b) in base.iter().enumerate() {
let mut n = i + 1; let mut val = 0.0f64;
let mut denom = 1.0f64;
let mut digit_pos = 0usize;
while n > 0 {
denom *= b as f64;
let original_digit = n % b;
let scrambled = if digit_pos < max_digits {
perm[d][digit_pos][original_digit]
} else {
original_digit };
val += scrambled as f64 / denom;
n /= b;
digit_pos += 1;
}
result[[i, d]] = val.min(1.0 - f64::EPSILON); }
}
Ok(result)
}
fn first_n_primes(n: usize) -> Vec<usize> {
if n == 0 {
return Vec::new();
}
let mut primes = Vec::with_capacity(n);
let mut candidate = 2usize;
while primes.len() < n {
if is_prime(candidate) {
primes.push(candidate);
}
candidate += 1;
}
primes
}
fn is_prime(n: usize) -> bool {
if n < 2 {
return false;
}
if n == 2 {
return true;
}
if n % 2 == 0 {
return false;
}
let limit = (n as f64).sqrt() as usize + 1;
(3..=limit).step_by(2).all(|k| n % k != 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_halton_unit_cube() {
let mut seq = HaltonSequence::with_primes(3);
for _ in 0..100 {
let p = seq.next_point();
assert_eq!(p.len(), 3);
assert!(p.iter().all(|&x| x >= 0.0 && x < 1.0), "out of [0,1): {p:?}");
}
}
#[test]
fn test_halton_known_values_base2() {
let mut seq = HaltonSequence::new(vec![2]);
let p1 = seq.next_point();
let p2 = seq.next_point();
let p3 = seq.next_point();
assert!((p1[0] - 0.5).abs() < 1e-12, "p1={}", p1[0]);
assert!((p2[0] - 0.25).abs() < 1e-12, "p2={}", p2[0]);
assert!((p3[0] - 0.75).abs() < 1e-12, "p3={}", p3[0]);
}
#[test]
fn test_halton_reset() {
let mut seq = HaltonSequence::with_primes(2);
let first: Vec<Vec<f64>> = (0..5).map(|_| seq.next_point()).collect();
seq.reset();
let second: Vec<Vec<f64>> = (0..5).map(|_| seq.next_point()).collect();
for (a, b) in first.iter().zip(second.iter()) {
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-15);
}
}
}
#[test]
fn test_halton_generate_shape() {
let mut seq = HaltonSequence::with_primes(4);
let pts = seq.generate(128);
assert_eq!(pts.nrows(), 128);
assert_eq!(pts.ncols(), 4);
}
#[test]
fn test_sobol_unit_cube() {
let mut seq = SobolSequence::new(5);
for _ in 0..200 {
let p = seq.next_point();
assert_eq!(p.len(), 5);
assert!(
p.iter().all(|&x| x >= 0.0 && x <= 1.0),
"out of [0,1]: {p:?}"
);
}
}
#[test]
fn test_sobol_first_dim_van_der_corput() {
let mut seq = SobolSequence::new(1);
let _p0 = seq.next_point(); let p1 = seq.next_point();
let p2 = seq.next_point();
let p3 = seq.next_point();
assert!((p1[0] - 0.5).abs() < 1e-6, "p1={}", p1[0]);
assert!((p2[0] - 0.25).abs() < 1e-6, "p2={}", p2[0]);
assert!((p3[0] - 0.75).abs() < 1e-6, "p3={}", p3[0]);
}
#[test]
fn test_sobol_reset() {
let mut seq = SobolSequence::new(3);
let before: Vec<Vec<f64>> = (0..10).map(|_| seq.next_point()).collect();
seq.reset();
let after: Vec<Vec<f64>> = (0..10).map(|_| seq.next_point()).collect();
for (a, b) in before.iter().zip(after.iter()) {
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-12, "reset failed: {x} != {y}");
}
}
}
#[test]
fn test_lattice_rule_unit_cube() {
let lat = LatticeRule::korobov(3, 1973, 1024);
let mut lat = lat;
for _ in 0..1024 {
let p = lat.next_point();
assert_eq!(p.len(), 3);
assert!(
p.iter().all(|&x| x >= 0.0 && x <= 1.0),
"out of [0,1]: {p:?}"
);
}
}
#[test]
fn test_lattice_first_point_zero() {
let mut lat = LatticeRule::korobov(2, 3, 8);
let p0 = lat.next_point();
assert!(p0.iter().all(|&x| x == 0.0), "first point should be zero: {p0:?}");
}
#[test]
fn test_qmc_integrate_halton_1d() {
let mut seq = HaltonSequence::with_primes(1);
let val = qmc_integrate(
|x| x[0] * x[0],
&[(0.0, 1.0)],
&mut seq,
4096,
)
.expect("qmc_integrate failed");
assert!((val - 1.0 / 3.0).abs() < 0.01, "val={val}");
}
#[test]
fn test_qmc_integrate_sobol_2d() {
let mut seq = SobolSequence::new(2);
let val = qmc_integrate(
|x| x[0] * x[1],
&[(0.0, 1.0), (0.0, 1.0)],
&mut seq,
4096,
)
.expect("qmc_integrate 2d failed");
assert!((val - 0.25).abs() < 0.01, "val={val}");
}
#[test]
fn test_qmc_integrate_lattice_3d() {
let mut seq = LatticeRule::korobov(3, 1973, 4096);
let val = qmc_integrate(
|x| (x[0] + x[1] + x[2]) / 3.0,
&[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
&mut seq,
4096,
)
.expect("qmc_integrate 3d failed");
assert!((val - 0.5).abs() < 0.02, "val={val}");
}
#[test]
fn test_qmc_integrate_non_unit_bounds() {
let mut seq = HaltonSequence::with_primes(2);
let val = qmc_integrate(
|_| 1.0,
&[(0.0, 2.0), (0.0, 3.0)],
&mut seq,
1000,
)
.expect("qmc non-unit bounds failed");
assert!((val - 6.0).abs() < 1e-10, "val={val}");
}
#[test]
fn test_qmc_integrate_dim_mismatch() {
let mut seq = HaltonSequence::with_primes(2);
assert!(qmc_integrate(|x| x[0], &[(0.0, 1.0)], &mut seq, 100).is_err());
}
#[test]
fn test_star_discrepancy_halton() {
let mut seq = HaltonSequence::with_primes(2);
let pts = seq.generate(256);
let d = star_discrepancy(&pts).expect("discrepancy failed");
assert!(d < 0.15, "D*={d}");
assert!(d >= 0.0, "D* must be non-negative");
}
#[test]
fn test_star_discrepancy_1d_perfect() {
let n = 100usize;
let pts: Vec<Vec<f64>> = (0..n).map(|i| vec![(i as f64 + 0.5) / n as f64]).collect();
let arr = Array2::from_shape_fn((n, 1), |(i, _)| (i as f64 + 0.5) / n as f64);
let d = star_discrepancy(&arr).expect("discrepancy failed");
assert!(d < 0.02, "D*={d}");
}
#[test]
fn test_star_discrepancy_empty() {
let empty = Array2::<f64>::zeros((0, 2));
assert!(star_discrepancy(&empty).is_err());
}
#[test]
fn test_scrambled_halton_shape() {
let pts = scrambled_halton(&[2, 3, 5], 777, 128).expect("scramble failed");
assert_eq!(pts.nrows(), 128);
assert_eq!(pts.ncols(), 3);
}
#[test]
fn test_scrambled_halton_unit_cube() {
let pts = scrambled_halton(&[2, 3, 5, 7], 999, 256).expect("scramble failed");
assert!(
pts.iter().all(|&x| x >= 0.0 && x < 1.0),
"some points outside [0,1)"
);
}
#[test]
fn test_scrambled_halton_reproducible() {
let pts1 = scrambled_halton(&[2, 3], 42, 64).expect("scramble1 failed");
let pts2 = scrambled_halton(&[2, 3], 42, 64).expect("scramble2 failed");
assert_eq!(pts1, pts2, "scrambled Halton must be reproducible for same key");
}
#[test]
fn test_scrambled_halton_different_keys_differ() {
let pts1 = scrambled_halton(&[2, 3], 1, 32).expect("scramble1 failed");
let pts2 = scrambled_halton(&[2, 3], 2, 32).expect("scramble2 failed");
let same = pts1
.iter()
.zip(pts2.iter())
.all(|(a, b)| (a - b).abs() < 1e-15);
assert!(!same, "different keys produced identical points");
}
#[test]
fn test_scrambled_halton_invalid_base() {
assert!(scrambled_halton(&[1], 0, 10).is_err());
assert!(scrambled_halton(&[], 0, 10).is_err());
}
#[test]
fn test_first_n_primes() {
let p = first_n_primes(5);
assert_eq!(p, vec![2, 3, 5, 7, 11]);
}
#[test]
fn test_is_prime() {
assert!(!is_prime(0));
assert!(!is_prime(1));
assert!(is_prime(2));
assert!(is_prime(3));
assert!(!is_prime(4));
assert!(is_prime(97));
assert!(!is_prime(100));
}
}