use std::sync::Arc;
use oxicuda_driver::context::Context;
use oxicuda_driver::module::Module;
use oxicuda_driver::stream::Stream;
use oxicuda_launch::grid::grid_size_for;
use oxicuda_launch::kernel::Kernel;
use oxicuda_launch::params::LaunchParams;
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::error::PtxGenError;
use oxicuda_ptx::ir::PtxType;
use crate::error::{RandError, RandResult};
pub const MAX_SOBOL_DIMENSION: u32 = 20;
const DIRECTION_BITS: usize = 32;
#[allow(dead_code)]
fn van_der_corput_directions() -> [u32; DIRECTION_BITS] {
let mut dirs = [0u32; DIRECTION_BITS];
for (i, dir) in dirs.iter_mut().enumerate() {
*dir = 1u32 << (31 - i);
}
dirs
}
#[allow(dead_code)]
static SOBOL_INIT_NUMBERS: &[[u32; 4]] = &[
[1, 0, 0, 0],
[2, 1, 1, 0],
[3, 1, 1, 1],
[3, 2, 1, 3],
[4, 1, 1, 1],
[4, 4, 1, 3],
[5, 2, 1, 1],
[5, 4, 1, 3],
[5, 7, 1, 1],
[5, 11, 1, 3],
[5, 13, 1, 1],
[5, 14, 1, 3],
[6, 1, 1, 1],
[6, 13, 1, 3],
[6, 16, 1, 1],
[7, 1, 1, 1],
[7, 4, 1, 3],
[7, 7, 1, 1],
[7, 8, 1, 3],
];
pub(crate) fn compute_direction_numbers(dimension: u32) -> RandResult<[u32; DIRECTION_BITS]> {
if dimension == 0 || dimension > MAX_SOBOL_DIMENSION {
return Err(RandError::InvalidSize(format!(
"Sobol dimension must be 1..={MAX_SOBOL_DIMENSION}, got {dimension}"
)));
}
if dimension == 1 {
return Ok(van_der_corput_directions());
}
let idx = (dimension - 2) as usize;
let init = if idx < SOBOL_INIT_NUMBERS.len() {
&SOBOL_INIT_NUMBERS[idx]
} else {
return Err(RandError::InvalidSize(format!(
"Sobol dimension {dimension} exceeds supported range"
)));
};
let mut dirs = [0u32; DIRECTION_BITS];
let degree = init[0] as usize;
for i in 0..degree.min(DIRECTION_BITS) {
let m_val = if i < 2 {
init[2 + i]
} else {
1u32 | (2 * i as u32)
};
dirs[i] = m_val << (31 - i);
}
let poly_coeff = init[1];
for i in degree..DIRECTION_BITS {
let mut v = dirs[i - degree] >> degree;
v ^= dirs[i - degree];
for j in 1..degree {
if (poly_coeff >> (degree - 1 - j)) & 1 == 1 {
v ^= dirs[i - j];
}
}
dirs[i] = v;
}
Ok(dirs)
}
#[allow(dead_code)]
fn generate_sobol_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
KernelBuilder::new("sobol_generate")
.target(sm)
.param("out_ptr", PtxType::U64)
.param("dir_ptr", PtxType::U64)
.param("n_points", PtxType::U32)
.param("base_index", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n_points");
b.if_lt_u32(gid.clone(), n_reg, move |b| {
let out_ptr = b.load_param_u64("out_ptr");
let dir_ptr = b.load_param_u64("dir_ptr");
let base_index = b.load_param_u32("base_index");
let index = b.add_u32(base_index, gid.clone());
let shifted = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shr.u32 {shifted}, {index}, 1;"));
let gray = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("xor.b32 {gray}, {index}, {shifted};"));
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {result}, 0;"));
b.unroll(DIRECTION_BITS as u32, |b, bit_idx| {
let bit_pred = b.alloc_reg(PtxType::Pred);
let mask = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {mask}, {};", 1u32 << bit_idx));
let masked = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("and.b32 {masked}, {gray}, {mask};"));
b.raw_ptx(&format!("setp.ne.u32 {bit_pred}, {masked}, 0;"));
let dir_offset = (bit_idx as u64) * 4; let dir_addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("add.u64 {dir_addr}, {dir_ptr}, {dir_offset};"));
let dir_val = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("ld.global.u32 {dir_val}, [{dir_addr}];"));
let xored = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("xor.b32 {xored}, {result}, {dir_val};"));
b.raw_ptx(&format!("@{bit_pred} mov.u32 {result}, {xored};"));
});
let fval = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {fval}, {result};"));
let scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {scale}, 0f2F800000;")); let fresult = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {fresult}, {fval}, {scale};"));
let addr = b.byte_offset_addr(out_ptr, gid.clone(), 4);
b.store_global_f32(addr, fresult);
});
b.ret();
})
.build()
}
pub struct SobolGenerator {
#[allow(dead_code)]
dimension: u32,
#[allow(dead_code)]
n_generated: u64,
#[allow(dead_code)]
direction_numbers: Vec<[u32; DIRECTION_BITS]>,
#[allow(dead_code)]
context: Arc<Context>,
#[allow(dead_code)]
stream: Stream,
#[allow(dead_code)]
sm_version: SmVersion,
}
impl SobolGenerator {
pub fn new(dimension: u32, ctx: &Arc<Context>) -> RandResult<Self> {
if dimension == 0 || dimension > MAX_SOBOL_DIMENSION {
return Err(RandError::InvalidSize(format!(
"Sobol dimension must be 1..={MAX_SOBOL_DIMENSION}, got {dimension}"
)));
}
let mut dir_numbers = Vec::with_capacity(dimension as usize);
for d in 1..=dimension {
dir_numbers.push(compute_direction_numbers(d)?);
}
let stream = Stream::new(ctx).map_err(RandError::Cuda)?;
Ok(Self {
dimension,
n_generated: 0,
direction_numbers: dir_numbers,
context: Arc::clone(ctx),
stream,
sm_version: SmVersion::Sm80,
})
}
pub fn generate(&mut self, output: &mut DeviceBuffer<f32>, n_points: u32) -> RandResult<()> {
if output.len() < n_points as usize {
return Err(RandError::InvalidSize(format!(
"output buffer has {} elements but {} requested",
output.len(),
n_points
)));
}
let ptx_source = generate_sobol_ptx(self.sm_version)?;
let module = Arc::new(Module::from_ptx(&ptx_source).map_err(RandError::Cuda)?);
let kernel = Kernel::from_module(module, "sobol_generate").map_err(RandError::Cuda)?;
let dirs = &self.direction_numbers[0];
let dir_buf = DeviceBuffer::<u32>::from_host(dirs).map_err(RandError::Cuda)?;
let base_index = self.n_generated as u32;
let grid = grid_size_for(n_points, 256);
let params = LaunchParams::new(grid, 256u32);
let args = (
output.as_device_ptr(),
dir_buf.as_device_ptr(),
n_points,
base_index,
);
kernel
.launch(¶ms, &self.stream, &args)
.map_err(RandError::Cuda)?;
self.stream.synchronize().map_err(RandError::Cuda)?;
self.n_generated += n_points as u64;
Ok(())
}
#[allow(dead_code)]
pub fn dimension(&self) -> u32 {
self.dimension
}
#[allow(dead_code)]
pub fn points_generated(&self) -> u64 {
self.n_generated
}
#[allow(dead_code)]
pub fn reset(&mut self) {
self.n_generated = 0;
}
}
#[allow(dead_code)]
pub fn gray_code_rank(n: u32) -> u32 {
(!n).trailing_zeros()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gray_code_rank_values() {
assert_eq!(gray_code_rank(0), 0); assert_eq!(gray_code_rank(1), 1); assert_eq!(gray_code_rank(2), 0); assert_eq!(gray_code_rank(3), 2); assert_eq!(gray_code_rank(7), 3); }
#[test]
fn van_der_corput_directions_are_powers_of_two() {
let dirs = van_der_corput_directions();
for (i, &d) in dirs.iter().enumerate() {
assert_eq!(d, 1u32 << (31 - i));
}
}
#[test]
fn compute_dimension_1_is_van_der_corput() {
let dirs = compute_direction_numbers(1).expect("dim 1 should succeed");
let vdc = van_der_corput_directions();
assert_eq!(dirs, vdc);
}
#[test]
fn compute_dimension_out_of_range() {
assert!(compute_direction_numbers(0).is_err());
assert!(compute_direction_numbers(MAX_SOBOL_DIMENSION + 1).is_err());
}
#[test]
fn sobol_ptx_generates() {
let ptx = generate_sobol_ptx(SmVersion::Sm80);
let ptx = ptx.expect("should generate Sobol PTX");
assert!(ptx.contains(".entry sobol_generate"));
assert!(ptx.contains("xor.b32"));
}
#[test]
fn max_dimension_computable() {
for d in 1..=MAX_SOBOL_DIMENSION {
let result = compute_direction_numbers(d);
assert!(result.is_ok(), "dimension {d} should compute");
}
}
}