use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cpu::kernels::sobol_data::ArchivedSobolPolynomial;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::runtime::cpu::kernels::sobol_data::get_polynomial;
pub const SOBOL_MAX_DIMENSIONS: usize = 21201;
pub const HALTON_MAX_DIMENSIONS: usize = 100;
pub const SOBOL_BITS: usize = 32;
#[inline]
pub fn compute_direction_vectors(poly: &ArchivedSobolPolynomial) -> [u32; SOBOL_BITS] {
let s = poly.degree as usize;
let a: u32 = poly.coeff.into();
let mut v = [0u32; SOBOL_BITS];
for (i, m) in poly.m_values.iter().enumerate() {
let m_native: u32 = (*m).into();
v[i] = m_native << (SOBOL_BITS - 1 - i);
}
for i in s..SOBOL_BITS {
let mut vi = v[i - s] ^ (v[i - s] >> s);
for j in 1..s {
if (a >> (s - 1 - j)) & 1 != 0 {
vi ^= v[i - j] >> j;
}
}
v[i] = vi;
}
v
}
#[inline]
pub fn dimension_zero_vectors() -> [u32; SOBOL_BITS] {
let mut v = [0u32; SOBOL_BITS];
for i in 0..SOBOL_BITS {
v[i] = 1u32 << (SOBOL_BITS - 1 - i);
}
v
}
#[inline]
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn compute_all_direction_vectors(dimension: usize) -> Vec<u32> {
let mut all_vectors = Vec::with_capacity(dimension * SOBOL_BITS);
for d in 0..dimension {
let v = if d == 0 {
dimension_zero_vectors()
} else {
let poly = get_polynomial(d + 1)
.expect("INTERNAL: sobol_data.bin corrupted - missing polynomial");
compute_direction_vectors(poly)
};
all_vectors.extend_from_slice(&v);
}
all_vectors
}
#[inline]
pub fn validate_basic_params(
n_points: usize,
dimension: usize,
max_dim: usize,
op: &'static str,
) -> Result<()> {
if n_points == 0 {
return Err(Error::InvalidArgument {
arg: "n_points",
reason: format!("{} requires at least 1 point", op),
});
}
if dimension == 0 {
return Err(Error::InvalidArgument {
arg: "dimension",
reason: format!("{} requires at least 1 dimension", op),
});
}
if dimension > max_dim {
return Err(Error::InvalidArgument {
arg: "dimension",
reason: format!(
"{} supports up to {} dimensions, got {}",
op, max_dim, dimension
),
});
}
Ok(())
}
#[inline]
pub fn validate_dtype(dtype: DType, supported_dtypes: &[DType], op: &'static str) -> Result<()> {
if !supported_dtypes.contains(&dtype) {
return Err(Error::UnsupportedDType { dtype, op });
}
Ok(())
}
#[inline]
pub fn validate_sobol_params(
n_points: usize,
dimension: usize,
dtype: DType,
supported_dtypes: &[DType],
op: &'static str,
) -> Result<()> {
validate_basic_params(n_points, dimension, SOBOL_MAX_DIMENSIONS, op)?;
validate_dtype(dtype, supported_dtypes, op)?;
Ok(())
}
#[inline]
pub fn validate_halton_params(
n_points: usize,
dimension: usize,
dtype: DType,
supported_dtypes: &[DType],
op: &'static str,
) -> Result<()> {
validate_basic_params(n_points, dimension, HALTON_MAX_DIMENSIONS, op)?;
validate_dtype(dtype, supported_dtypes, op)?;
Ok(())
}
#[inline]
pub fn validate_latin_hypercube_params(
n_samples: usize,
dimension: usize,
dtype: DType,
supported_dtypes: &[DType],
op: &'static str,
) -> Result<()> {
if n_samples == 0 {
return Err(Error::InvalidArgument {
arg: "n_samples",
reason: format!("{} requires at least 1 sample", op),
});
}
if dimension == 0 {
return Err(Error::InvalidArgument {
arg: "dimension",
reason: format!("{} requires at least 1 dimension", op),
});
}
validate_dtype(dtype, supported_dtypes, op)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_basic_params() {
assert!(validate_basic_params(10, 5, 100, "test").is_ok());
assert!(validate_basic_params(0, 5, 100, "test").is_err());
assert!(validate_basic_params(10, 0, 100, "test").is_err());
assert!(validate_basic_params(10, 101, 100, "test").is_err());
}
#[test]
fn test_validate_dtype() {
let supported = &[DType::F32, DType::F64];
assert!(validate_dtype(DType::F32, supported, "test").is_ok());
assert!(validate_dtype(DType::F64, supported, "test").is_ok());
assert!(validate_dtype(DType::I32, supported, "test").is_err());
}
#[test]
fn test_validate_sobol_params() {
let supported = &[DType::F32, DType::F64];
assert!(validate_sobol_params(10, 5, DType::F32, supported, "sobol").is_ok());
assert!(validate_sobol_params(10, 1000, DType::F32, supported, "sobol").is_ok());
assert!(validate_sobol_params(10, 21201, DType::F32, supported, "sobol").is_ok());
assert!(validate_sobol_params(10, 21202, DType::F32, supported, "sobol").is_err());
assert!(validate_sobol_params(10, 5, DType::I32, supported, "sobol").is_err());
}
#[test]
fn test_validate_halton_params() {
let supported = &[DType::F32, DType::F64];
assert!(validate_halton_params(10, 5, DType::F32, supported, "halton").is_ok());
assert!(validate_halton_params(10, 101, DType::F32, supported, "halton").is_err());
}
#[test]
fn test_validate_latin_hypercube_params() {
let supported = &[DType::F32, DType::F64];
assert!(
validate_latin_hypercube_params(10, 5, DType::F32, supported, "latin_hypercube")
.is_ok()
);
assert!(
validate_latin_hypercube_params(0, 5, DType::F32, supported, "latin_hypercube")
.is_err()
);
assert!(
validate_latin_hypercube_params(10, 0, DType::F32, supported, "latin_hypercube")
.is_err()
);
}
#[test]
fn test_dimension_zero_vectors() {
let v = dimension_zero_vectors();
assert_eq!(v.len(), SOBOL_BITS);
for i in 0..SOBOL_BITS {
assert_eq!(v[i], 1u32 << (SOBOL_BITS - 1 - i));
}
}
#[test]
#[cfg(any(feature = "cuda", feature = "wgpu"))]
fn test_compute_all_direction_vectors_length() {
let dim = 5;
let vectors = compute_all_direction_vectors(dim);
assert_eq!(vectors.len(), dim * SOBOL_BITS);
}
#[test]
#[cfg(any(feature = "cuda", feature = "wgpu"))]
fn test_compute_all_direction_vectors_dimension_0() {
let vectors = compute_all_direction_vectors(3);
let d0 = dimension_zero_vectors();
for i in 0..SOBOL_BITS {
assert_eq!(vectors[i], d0[i]);
}
}
}