use oxicuda_ptx::builder::BodyBuilder;
use oxicuda_ptx::ir::{PtxType, Register};
#[allow(dead_code)]
pub fn emit_box_muller_f32(
b: &mut BodyBuilder<'_>,
u1: Register,
u2: Register,
) -> (Register, Register) {
let eps = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {eps}, 0f33800000;")); let u1_safe = b.max_f32(u1, eps);
let lg2_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("lg2.approx.f32 {lg2_u1}, {u1_safe};"));
let ln2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {ln2}, 0f3F317218;")); let ln_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {ln_u1}, {lg2_u1}, {ln2};"));
let neg2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {neg2}, 0fC0000000;")); let neg2ln = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {neg2ln}, {neg2}, {ln_u1};"));
let radius = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sqrt.approx.f32 {radius}, {neg2ln};"));
let two_pi = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {two_pi}, 0f40C90FDB;")); let angle = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {angle}, {two_pi}, {u2};"));
let cos_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cos.approx.f32 {cos_val}, {angle};"));
let z0 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {z0}, {radius}, {cos_val};"));
let sin_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sin.approx.f32 {sin_val}, {angle};"));
let z1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {z1}, {radius}, {sin_val};"));
(z0, z1)
}
#[allow(dead_code)]
pub fn emit_box_muller_f64(
b: &mut BodyBuilder<'_>,
u1: Register,
u2: Register,
) -> (Register, Register) {
let (z0_f32, z1_f32) = emit_box_muller_f32(b, u1, u2);
let z0 = b.cvt_f32_to_f64(z0_f32);
let z1 = b.cvt_f32_to_f64(z1_f32);
(z0, z1)
}
#[allow(dead_code)]
pub fn emit_normal_scale(
b: &mut BodyBuilder<'_>,
z: Register,
mean: Register,
stddev: Register,
precision: PtxType,
) -> Register {
match precision {
PtxType::F32 => b.fma_f32(stddev, z, mean),
PtxType::F64 => b.fma_f64(stddev, z, mean),
_ => z, }
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
#[test]
fn box_muller_f32_compiles() {
let ptx = KernelBuilder::new("test_bm_f32")
.target(SmVersion::Sm80)
.param("u1", PtxType::U32)
.param("u2", PtxType::U32)
.body(|b| {
let u1_raw = b.load_param_u32("u1");
let u2_raw = b.load_param_u32("u2");
let u1_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u1_f}, {u1_raw};"));
let u2_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u2_f}, {u2_raw};"));
let (_z0, _z1) = emit_box_muller_f32(b, u1_f, u2_f);
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("lg2.approx.f32"));
assert!(ptx.contains("cos.approx.f32"));
assert!(ptx.contains("sin.approx.f32"));
assert!(ptx.contains("sqrt.approx.f32"));
}
#[test]
fn box_muller_f64_compiles() {
let ptx = KernelBuilder::new("test_bm_f64")
.target(SmVersion::Sm80)
.param("u1", PtxType::U32)
.param("u2", PtxType::U32)
.body(|b| {
let u1_raw = b.load_param_u32("u1");
let u2_raw = b.load_param_u32("u2");
let u1_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u1_f}, {u1_raw};"));
let u2_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u2_f}, {u2_raw};"));
let (_z0, _z1) = emit_box_muller_f64(b, u1_f, u2_f);
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("cvt.f64.f32"));
}
#[test]
fn normal_scale_f32_compiles() {
let ptx = KernelBuilder::new("test_nscale_f32")
.target(SmVersion::Sm80)
.param("z", PtxType::F32)
.param("mean", PtxType::F32)
.param("stddev", PtxType::F32)
.body(|b| {
let z = b.load_param_f32("z");
let mean = b.load_param_f32("mean");
let stddev = b.load_param_f32("stddev");
let _result = emit_normal_scale(b, z, mean, stddev, PtxType::F32);
b.ret();
})
.build();
assert!(ptx.is_ok());
}
}