use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::error::PtxGenError;
use oxicuda_ptx::ir::PtxType;
#[allow(dead_code)]
const M1: u64 = 4_294_967_087;
#[allow(dead_code)]
const M2: u64 = 4_294_944_443;
#[allow(dead_code)]
const A12: u64 = 1_403_580;
#[allow(dead_code)]
const A13N: u64 = 810_728;
#[allow(dead_code)]
const A21: u64 = 527_612;
#[allow(dead_code)]
const A23N: u64 = 1_370_589;
pub fn generate_mrg32k3a_uniform_ptx(
precision: PtxType,
sm: SmVersion,
) -> Result<String, PtxGenError> {
let kernel_name = match precision {
PtxType::F32 => "mrg32k3a_uniform_f32",
PtxType::F64 => "mrg32k3a_uniform_f64",
_ => return Err(PtxGenError::InvalidType(format!("{precision:?}"))),
};
let stride_bytes: u32 = precision.size_bytes() as u32;
KernelBuilder::new(kernel_name)
.target(sm)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("seed", PtxType::U32)
.param("offset_lo", PtxType::U32)
.param("offset_hi", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid.clone(), n_reg, move |b| {
let out_ptr = b.load_param_u64("out_ptr");
let seed = b.load_param_u32("seed");
let s10 = b.alloc_reg(PtxType::U32);
let s11 = b.alloc_reg(PtxType::U32);
let s12 = b.alloc_reg(PtxType::U32);
let s20 = b.alloc_reg(PtxType::U32);
let s21 = b.alloc_reg(PtxType::U32);
let s22 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("xor.b32 {s10}, {seed}, {gid};"));
let scr1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr1}, {gid}, 1812433253;"));
b.raw_ptx(&format!("xor.b32 {s11}, {seed}, {scr1};"));
let scr2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr2}, {gid}, 1566083941;"));
b.raw_ptx(&format!("xor.b32 {s12}, {seed}, {scr2};"));
let scr3 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr3}, {gid}, 1103515245;"));
b.raw_ptx(&format!("xor.b32 {s20}, {seed}, {scr3};"));
let scr4 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr4}, {gid}, 214013;"));
b.raw_ptx(&format!("xor.b32 {s21}, {seed}, {scr4};"));
let scr5 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr5}, {gid}, 2531011;"));
b.raw_ptx(&format!("xor.b32 {s22}, {seed}, {scr5};"));
emit_clamp_nonzero(b, &s10);
emit_clamp_nonzero(b, &s11);
emit_clamp_nonzero(b, &s12);
emit_clamp_nonzero(b, &s20);
emit_clamp_nonzero(b, &s21);
emit_clamp_nonzero(b, &s22);
emit_mrg32k3a_step(b, &s10, &s11, &s12, &s20, &s21, &s22);
let diff = b.alloc_reg(PtxType::U32);
let m1_const = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {m1_const}, {};", M1 as u32));
let pred_ge = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_ge}, {s10}, {s20};"));
let raw_diff = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {raw_diff}, {s10}, {s20};"));
let rev_diff = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {rev_diff}, {s20}, {s10};"));
let wrapped = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {wrapped}, {m1_const}, {rev_diff};"));
b.raw_ptx(&format!(
"selp.u32 {diff}, {raw_diff}, {wrapped}, {pred_ge};"
));
match precision {
PtxType::F32 => {
let fval = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {fval}, {diff};"));
let scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {scale}, 0f2F800000;"));
let result = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {result}, {fval}, {scale};"));
let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
b.store_global_f32(addr, result);
}
PtxType::F64 => {
let fval = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("cvt.rn.f64.u32 {fval}, {diff};"));
let scale = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mov.f64 {scale}, 0d3DF0000000000000;"));
let result = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mul.rn.f64 {result}, {fval}, {scale};"));
let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
b.store_global_f64(addr, result);
}
_ => {}
}
});
b.ret();
})
.build()
}
pub fn generate_mrg32k3a_normal_ptx(
precision: PtxType,
sm: SmVersion,
) -> Result<String, PtxGenError> {
let kernel_name = match precision {
PtxType::F32 => "mrg32k3a_normal_f32",
PtxType::F64 => "mrg32k3a_normal_f64",
_ => return Err(PtxGenError::InvalidType(format!("{precision:?}"))),
};
let stride_bytes: u32 = precision.size_bytes() as u32;
let mean_ty = precision;
let stddev_ty = precision;
KernelBuilder::new(kernel_name)
.target(sm)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("seed", PtxType::U32)
.param("offset_lo", PtxType::U32)
.param("offset_hi", PtxType::U32)
.param("mean", mean_ty)
.param("stddev", stddev_ty)
.max_threads_per_block(256)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid.clone(), n_reg, move |b| {
let out_ptr = b.load_param_u64("out_ptr");
let seed = b.load_param_u32("seed");
let s10 = b.alloc_reg(PtxType::U32);
let s11 = b.alloc_reg(PtxType::U32);
let s12 = b.alloc_reg(PtxType::U32);
let s20 = b.alloc_reg(PtxType::U32);
let s21 = b.alloc_reg(PtxType::U32);
let s22 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("xor.b32 {s10}, {seed}, {gid};"));
let scr1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr1}, {gid}, 1812433253;"));
b.raw_ptx(&format!("xor.b32 {s11}, {seed}, {scr1};"));
let scr2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr2}, {gid}, 1566083941;"));
b.raw_ptx(&format!("xor.b32 {s12}, {seed}, {scr2};"));
let scr3 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr3}, {gid}, 1103515245;"));
b.raw_ptx(&format!("xor.b32 {s20}, {seed}, {scr3};"));
let scr4 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr4}, {gid}, 214013;"));
b.raw_ptx(&format!("xor.b32 {s21}, {seed}, {scr4};"));
let scr5 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr5}, {gid}, 2531011;"));
b.raw_ptx(&format!("xor.b32 {s22}, {seed}, {scr5};"));
emit_clamp_nonzero(b, &s10);
emit_clamp_nonzero(b, &s11);
emit_clamp_nonzero(b, &s12);
emit_clamp_nonzero(b, &s20);
emit_clamp_nonzero(b, &s21);
emit_clamp_nonzero(b, &s22);
emit_mrg32k3a_step(b, &s10, &s11, &s12, &s20, &s21, &s22);
let u1_raw = emit_mrg32k3a_output(b, &s10, &s20);
emit_mrg32k3a_step(b, &s10, &s11, &s12, &s20, &s21, &s22);
let u2_raw = emit_mrg32k3a_output(b, &s10, &s20);
match precision {
PtxType::F32 => {
let mean_reg = b.load_param_f32("mean");
let stddev_reg = b.load_param_f32("stddev");
let scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {scale}, 0f2F800000;"));
let u1_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u1_f}, {u1_raw};"));
let u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {u1}, {u1_f}, {scale};"));
let eps = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {eps}, 0f33800000;"));
let u1_safe = b.max_f32(u1, eps);
let u2_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u2_f}, {u2_raw};"));
let u2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {u2}, {u2_f}, {scale};"));
let lg2_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("lg2.approx.f32 {lg2_u1}, {u1_safe};"));
let ln2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {ln2}, 0f3F317218;"));
let ln_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {ln_u1}, {lg2_u1}, {ln2};"));
let neg2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {neg2}, 0fC0000000;"));
let neg2ln = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {neg2ln}, {neg2}, {ln_u1};"));
let radius = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sqrt.approx.f32 {radius}, {neg2ln};"));
let two_pi = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {two_pi}, 0f40C90FDB;"));
let angle = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {angle}, {two_pi}, {u2};"));
let cos_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cos.approx.f32 {cos_val}, {angle};"));
let z = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {z}, {radius}, {cos_val};"));
let result = b.fma_f32(stddev_reg, z, mean_reg);
let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
b.store_global_f32(addr, result);
}
PtxType::F64 => {
let mean_reg = b.load_param_f64("mean");
let stddev_reg = b.load_param_f64("stddev");
let scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {scale}, 0f2F800000;"));
let u1_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u1_f}, {u1_raw};"));
let u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {u1}, {u1_f}, {scale};"));
let eps = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {eps}, 0f33800000;"));
let u1_safe = b.max_f32(u1, eps);
let u2_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {u2_f}, {u2_raw};"));
let u2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {u2}, {u2_f}, {scale};"));
let lg2_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("lg2.approx.f32 {lg2_u1}, {u1_safe};"));
let ln2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {ln2}, 0f3F317218;"));
let ln_u1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {ln_u1}, {lg2_u1}, {ln2};"));
let neg2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {neg2}, 0fC0000000;"));
let neg2ln = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {neg2ln}, {neg2}, {ln_u1};"));
let radius = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sqrt.approx.f32 {radius}, {neg2ln};"));
let two_pi = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {two_pi}, 0f40C90FDB;"));
let angle = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {angle}, {two_pi}, {u2};"));
let cos_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cos.approx.f32 {cos_val}, {angle};"));
let z32 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {z32}, {radius}, {cos_val};"));
let z64 = b.cvt_f32_to_f64(z32);
let result = b.fma_f64(stddev_reg, z64, mean_reg);
let addr = b.byte_offset_addr(out_ptr, gid.clone(), stride_bytes);
b.store_global_f64(addr, result);
}
_ => {}
}
});
b.ret();
})
.build()
}
pub fn generate_mrg32k3a_u32_ptx(sm: SmVersion) -> Result<String, PtxGenError> {
KernelBuilder::new("mrg32k3a_u32")
.target(sm)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("seed", PtxType::U32)
.param("offset_lo", PtxType::U32)
.param("offset_hi", PtxType::U32)
.max_threads_per_block(256)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
b.if_lt_u32(gid.clone(), n_reg, move |b| {
let out_ptr = b.load_param_u64("out_ptr");
let seed = b.load_param_u32("seed");
let s10 = b.alloc_reg(PtxType::U32);
let s11 = b.alloc_reg(PtxType::U32);
let s12 = b.alloc_reg(PtxType::U32);
let s20 = b.alloc_reg(PtxType::U32);
let s21 = b.alloc_reg(PtxType::U32);
let s22 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("xor.b32 {s10}, {seed}, {gid};"));
let scr1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr1}, {gid}, 1812433253;"));
b.raw_ptx(&format!("xor.b32 {s11}, {seed}, {scr1};"));
let scr2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr2}, {gid}, 1566083941;"));
b.raw_ptx(&format!("xor.b32 {s12}, {seed}, {scr2};"));
let scr3 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr3}, {gid}, 1103515245;"));
b.raw_ptx(&format!("xor.b32 {s20}, {seed}, {scr3};"));
let scr4 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr4}, {gid}, 214013;"));
b.raw_ptx(&format!("xor.b32 {s21}, {seed}, {scr4};"));
let scr5 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {scr5}, {gid}, 2531011;"));
b.raw_ptx(&format!("xor.b32 {s22}, {seed}, {scr5};"));
emit_clamp_nonzero(b, &s10);
emit_clamp_nonzero(b, &s11);
emit_clamp_nonzero(b, &s12);
emit_clamp_nonzero(b, &s20);
emit_clamp_nonzero(b, &s21);
emit_clamp_nonzero(b, &s22);
emit_mrg32k3a_step(b, &s10, &s11, &s12, &s20, &s21, &s22);
let output = emit_mrg32k3a_output(b, &s10, &s20);
let addr = b.byte_offset_addr(out_ptr, gid.clone(), 4);
b.raw_ptx(&format!("st.global.u32 [{addr}], {output};"));
});
b.ret();
})
.build()
}
fn emit_mrg32k3a_step(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
s10: &oxicuda_ptx::ir::Register,
s11: &oxicuda_ptx::ir::Register,
s12: &oxicuda_ptx::ir::Register,
s20: &oxicuda_ptx::ir::Register,
s21: &oxicuda_ptx::ir::Register,
s22: &oxicuda_ptx::ir::Register,
) {
b.comment("MRG32k3a step - component 1");
let a12_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {a12_reg}, {};", A12));
let a13n_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {a13n_reg}, {};", A13N));
let prod1 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mul.wide.u32 {prod1}, {a12_reg}, {s11};"));
let prod2 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mul.wide.u32 {prod2}, {a13n_reg}, {s10};"));
let m1_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {m1_64}, {};", M1));
let pred_ge1 = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u64 {pred_ge1}, {prod1}, {prod2};"));
let abs_diff1 = b.alloc_reg(PtxType::U64);
let neg_diff1 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("sub.u64 {abs_diff1}, {prod1}, {prod2};"));
b.raw_ptx(&format!("sub.u64 {neg_diff1}, {prod2}, {prod1};"));
let q1 = b.alloc_reg(PtxType::U64);
let r1 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("div.u64 {q1}, {abs_diff1}, {m1_64};"));
b.raw_ptx(&format!("mul.lo.u64 {r1}, {q1}, {m1_64};"));
b.raw_ptx(&format!("sub.u64 {r1}, {abs_diff1}, {r1};"));
let q1n = b.alloc_reg(PtxType::U64);
let r1n = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("div.u64 {q1n}, {neg_diff1}, {m1_64};"));
b.raw_ptx(&format!("mul.lo.u64 {r1n}, {q1n}, {m1_64};"));
b.raw_ptx(&format!("sub.u64 {r1n}, {neg_diff1}, {r1n};"));
let adj1 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("sub.u64 {adj1}, {m1_64}, {r1n};"));
let pred_r1n_zero = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u64 {pred_r1n_zero}, {r1n}, 0;"));
let neg_result1 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!(
"selp.u64 {neg_result1}, 0, {adj1}, {pred_r1n_zero};"
));
let p1_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!(
"selp.u64 {p1_64}, {r1}, {neg_result1}, {pred_ge1};"
));
let p1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {p1}, {p1_64};"));
b.raw_ptx(&format!("mov.u32 {s10}, {s11};"));
b.raw_ptx(&format!("mov.u32 {s11}, {s12};"));
b.raw_ptx(&format!("mov.u32 {s12}, {p1};"));
b.comment("MRG32k3a step - component 2");
let a21_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {a21_reg}, {};", A21));
let a23n_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {a23n_reg}, {};", A23N));
let prod3 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mul.wide.u32 {prod3}, {a21_reg}, {s22};"));
let prod4 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mul.wide.u32 {prod4}, {a23n_reg}, {s20};"));
let m2_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {m2_64}, {};", M2));
let pred_ge2 = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u64 {pred_ge2}, {prod3}, {prod4};"));
let abs_diff2 = b.alloc_reg(PtxType::U64);
let neg_diff2 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("sub.u64 {abs_diff2}, {prod3}, {prod4};"));
b.raw_ptx(&format!("sub.u64 {neg_diff2}, {prod4}, {prod3};"));
let q2 = b.alloc_reg(PtxType::U64);
let r2 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("div.u64 {q2}, {abs_diff2}, {m2_64};"));
b.raw_ptx(&format!("mul.lo.u64 {r2}, {q2}, {m2_64};"));
b.raw_ptx(&format!("sub.u64 {r2}, {abs_diff2}, {r2};"));
let q2n = b.alloc_reg(PtxType::U64);
let r2n = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("div.u64 {q2n}, {neg_diff2}, {m2_64};"));
b.raw_ptx(&format!("mul.lo.u64 {r2n}, {q2n}, {m2_64};"));
b.raw_ptx(&format!("sub.u64 {r2n}, {neg_diff2}, {r2n};"));
let adj2 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("sub.u64 {adj2}, {m2_64}, {r2n};"));
let pred_r2n_zero = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u64 {pred_r2n_zero}, {r2n}, 0;"));
let neg_result2 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!(
"selp.u64 {neg_result2}, 0, {adj2}, {pred_r2n_zero};"
));
let p2_64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!(
"selp.u64 {p2_64}, {r2}, {neg_result2}, {pred_ge2};"
));
let p2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {p2}, {p2_64};"));
b.raw_ptx(&format!("mov.u32 {s20}, {s21};"));
b.raw_ptx(&format!("mov.u32 {s21}, {s22};"));
b.raw_ptx(&format!("mov.u32 {s22}, {p2};"));
}
fn emit_mrg32k3a_output(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
s12: &oxicuda_ptx::ir::Register,
s22: &oxicuda_ptx::ir::Register,
) -> oxicuda_ptx::ir::Register {
let m1_const = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {m1_const}, {};", M1 as u32));
let pred_ge = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_ge}, {s12}, {s22};"));
let raw_diff = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {raw_diff}, {s12}, {s22};"));
let rev_diff = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {rev_diff}, {s22}, {s12};"));
let wrapped = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("sub.u32 {wrapped}, {m1_const}, {rev_diff};"));
let result = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"selp.u32 {result}, {raw_diff}, {wrapped}, {pred_ge};"
));
result
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ModMatrix3x3 {
data: [[u64; 3]; 3],
}
impl ModMatrix3x3 {
fn identity() -> Self {
Self {
data: [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
}
}
fn zero() -> Self {
Self {
data: [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
}
}
fn mul_mod(&self, other: &Self, modulus: u64) -> Self {
let mut result = Self::zero();
for i in 0..3 {
for j in 0..3 {
let mut acc: u128 = 0;
for k in 0..3 {
acc += (self.data[i][k] as u128) * (other.data[k][j] as u128);
}
result.data[i][j] = (acc % (modulus as u128)) as u64;
}
}
result
}
fn pow_mod(&self, mut exp: u64, modulus: u64) -> Self {
let mut result = Self::identity();
let mut base = self.clone();
while exp > 0 {
if exp & 1 == 1 {
result = result.mul_mod(&base, modulus);
}
base = base.mul_mod(&base, modulus);
exp >>= 1;
}
result
}
fn mul_vec_mod(&self, vec: &[u64; 3], modulus: u64) -> [u64; 3] {
let mut result = [0u64; 3];
for (i, row) in self.data.iter().enumerate() {
let mut acc: u128 = 0;
for (col_val, vec_val) in row.iter().zip(vec.iter()) {
acc += (*col_val as u128) * (*vec_val as u128);
}
result[i] = (acc % (modulus as u128)) as u64;
}
result
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mrg32k3aState {
pub comp1: [u64; 3],
pub comp2: [u64; 3],
}
impl Mrg32k3aState {
pub fn new(comp1: [u64; 3], comp2: [u64; 3]) -> Self {
let mut s = Self { comp1, comp2 };
for v in &mut s.comp1 {
*v %= M1;
if *v == 0 {
*v = 1;
}
}
for v in &mut s.comp2 {
*v %= M2;
if *v == 0 {
*v = 1;
}
}
s
}
pub fn from_seed(seed: u64) -> Self {
let s = seed as u32;
let s10 = s as u64;
let s11 = s as u64;
let s12 = s as u64;
let s20 = s as u64;
let s21 = s as u64;
let s22 = s as u64;
Self::new([s10, s11, s12], [s20, s21, s22])
}
pub fn step(&mut self) {
let prod_pos = (A12 as u128) * (self.comp1[1] as u128);
let prod_neg = (A13N as u128) * (self.comp1[0] as u128);
let p1 = if prod_pos >= prod_neg {
((prod_pos - prod_neg) % (M1 as u128)) as u64
} else {
let diff = ((prod_neg - prod_pos) % (M1 as u128)) as u64;
if diff == 0 { 0 } else { M1 - diff }
};
self.comp1[0] = self.comp1[1];
self.comp1[1] = self.comp1[2];
self.comp1[2] = p1;
let prod_pos2 = (A21 as u128) * (self.comp2[2] as u128);
let prod_neg2 = (A23N as u128) * (self.comp2[0] as u128);
let p2 = if prod_pos2 >= prod_neg2 {
((prod_pos2 - prod_neg2) % (M2 as u128)) as u64
} else {
let diff = ((prod_neg2 - prod_pos2) % (M2 as u128)) as u64;
if diff == 0 { 0 } else { M2 - diff }
};
self.comp2[0] = self.comp2[1];
self.comp2[1] = self.comp2[2];
self.comp2[2] = p2;
}
pub fn output(&self) -> u64 {
if self.comp1[2] >= self.comp2[2] {
self.comp1[2] - self.comp2[2]
} else {
M1 - (self.comp2[2] - self.comp1[2])
}
}
fn transition_matrix_1() -> ModMatrix3x3 {
ModMatrix3x3 {
data: [[0, 1, 0], [0, 0, 1], [M1 - A13N, A12, 0]],
}
}
fn transition_matrix_2() -> ModMatrix3x3 {
ModMatrix3x3 {
data: [[0, 1, 0], [0, 0, 1], [M2 - A23N, 0, A21]],
}
}
pub fn skip_ahead(&mut self, n: u64) {
if n == 0 {
return;
}
let a1 = Self::transition_matrix_1();
let a1n = a1.pow_mod(n, M1);
self.comp1 = a1n.mul_vec_mod(&self.comp1, M1);
let a2 = Self::transition_matrix_2();
let a2n = a2.pow_mod(n, M2);
self.comp2 = a2n.mul_vec_mod(&self.comp2, M2);
}
pub fn skip_ahead_pow2(&mut self, e: u32) {
if e == 0 {
self.step();
return;
}
let mut a1 = Self::transition_matrix_1();
let mut a2 = Self::transition_matrix_2();
for _ in 0..e {
a1 = a1.mul_mod(&a1, M1);
a2 = a2.mul_mod(&a2, M2);
}
self.comp1 = a1.mul_vec_mod(&self.comp1, M1);
self.comp2 = a2.mul_vec_mod(&self.comp2, M2);
}
pub fn stream(seed: u64, stream_id: u64) -> Self {
let mut state = Self::from_seed(seed);
if stream_id == 0 {
return state;
}
let mut a1_base = Self::transition_matrix_1();
let mut a2_base = Self::transition_matrix_2();
for _ in 0..76 {
a1_base = a1_base.mul_mod(&a1_base, M1);
a2_base = a2_base.mul_mod(&a2_base, M2);
}
let a1_skip = a1_base.pow_mod(stream_id, M1);
let a2_skip = a2_base.pow_mod(stream_id, M2);
state.comp1 = a1_skip.mul_vec_mod(&state.comp1, M1);
state.comp2 = a2_skip.mul_vec_mod(&state.comp2, M2);
state
}
}
fn emit_clamp_nonzero(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
reg: &oxicuda_ptx::ir::Register,
) {
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {pred}, {reg}, 0;"));
b.raw_ptx(&format!("@{pred} mov.u32 {reg}, 1;"));
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
#[test]
fn generate_uniform_f32_ptx() {
let ptx = generate_mrg32k3a_uniform_ptx(PtxType::F32, SmVersion::Sm80);
let ptx = ptx.expect("should generate PTX");
assert!(ptx.contains(".entry mrg32k3a_uniform_f32"));
assert!(ptx.contains("mul.wide.u32"));
}
#[test]
fn generate_uniform_f64_ptx() {
let ptx = generate_mrg32k3a_uniform_ptx(PtxType::F64, SmVersion::Sm80);
let ptx = ptx.expect("should generate PTX");
assert!(ptx.contains(".entry mrg32k3a_uniform_f64"));
}
#[test]
fn generate_normal_f32_ptx() {
let ptx = generate_mrg32k3a_normal_ptx(PtxType::F32, SmVersion::Sm80);
let ptx = ptx.expect("should generate PTX");
assert!(ptx.contains(".entry mrg32k3a_normal_f32"));
assert!(ptx.contains("lg2.approx"));
assert!(ptx.contains("cos.approx"));
}
#[test]
fn generate_normal_f64_ptx() {
let ptx = generate_mrg32k3a_normal_ptx(PtxType::F64, SmVersion::Sm80);
let ptx = ptx.expect("should generate PTX");
assert!(ptx.contains(".entry mrg32k3a_normal_f64"));
}
#[test]
fn generate_u32_ptx() {
let ptx = generate_mrg32k3a_u32_ptx(SmVersion::Sm80);
let ptx = ptx.expect("should generate PTX");
assert!(ptx.contains(".entry mrg32k3a_u32"));
assert!(ptx.contains("st.global.u32"));
}
#[test]
fn invalid_precision_returns_error() {
let result = generate_mrg32k3a_uniform_ptx(PtxType::U32, SmVersion::Sm80);
assert!(result.is_err());
}
#[test]
fn matrix_identity_is_identity() {
let id = ModMatrix3x3::identity();
assert_eq!(id.data[0], [1, 0, 0]);
assert_eq!(id.data[1], [0, 1, 0]);
assert_eq!(id.data[2], [0, 0, 1]);
}
#[test]
fn matrix_mul_identity_left() {
let id = ModMatrix3x3::identity();
let a = ModMatrix3x3 {
data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
};
let result = id.mul_mod(&a, M1);
assert_eq!(result, a);
}
#[test]
fn matrix_mul_identity_right() {
let id = ModMatrix3x3::identity();
let a = ModMatrix3x3 {
data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
};
let result = a.mul_mod(&id, M1);
assert_eq!(result, a);
}
#[test]
fn matrix_mul_known_result() {
let a = ModMatrix3x3 {
data: [[1, 2, 0], [0, 1, 0], [0, 0, 1]],
};
let b = ModMatrix3x3 {
data: [[1, 0, 0], [3, 1, 0], [0, 0, 1]],
};
let c = a.mul_mod(&b, M1);
assert_eq!(c.data[0], [7, 2, 0]);
assert_eq!(c.data[1], [3, 1, 0]);
assert_eq!(c.data[2], [0, 0, 1]);
}
#[test]
fn matrix_mul_mod_reduces() {
let a = ModMatrix3x3 {
data: [[M1 - 1, 0, 0], [0, 1, 0], [0, 0, 1]],
};
let b = ModMatrix3x3 {
data: [[M1 - 1, 0, 0], [0, 1, 0], [0, 0, 1]],
};
let c = a.mul_mod(&b, M1);
assert_eq!(c.data[0][0], 1);
}
#[test]
fn matrix_pow_zero_is_identity() {
let a = Mrg32k3aState::transition_matrix_1();
let result = a.pow_mod(0, M1);
assert_eq!(result, ModMatrix3x3::identity());
}
#[test]
fn matrix_pow_one_is_self() {
let a = Mrg32k3aState::transition_matrix_1();
let result = a.pow_mod(1, M1);
assert_eq!(result, a);
}
#[test]
fn matrix_pow_two_equals_mul_self() {
let a = Mrg32k3aState::transition_matrix_1();
let squared = a.mul_mod(&a, M1);
let pow2 = a.pow_mod(2, M1);
assert_eq!(pow2, squared);
}
#[test]
fn matrix_mul_vec_identity() {
let id = ModMatrix3x3::identity();
let v = [100, 200, 300];
let result = id.mul_vec_mod(&v, M1);
assert_eq!(result, v);
}
#[test]
fn matrix_mul_vec_known() {
let a = ModMatrix3x3 {
data: [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
};
let v = [1, 2, 3];
let result = a.mul_vec_mod(&v, M1);
assert_eq!(result, [14, 32, 50]);
}
#[test]
fn state_from_seed_nonzero() {
let state = Mrg32k3aState::from_seed(0);
for v in &state.comp1 {
assert!(*v > 0);
}
for v in &state.comp2 {
assert!(*v > 0);
}
}
#[test]
fn state_new_clamps_modulus() {
let state = Mrg32k3aState::new([M1 + 5, M1 + 10, M1 + 15], [M2 + 3, M2 + 7, M2 + 11]);
assert_eq!(state.comp1[0], 5);
assert_eq!(state.comp1[1], 10);
assert_eq!(state.comp1[2], 15);
assert_eq!(state.comp2[0], 3);
assert_eq!(state.comp2[1], 7);
assert_eq!(state.comp2[2], 11);
}
#[test]
fn skip_ahead_zero_is_identity() {
let original = Mrg32k3aState::from_seed(42);
let mut skipped = original.clone();
skipped.skip_ahead(0);
assert_eq!(original, skipped);
}
#[test]
fn skip_ahead_one_matches_single_step() {
let mut stepped = Mrg32k3aState::from_seed(42);
stepped.step();
let mut skipped = Mrg32k3aState::from_seed(42);
skipped.skip_ahead(1);
assert_eq!(stepped, skipped);
}
#[test]
fn skip_ahead_n_matches_n_steps() {
let n = 100;
let mut stepped = Mrg32k3aState::from_seed(42);
for _ in 0..n {
stepped.step();
}
let mut skipped = Mrg32k3aState::from_seed(42);
skipped.skip_ahead(n);
assert_eq!(stepped, skipped);
}
#[test]
fn skip_ahead_determinism() {
let mut s1 = Mrg32k3aState::from_seed(123);
let mut s2 = Mrg32k3aState::from_seed(123);
s1.skip_ahead(1000);
s2.skip_ahead(1000);
assert_eq!(s1, s2);
}
#[test]
fn skip_ahead_composable() {
let mut combined = Mrg32k3aState::from_seed(42);
combined.skip_ahead(150);
let mut sequential = Mrg32k3aState::from_seed(42);
sequential.skip_ahead(100);
sequential.skip_ahead(50);
assert_eq!(combined, sequential);
}
#[test]
fn skip_ahead_large_exponent() {
let mut state = Mrg32k3aState::from_seed(42);
state.skip_ahead(1_000_000);
for v in &state.comp1 {
assert!(*v > 0);
assert!(*v < M1);
}
for v in &state.comp2 {
assert!(*v > 0);
assert!(*v < M2);
}
}
#[test]
fn skip_ahead_pow2_matches_skip_ahead() {
for e in 0..10 {
let n = 1u64 << e;
let mut via_pow2 = Mrg32k3aState::from_seed(42);
via_pow2.skip_ahead_pow2(e);
let mut via_skip = Mrg32k3aState::from_seed(42);
via_skip.skip_ahead(n);
assert_eq!(
via_pow2, via_skip,
"skip_ahead_pow2({e}) != skip_ahead({n})"
);
}
}
#[test]
fn stream_zero_equals_from_seed() {
let from_seed = Mrg32k3aState::from_seed(42);
let stream0 = Mrg32k3aState::stream(42, 0);
assert_eq!(from_seed, stream0);
}
#[test]
fn stream_different_ids_produce_different_states() {
let s0 = Mrg32k3aState::stream(42, 0);
let s1 = Mrg32k3aState::stream(42, 1);
let s2 = Mrg32k3aState::stream(42, 2);
let s3 = Mrg32k3aState::stream(42, 3);
assert_ne!(s0, s1);
assert_ne!(s0, s2);
assert_ne!(s0, s3);
assert_ne!(s1, s2);
assert_ne!(s1, s3);
assert_ne!(s2, s3);
}
#[test]
fn stream_different_seeds_produce_different_states() {
let a = Mrg32k3aState::stream(1, 1);
let b = Mrg32k3aState::stream(2, 1);
assert_ne!(a, b);
}
#[test]
fn stream_outputs_differ() {
let mut outputs = Vec::new();
for id in 0..5 {
let mut s = Mrg32k3aState::stream(42, id);
s.step();
outputs.push(s.output());
}
for i in 0..outputs.len() {
for j in (i + 1)..outputs.len() {
assert_ne!(
outputs[i], outputs[j],
"stream {i} and {j} produced same output"
);
}
}
}
#[test]
fn transition_matrix_1_structure() {
let a1 = Mrg32k3aState::transition_matrix_1();
assert_eq!(a1.data[0], [0, 1, 0]);
assert_eq!(a1.data[1], [0, 0, 1]);
assert_eq!(a1.data[2][0], M1 - A13N);
assert_eq!(a1.data[2][1], A12);
assert_eq!(a1.data[2][2], 0);
}
#[test]
fn transition_matrix_2_structure() {
let a2 = Mrg32k3aState::transition_matrix_2();
assert_eq!(a2.data[0], [0, 1, 0]);
assert_eq!(a2.data[1], [0, 0, 1]);
assert_eq!(a2.data[2][0], M2 - A23N);
assert_eq!(a2.data[2][1], 0);
assert_eq!(a2.data[2][2], A21);
}
#[test]
fn transition_matrix_1_step_matches_manual_step() {
let state = Mrg32k3aState::new([100, 200, 300], [400, 500, 600]);
let mut stepped = state.clone();
stepped.step();
let a1 = Mrg32k3aState::transition_matrix_1();
let new_comp1 = a1.mul_vec_mod(&state.comp1, M1);
assert_eq!(new_comp1, stepped.comp1);
}
#[test]
fn transition_matrix_2_step_matches_manual_step() {
let state = Mrg32k3aState::new([100, 200, 300], [400, 500, 600]);
let mut stepped = state.clone();
stepped.step();
let a2 = Mrg32k3aState::transition_matrix_2();
let new_comp2 = a2.mul_vec_mod(&state.comp2, M2);
assert_eq!(new_comp2, stepped.comp2);
}
#[test]
fn output_within_valid_range() {
let mut state = Mrg32k3aState::from_seed(42);
for _ in 0..100 {
state.step();
let out = state.output();
assert!(out < M1, "output {out} exceeds m1");
}
}
#[test]
fn step_changes_state() {
let mut state = Mrg32k3aState::from_seed(42);
let before = state.clone();
state.step();
assert_ne!(before, state);
}
#[test]
fn test_mrg32k3a_skip_ahead_output_reproducible() {
let mut rng1 = Mrg32k3aState::from_seed(12345);
let mut sequential_outputs: Vec<u64> = Vec::with_capacity(110);
for _ in 0..110 {
rng1.step();
sequential_outputs.push(rng1.output());
}
let mut rng2 = Mrg32k3aState::from_seed(12345);
rng2.skip_ahead(100);
rng2.step();
let value_after_skip = rng2.output();
assert_eq!(
value_after_skip, sequential_outputs[100],
"Skip-ahead by 100 then step should give same output as sequential[100]: \
skip={value_after_skip}, seq={}",
sequential_outputs[100],
);
}
#[test]
fn test_mrg32k3a_parallel_streams_independent() {
let mut rng1 = Mrg32k3aState::from_seed(42);
let mut rng2 = Mrg32k3aState::stream(42, 1);
let v1: Vec<u64> = (0..100)
.map(|_| {
rng1.step();
rng1.output()
})
.collect();
let v2: Vec<u64> = (0..100)
.map(|_| {
rng2.step();
rng2.output()
})
.collect();
let same_count = v1.iter().zip(&v2).filter(|(a, b)| a == b).count();
assert!(
same_count < 10,
"Parallel streams appear correlated: {same_count}/100 values identical"
);
}
#[test]
fn test_mrg32k3a_skip_ahead_composability_output() {
let seed = 99_u64;
let n = 50_u64;
let m = 75_u64;
let mut combined = Mrg32k3aState::from_seed(seed);
combined.skip_ahead(n + m);
combined.step();
let combined_output = combined.output();
let mut sequential = Mrg32k3aState::from_seed(seed);
sequential.skip_ahead(n);
sequential.skip_ahead(m);
sequential.step();
let sequential_output = sequential.output();
assert_eq!(
combined_output,
sequential_output,
"skip_ahead({}) + skip_ahead({}) should equal skip_ahead({}): \
combined={combined_output}, sequential={sequential_output}",
n,
m,
n + m,
);
}
#[test]
fn test_mrg32k3a_skip_ahead_determinism_output() {
let mut s1 = Mrg32k3aState::from_seed(777);
let mut s2 = Mrg32k3aState::from_seed(777);
s1.skip_ahead(500);
s2.skip_ahead(500);
s1.step();
s2.step();
assert_eq!(
s1.output(),
s2.output(),
"Same seed and skip must give identical output"
);
}
#[test]
fn test_mrg32k3a_skip_ahead_generate_10_values() {
let mut seq = Mrg32k3aState::from_seed(31415);
let mut sequential_values = [0_u64; 10];
for _ in 0..100 {
seq.step();
}
for slot in &mut sequential_values {
seq.step();
*slot = seq.output();
}
let mut skip = Mrg32k3aState::from_seed(31415);
skip.skip_ahead(100);
let mut skip_values = [0_u64; 10];
for slot in &mut skip_values {
skip.step();
*slot = skip.output();
}
for (i, (s, q)) in sequential_values.iter().zip(skip_values.iter()).enumerate() {
assert_eq!(
s, q,
"Value {i} after skip_ahead(100): sequential={s} != skip={q}"
);
}
}
#[test]
fn test_mrg32k3a_four_stream_independence() {
let seed = 12_345_678_u64;
let mut streams: Vec<Mrg32k3aState> =
(0..4).map(|id| Mrg32k3aState::stream(seed, id)).collect();
let samples = 100_usize;
let values: Vec<Vec<u64>> = streams
.iter_mut()
.map(|s| {
(0..samples)
.map(|_| {
s.step();
s.output()
})
.collect()
})
.collect();
for i in 0..4 {
for j in (i + 1)..4 {
let same_count = values[i]
.iter()
.zip(&values[j])
.filter(|(a, b)| a == b)
.count();
assert!(
same_count < 10,
"Streams {i} and {j} share {same_count}/{samples} identical values — \
streams appear correlated"
);
}
}
let first_outputs: Vec<u64> = values.iter().map(|v| v[0]).collect();
let unique_first: std::collections::HashSet<u64> = first_outputs.iter().cloned().collect();
assert_eq!(
unique_first.len(),
4,
"All four stream initial outputs must be distinct: {first_outputs:?}"
);
}
}