use oxicuda_ptx::builder::BodyBuilder;
use oxicuda_ptx::ir::{PtxType, Register};
#[allow(dead_code)]
pub fn emit_log_normal_transform_f32(b: &mut BodyBuilder<'_>, normal_val: Register) -> Register {
let log2e = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {log2e}, 0f3FB8AA3B;"));
let scaled = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {scaled}, {normal_val}, {log2e};"));
let result = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("ex2.approx.f32 {result}, {scaled};"));
result
}
#[allow(dead_code)]
pub fn emit_log_normal_transform_f64(b: &mut BodyBuilder<'_>, normal_val: Register) -> Register {
let narrow = b.cvt_f64_to_f32(normal_val);
let exp_f32 = emit_log_normal_transform_f32(b, narrow);
b.cvt_f32_to_f64(exp_f32)
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
#[test]
fn log_normal_f32_compiles() {
let ptx = KernelBuilder::new("test_lognormal_f32")
.target(SmVersion::Sm80)
.param("z", PtxType::F32)
.body(|b| {
let z = b.load_param_f32("z");
let _result = emit_log_normal_transform_f32(b, z);
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("ex2.approx.f32"));
assert!(ptx.contains("0f3FB8AA3B")); }
#[test]
fn log_normal_f64_compiles() {
let ptx = KernelBuilder::new("test_lognormal_f64")
.target(SmVersion::Sm80)
.param("z", PtxType::F64)
.body(|b| {
let z = b.load_param_f64("z");
let _result = emit_log_normal_transform_f64(b, z);
b.ret();
})
.build();
let ptx = ptx.expect("should compile");
assert!(ptx.contains("cvt.rn.f32.f64")); assert!(ptx.contains("cvt.f64.f32")); }
}