use crate::error::QuadratureError;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
const PRIMES: [u32; 100] = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193,
197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307,
311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421,
431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541,
];
pub struct HaltonSequence {
dim: usize,
bases: Vec<u32>,
index: u64,
}
impl HaltonSequence {
pub fn new(dim: usize) -> Result<Self, QuadratureError> {
if dim == 0 {
return Err(QuadratureError::InvalidInput("dimension must be >= 1"));
}
if dim > PRIMES.len() {
return Err(QuadratureError::InvalidInput(
"Halton sequence supports at most 100 dimensions",
));
}
Ok(Self {
dim,
bases: PRIMES[..dim].to_vec(),
index: 0,
})
}
pub fn next_point(&mut self, point: &mut [f64]) {
assert!(point.len() >= self.dim);
self.index += 1;
for (j, p) in point.iter_mut().enumerate().take(self.dim) {
*p = radical_inverse(self.index, self.bases[j]);
}
}
#[must_use]
pub fn index(&self) -> u64 {
self.index
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
}
fn radical_inverse(mut n: u64, base: u32) -> f64 {
let base_f = f64::from(base);
let base_u64 = u64::from(base);
let mut result = 0.0;
let mut factor = 1.0 / base_f;
while n > 0 {
#[allow(clippy::cast_precision_loss)]
let digit = (n % base_u64) as f64;
result += digit * factor;
n /= base_u64;
factor /= base_f;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "std"))]
use alloc::vec;
#[test]
fn first_few_points() {
let mut hal = HaltonSequence::new(2).unwrap();
let mut pt = [0.0; 2];
hal.next_point(&mut pt);
assert!((pt[0] - 0.5).abs() < 1e-14); assert!((pt[1] - 1.0 / 3.0).abs() < 1e-14);
hal.next_point(&mut pt);
assert!((pt[0] - 0.25).abs() < 1e-14); assert!((pt[1] - 2.0 / 3.0).abs() < 1e-14);
hal.next_point(&mut pt);
assert!((pt[0] - 0.75).abs() < 1e-14); assert!((pt[1] - 1.0 / 9.0).abs() < 1e-14); }
#[test]
fn points_in_unit_cube() {
let mut hal = HaltonSequence::new(5).unwrap();
let mut pt = vec![0.0; 5];
for _ in 0..100 {
hal.next_point(&mut pt);
for &x in &pt {
assert!(x >= 0.0 && x < 1.0, "x={x} out of [0,1)");
}
}
}
#[test]
fn invalid_dim() {
assert!(HaltonSequence::new(0).is_err());
assert!(HaltonSequence::new(101).is_err());
}
}