use oxicuda_ptx::builder::BodyBuilder;
use oxicuda_ptx::ir::{PtxType, Register};
#[allow(dead_code)]
pub const POISSON_LAMBDA_THRESHOLD: f32 = 30.0;
#[allow(dead_code)]
pub fn emit_poisson_small_f32<F>(
b: &mut BodyBuilder<'_>,
lambda_reg: Register,
uniform_gen: F,
) -> Register
where
F: Fn(&mut BodyBuilder<'_>) -> Register,
{
let log2e = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;")); let neg_lambda = b.neg_f32(lambda_reg);
let scaled = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {scaled}, {neg_lambda}, {log2e};"));
let limit = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("ex2.approx.f32 {limit}, {scaled};"));
let k = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {k}, 0;"));
let p = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {p}, 0f3F800000;"));
let loop_label = b.fresh_label("poisson_loop");
let end_label = b.fresh_label("poisson_end");
b.label(&loop_label);
b.raw_ptx(&format!("add.u32 {k}, {k}, 1;"));
let u = uniform_gen(b);
b.raw_ptx(&format!("mul.rn.f32 {p}, {p}, {u};"));
let pred_continue = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.gt.f32 {pred_continue}, {p}, {limit};"));
b.branch_if(pred_continue, &loop_label);
b.label(&end_label);
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {result}, {k}, 1;"));
result
}
#[allow(dead_code)]
pub fn emit_poisson_large_f32(
b: &mut BodyBuilder<'_>,
lambda_reg: Register,
normal_reg: Register,
) -> Register {
let sqrt_lambda = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sqrt.approx.f32 {sqrt_lambda}, {lambda_reg};"));
let approx = b.fma_f32(sqrt_lambda, normal_reg, lambda_reg);
let zero = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {zero}, 0f00000000;")); let clamped = b.max_f32(approx, zero);
let rounded = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rni.f32.f32 {rounded}, {clamped};"));
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.rzi.u32.f32 {result}, {rounded};"));
result
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
#[test]
fn poisson_small_compiles() {
let ptx = KernelBuilder::new("test_poisson_small")
.target(SmVersion::Sm80)
.param("lambda", PtxType::F32)
.body(|b| {
let lambda = b.load_param_f32("lambda");
let _result = emit_poisson_small_f32(b, lambda, |b| {
let half = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {half}, 0f3F000000;")); half
});
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("ex2.approx.f32")); assert!(ptx.contains("setp.gt.f32")); }
#[test]
fn poisson_large_compiles() {
let ptx = KernelBuilder::new("test_poisson_large")
.target(SmVersion::Sm80)
.param("lambda", PtxType::F32)
.param("z", PtxType::F32)
.body(|b| {
let lambda = b.load_param_f32("lambda");
let z = b.load_param_f32("z");
let _result = emit_poisson_large_f32(b, lambda, z);
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("sqrt.approx.f32"));
assert!(ptx.contains("cvt.rni.f32.f32")); }
}