use oxicuda_blas::GpuFloat;
use oxicuda_ptx::builder::BodyBuilder;
use oxicuda_ptx::ir::{PtxType, Register};
pub const SOLVER_BLOCK_SIZE: u32 = 256;
pub fn load_global_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, addr: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"ld.global{} {dst}, [{addr}];",
float_ty.as_ptx_str()
));
dst
}
pub fn store_global_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, addr: Register, val: Register) {
let float_ty = T::PTX_TYPE;
b.raw_ptx(&format!(
"st.global{} [{addr}], {val};",
float_ty.as_ptx_str()
));
}
pub fn mul_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, a: Register, x: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("mul{} {dst}, {a}, {x};", float_ty.as_ptx_str()));
dst
}
pub fn add_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, a: Register, x: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("add{} {dst}, {a}, {x};", float_ty.as_ptx_str()));
dst
}
pub fn sub_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, x: Register, y: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("sub{} {dst}, {x}, {y};", float_ty.as_ptx_str()));
dst
}
pub fn div_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, x: Register, y: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("div.rn{} {dst}, {x}, {y};", float_ty.as_ptx_str()));
dst
}
pub fn abs_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, x: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("abs{} {dst}, {x};", float_ty.as_ptx_str()));
dst
}
pub fn sqrt_float<T: GpuFloat>(b: &mut BodyBuilder<'_>, x: Register) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!("sqrt.rn{} {dst}, {x};", float_ty.as_ptx_str()));
dst
}
pub fn fma_float<T: GpuFloat>(
b: &mut BodyBuilder<'_>,
a_reg: Register,
b_reg: Register,
c_reg: Register,
) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"fma.rn{} {dst}, {a_reg}, {b_reg}, {c_reg};",
float_ty.as_ptx_str()
));
dst
}
pub fn zero_float<T: GpuFloat>(b: &mut BodyBuilder<'_>) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
let zero_bits: u64 = T::gpu_zero().to_bits_u64();
let bits_reg = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {bits_reg}, {zero_bits};"));
if float_ty == PtxType::F32 {
let bits32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {bits32}, {bits_reg};"));
b.raw_ptx(&format!("mov.b32 {dst}, {bits32};"));
} else {
b.raw_ptx(&format!("mov.b64 {dst}, {bits_reg};"));
}
dst
}
pub fn one_float<T: GpuFloat>(b: &mut BodyBuilder<'_>) -> Register {
let float_ty = T::PTX_TYPE;
let dst = b.alloc_reg(float_ty);
let one_bits: u64 = T::gpu_one().to_bits_u64();
let bits_reg = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {bits_reg}, {one_bits};"));
if float_ty == PtxType::F32 {
let bits32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.u32.u64 {bits32}, {bits_reg};"));
b.raw_ptx(&format!("mov.b32 {dst}, {bits32};"));
} else {
b.raw_ptx(&format!("mov.b64 {dst}, {bits_reg};"));
}
dst
}
pub fn byte_offset<T: GpuFloat>(
b: &mut BodyBuilder<'_>,
base: Register,
index: Register,
) -> Register {
b.byte_offset_addr(base, index, T::size_u32())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn solver_block_size_reasonable() {
let block_size = SOLVER_BLOCK_SIZE;
assert!(block_size > 0);
assert!(block_size <= 1024);
assert_eq!(
SOLVER_BLOCK_SIZE & (SOLVER_BLOCK_SIZE - 1),
0,
"must be power of 2"
);
}
}