#![allow(dead_code)]
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
fn pade_coefficients(order: u32) -> SolverResult<Vec<f64>> {
match order {
3 => Ok(vec![120.0, 60.0, 12.0, 1.0]),
5 => Ok(vec![30240.0, 15120.0, 3360.0, 420.0, 30.0, 1.0]),
7 => Ok(vec![
17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0, 1.0,
]),
9 => Ok(vec![
17643225600.0,
8821612800.0,
2075673600.0,
302702400.0,
30270240.0,
2162160.0,
110880.0,
3960.0,
90.0,
1.0,
]),
13 => Ok(vec![
64764752532480000.0,
32382376266240000.0,
7771770303897600.0,
1187353796428800.0,
129060195264000.0,
10559470521600.0,
670442572800.0,
33522128640.0,
1323241920.0,
40840800.0,
960960.0,
16380.0,
182.0,
1.0,
]),
_ => Err(SolverError::InternalError(format!(
"unsupported Padé order {order}; valid orders are 3, 5, 7, 9, 13"
))),
}
}
#[allow(clippy::excessive_precision)]
fn pade_theta(order: u32) -> SolverResult<f64> {
match order {
3 => Ok(1.495_585_217_958_292e-2),
5 => Ok(2.539_398_330_063_230e-1),
7 => Ok(9.504_178_996_162_932e-1),
9 => Ok(2.097_847_961_257_068),
13 => Ok(5.371_920_351_148_152),
_ => Err(SolverError::InternalError(format!(
"no theta for Padé order {order}"
))),
}
}
#[derive(Debug, Clone)]
pub struct MatrixExpConfig {
pub n: u32,
pub precision: String,
pub pade_order: u32,
}
impl MatrixExpConfig {
pub fn new(n: u32, precision: &str) -> Self {
Self {
n,
precision: precision.to_string(),
pade_order: 13,
}
}
pub fn with_pade_order(mut self, order: u32) -> Self {
self.pade_order = order;
self
}
fn validate(&self) -> SolverResult<()> {
if self.n == 0 {
return Err(SolverError::DimensionMismatch(
"expm: matrix dimension must be > 0".into(),
));
}
if self.precision != "f32" && self.precision != "f64" {
return Err(SolverError::InternalError(format!(
"expm: unsupported precision '{}'; use 'f32' or 'f64'",
self.precision
)));
}
pade_coefficients(self.pade_order)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MatrixExpPlan {
config: MatrixExpConfig,
pade_coeffs: Vec<f64>,
theta: f64,
}
impl MatrixExpPlan {
pub fn new(config: MatrixExpConfig) -> SolverResult<Self> {
config.validate()?;
let pade_coeffs = pade_coefficients(config.pade_order)?;
let theta = pade_theta(config.pade_order)?;
Ok(Self {
config,
pade_coeffs,
theta,
})
}
pub fn pade_coefficients(&self) -> &[f64] {
&self.pade_coeffs
}
pub fn theta(&self) -> f64 {
self.theta
}
pub fn generate_ptx(&self) -> SolverResult<String> {
let n = self.config.n;
let float_ty = precision_to_ptx_type(&self.config.precision)?;
let sm = SmVersion::Sm75;
let mut all_ptx = Vec::new();
let scale_ptx = self.emit_scale_kernel(n, float_ty, sm)?;
all_ptx.push(scale_ptx);
let pade_ptx = self.emit_pade_kernel(n, float_ty, sm)?;
all_ptx.push(pade_ptx);
let square_ptx = self.emit_squaring_kernel(n, float_ty, sm)?;
all_ptx.push(square_ptx);
Ok(all_ptx.join("\n"))
}
fn emit_scale_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
let name = format!("solver_expm_scale_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("a_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("scale_exp", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
b.if_lt_u32(gid, total, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let out_ptr = b.load_param_u64("out_ptr");
let scale_exp = b.load_param_u32("scale_exp");
let gid_repeat = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let addr = b.byte_offset_addr(a_ptr, gid_repeat.clone(), elem_size);
let val = load_float(b, float_ty, addr);
let out_addr = b.byte_offset_addr(out_ptr, gid_repeat, elem_size);
let result = if float_ty == PtxType::F64 {
let se64 = b.cvt_u32_to_u64(scale_exp);
let biased = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
let shift_amt = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
let bits = b.shl_b64(biased, shift_amt);
let divisor = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mov.b64 {divisor}, {bits};"));
let res = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("div.rn.f64 {res}, {val}, {divisor};"));
res
} else {
let biased = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
let shift_amt = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
let bits = b.shl_b32(biased, shift_amt);
let divisor = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {divisor}, {bits};"));
let res = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("div.rn.f32 {res}, {val}, {divisor};"));
res
};
store_float(b, float_ty, out_addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_pade_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
let order = self.config.pade_order;
let name = format!(
"solver_expm_pade_{}_n{}_p{}",
ptx_type_suffix(float_ty),
n,
order
);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("a_ptr", PtxType::U64)
.param("p_ptr", PtxType::U64)
.param("q_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("coeffs_ptr", PtxType::U64)
.param("num_coeffs", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg);
b.if_lt_u32(gid, total, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let p_ptr = b.load_param_u64("p_ptr");
let q_ptr = b.load_param_u64("q_ptr");
let coeffs_ptr = b.load_param_u64("coeffs_ptr");
let num_coeffs = b.load_param_u32("num_coeffs");
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
const COEFF_SIZE: u32 = 8u32;
let gid_r = b.global_thread_id_x();
let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
let a_val = load_float(b, float_ty, a_addr);
let acc_p = zero_const(b, float_ty);
let acc_q = zero_const(b, float_ty);
let idx_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {idx_reg}, {num_coeffs};"));
let horner_loop = b.fresh_label("horner_loop");
let horner_exit = b.fresh_label("horner_exit");
b.raw_ptx(&format!("{horner_loop}:"));
let done_pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {idx_reg}, 0;"));
b.raw_ptx(&format!("@{done_pred} bra {horner_exit};"));
b.raw_ptx(&format!("sub.u32 {idx_reg}, {idx_reg}, 1;"));
let coeff_addr =
b.byte_offset_addr(coeffs_ptr.clone(), idx_reg.clone(), COEFF_SIZE);
let coeff_f64 = load_float(b, PtxType::F64, coeff_addr);
let c_k = if float_ty == PtxType::F64 {
coeff_f64.clone()
} else {
let dst = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.f64 {dst}, {coeff_f64};"));
dst
};
let new_acc_p = if float_ty == PtxType::F64 {
b.fma_f64(acc_p.clone(), a_val.clone(), c_k.clone())
} else {
b.fma_f32(acc_p.clone(), a_val.clone(), c_k.clone())
};
b.raw_ptx(&format!(
"mov{} {acc_p}, {new_acc_p};",
float_ty.as_ptx_str()
));
let odd_pred = b.alloc_reg(PtxType::Pred);
let lsb = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("and.b32 {lsb}, {idx_reg}, 1;"));
b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
let neg_c_k = b.alloc_reg(float_ty);
b.raw_ptx(&format!("neg{} {neg_c_k}, {c_k};", float_ty.as_ptx_str()));
let q_coeff = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {q_coeff}, {neg_c_k}, {c_k}, {odd_pred};",
float_ty.as_ptx_str()
));
let new_acc_q = if float_ty == PtxType::F64 {
b.fma_f64(acc_q.clone(), a_val.clone(), q_coeff)
} else {
b.fma_f32(acc_q.clone(), a_val.clone(), q_coeff)
};
b.raw_ptx(&format!(
"mov{} {acc_q}, {new_acc_q};",
float_ty.as_ptx_str()
));
b.raw_ptx(&format!("bra {horner_loop};"));
b.raw_ptx(&format!("{horner_exit}:"));
let p_addr = b.byte_offset_addr(p_ptr, gid_r.clone(), elem_size);
let q_addr = b.byte_offset_addr(q_ptr, gid_r, elem_size);
store_float(b, float_ty, p_addr, acc_p);
store_float(b, float_ty, q_addr, acc_q);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_squaring_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!("solver_expm_square_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("f_ptr", PtxType::U64)
.param("tmp_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
b.if_lt_u32(gid, total, |b| {
let f_ptr = b.load_param_u64("f_ptr");
let tmp_ptr = b.load_param_u64("tmp_ptr");
let n_inner = b.load_param_u32("n");
let gid_r = b.global_thread_id_x();
let row = b.alloc_reg(PtxType::U32);
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let acc = zero_const(b, float_ty);
let k_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {k_reg}, 0;"));
let loop_label = b.fresh_label("sq_loop");
let exit_label = b.fresh_label("sq_exit");
b.raw_ptx(&format!("{loop_label}:"));
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred}, {k_reg}, {n_inner};"));
b.raw_ptx(&format!("@{pred} bra {exit_label};"));
let a_idx_base = b.mul_lo_u32(k_reg.clone(), n_inner.clone());
let a_idx = b.add_u32(a_idx_base, row.clone());
let a_addr = b.byte_offset_addr(f_ptr.clone(), a_idx, elem_size);
let a_val = load_float(b, float_ty, a_addr);
let b_idx_base = b.mul_lo_u32(col.clone(), n_inner.clone());
let b_idx = b.add_u32(b_idx_base, k_reg.clone());
let b_addr = b.byte_offset_addr(f_ptr.clone(), b_idx, elem_size);
let b_val = load_float(b, float_ty, b_addr);
let new_acc = if float_ty == PtxType::F64 {
b.fma_f64(a_val, b_val, acc.clone())
} else {
b.fma_f32(a_val, b_val, acc.clone())
};
b.raw_ptx(&format!("mov{} {acc}, {new_acc};", float_ty.as_ptx_str()));
b.raw_ptx(&format!("add.u32 {k_reg}, {k_reg}, 1;"));
b.raw_ptx(&format!("bra {loop_label};"));
b.raw_ptx(&format!("{exit_label}:"));
let out_idx_base = b.mul_lo_u32(col, n_inner);
let out_idx = b.add_u32(out_idx_base, row);
let out_addr = b.byte_offset_addr(tmp_ptr, out_idx, elem_size);
store_float(b, float_ty, out_addr, acc);
});
b.ret();
})
.build()?;
Ok(ptx)
}
}
#[derive(Debug, Clone)]
pub struct MatrixLogConfig {
pub n: u32,
pub precision: String,
pub max_sqrt_iters: u32,
}
impl MatrixLogConfig {
pub fn new(n: u32, precision: &str) -> Self {
Self {
n,
precision: precision.to_string(),
max_sqrt_iters: 100,
}
}
pub fn with_max_sqrt_iters(mut self, iters: u32) -> Self {
self.max_sqrt_iters = iters;
self
}
fn validate(&self) -> SolverResult<()> {
if self.n == 0 {
return Err(SolverError::DimensionMismatch(
"logm: matrix dimension must be > 0".into(),
));
}
if self.precision != "f32" && self.precision != "f64" {
return Err(SolverError::InternalError(format!(
"logm: unsupported precision '{}'; use 'f32' or 'f64'",
self.precision
)));
}
if self.max_sqrt_iters == 0 {
return Err(SolverError::InternalError(
"logm: max_sqrt_iters must be > 0".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MatrixLogPlan {
config: MatrixLogConfig,
}
impl MatrixLogPlan {
pub fn new(config: MatrixLogConfig) -> SolverResult<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn max_sqrt_iters(&self) -> u32 {
self.config.max_sqrt_iters
}
pub fn generate_ptx(&self) -> SolverResult<String> {
let n = self.config.n;
let float_ty = precision_to_ptx_type(&self.config.precision)?;
let sm = SmVersion::Sm75;
let mut all_ptx = Vec::new();
let shift_ptx = self.emit_shift_kernel(n, float_ty, sm)?;
all_ptx.push(shift_ptx);
let sqrt_step_ptx = self.emit_sqrt_step_kernel(n, float_ty, sm)?;
all_ptx.push(sqrt_step_ptx);
let pade_log_ptx = self.emit_pade_log_kernel(n, float_ty, sm)?;
all_ptx.push(pade_log_ptx);
let scale_ptx = self.emit_scale_back_kernel(n, float_ty, sm)?;
all_ptx.push(scale_ptx);
Ok(all_ptx.join("\n"))
}
fn emit_shift_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
let name = format!("solver_logm_shift_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("a_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
b.if_lt_u32(gid, total, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let out_ptr = b.load_param_u64("out_ptr");
let n_inner = b.load_param_u32("n");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let src_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
let val = load_float(b, float_ty, src_addr);
let row = b.alloc_reg(PtxType::U32);
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
let is_diag = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
let one = one_const(b, float_ty);
let zero = zero_const(b, float_ty);
let diag_sub = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {diag_sub}, {one}, {zero}, {is_diag};",
float_ty.as_ptx_str()
));
let result = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"sub{} {result}, {val}, {diag_sub};",
float_ty.as_ptx_str()
));
let dst_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
store_float(b, float_ty, dst_addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_sqrt_step_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!("solver_logm_sqrt_step_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("y_ptr", PtxType::U64)
.param("z_ptr", PtxType::U64)
.param("y_next_ptr", PtxType::U64)
.param("z_next_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg);
b.if_lt_u32(gid, total, |b| {
let y_ptr = b.load_param_u64("y_ptr");
let z_ptr = b.load_param_u64("z_ptr");
let y_next_ptr = b.load_param_u64("y_next_ptr");
let z_next_ptr = b.load_param_u64("z_next_ptr");
let n_inner = b.load_param_u32("n");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let row = b.alloc_reg(PtxType::U32);
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
let is_diag = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
let one = one_const(b, float_ty);
let zero = zero_const(b, float_ty);
let diag_add = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {diag_add}, {one}, {zero}, {is_diag};",
float_ty.as_ptx_str()
));
let half = half_const(b, float_ty);
let y_src = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
let y_val = load_float(b, float_ty, y_src);
let y_sum = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"add{} {y_sum}, {y_val}, {diag_add};",
float_ty.as_ptx_str()
));
let y_result = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"mul{} {y_result}, {y_sum}, {half};",
float_ty.as_ptx_str()
));
let y_dst = b.byte_offset_addr(y_next_ptr, gid_r.clone(), elem_size);
store_float(b, float_ty, y_dst, y_result);
let z_src = b.byte_offset_addr(z_ptr, gid_r.clone(), elem_size);
let z_val = load_float(b, float_ty, z_src);
let z_sum = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"add{} {z_sum}, {z_val}, {diag_add};",
float_ty.as_ptx_str()
));
let z_result = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"mul{} {z_result}, {z_sum}, {half};",
float_ty.as_ptx_str()
));
let z_dst = b.byte_offset_addr(z_next_ptr, gid_r, elem_size);
store_float(b, float_ty, z_dst, z_result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_pade_log_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!("solver_logm_pade_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("x_ptr", PtxType::U64)
.param("result_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("num_terms", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg);
b.if_lt_u32(gid, total, |b| {
let x_ptr = b.load_param_u64("x_ptr");
let result_ptr = b.load_param_u64("result_ptr");
let num_terms = b.load_param_u32("num_terms");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let src = b.byte_offset_addr(x_ptr, gid_r.clone(), elem_size);
let x_val = load_float(b, float_ty, src);
let acc_reg = b.alloc_reg(float_ty);
let zero = zero_const(b, float_ty);
b.raw_ptx(&format!("mov{} {acc_reg}, {zero};", float_ty.as_ptx_str()));
let k_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {k_reg}, {num_terms};"));
let log_loop = b.fresh_label("log_loop");
let log_exit = b.fresh_label("log_exit");
b.raw_ptx(&format!("{log_loop}:"));
let done_pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {done_pred}, {k_reg}, 0;"));
b.raw_ptx(&format!("@{done_pred} bra {log_exit};"));
let k_f = b.alloc_reg(float_ty);
if float_ty == PtxType::F64 {
b.raw_ptx(&format!("cvt.rn.f64.u32 {k_f}, {k_reg};"));
} else {
b.raw_ptx(&format!("cvt.rn.f32.u32 {k_f}, {k_reg};"));
}
let inv_k = if float_ty == PtxType::F64 {
b.rcp_f64(k_f)
} else {
b.rcp_f32(k_f)
};
let odd_pred = b.alloc_reg(PtxType::Pred);
let lsb = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("and.b32 {lsb}, {k_reg}, 1;"));
b.raw_ptx(&format!("setp.ne.u32 {odd_pred}, {lsb}, 0;"));
let neg_inv_k = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"neg{} {neg_inv_k}, {inv_k};",
float_ty.as_ptx_str()
));
let signed_inv_k = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {signed_inv_k}, {inv_k}, {neg_inv_k}, {odd_pred};",
float_ty.as_ptx_str()
));
let new_acc = if float_ty == PtxType::F64 {
b.fma_f64(x_val.clone(), acc_reg.clone(), signed_inv_k)
} else {
b.fma_f32(x_val.clone(), acc_reg.clone(), signed_inv_k)
};
b.raw_ptx(&format!(
"mov{} {acc_reg}, {new_acc};",
float_ty.as_ptx_str()
));
b.raw_ptx(&format!("sub.u32 {k_reg}, {k_reg}, 1;"));
b.raw_ptx(&format!("bra {log_loop};"));
b.raw_ptx(&format!("{log_exit}:"));
let result = if float_ty == PtxType::F64 {
let r = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mul.rn.f64 {r}, {x_val}, {acc_reg};"));
r
} else {
let r = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {r}, {x_val}, {acc_reg};"));
r
};
let dst = b.byte_offset_addr(result_ptr, gid_r, elem_size);
store_float(b, float_ty, dst, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_scale_back_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!(
"solver_logm_scale_back_{}_n{}",
ptx_type_suffix(float_ty),
n
);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("result_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("scale_exp", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg);
b.if_lt_u32(gid, total, |b| {
let result_ptr = b.load_param_u64("result_ptr");
let scale_exp = b.load_param_u32("scale_exp");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let addr = b.byte_offset_addr(result_ptr, gid_r, elem_size);
let val = load_float(b, float_ty, addr.clone());
let result = if float_ty == PtxType::F64 {
let se64 = b.cvt_u32_to_u64(scale_exp);
let biased = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("add.u64 {biased}, {se64}, 1023;"));
let shift_amt = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {shift_amt}, 52;"));
let bits = b.shl_b64(biased, shift_amt);
let factor = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mov.b64 {factor}, {bits};"));
let res = b.alloc_reg(PtxType::F64);
b.raw_ptx(&format!("mul.rn.f64 {res}, {val}, {factor};"));
res
} else {
let biased = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {biased}, {scale_exp}, 127;"));
let shift_amt = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {shift_amt}, 23;"));
let bits = b.shl_b32(biased, shift_amt);
let factor = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {factor}, {bits};"));
let res = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {res}, {val}, {factor};"));
res
};
store_float(b, float_ty, addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
}
#[derive(Debug, Clone)]
pub struct MatrixSqrtConfig {
pub n: u32,
pub precision: String,
pub max_iters: u32,
pub tol: f64,
}
impl MatrixSqrtConfig {
pub fn new(n: u32, precision: &str) -> Self {
Self {
n,
precision: precision.to_string(),
max_iters: 50,
tol: 1e-12,
}
}
pub fn with_max_iters(mut self, iters: u32) -> Self {
self.max_iters = iters;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
fn validate(&self) -> SolverResult<()> {
if self.n == 0 {
return Err(SolverError::DimensionMismatch(
"sqrtm: matrix dimension must be > 0".into(),
));
}
if self.precision != "f32" && self.precision != "f64" {
return Err(SolverError::InternalError(format!(
"sqrtm: unsupported precision '{}'; use 'f32' or 'f64'",
self.precision
)));
}
if self.max_iters == 0 {
return Err(SolverError::InternalError(
"sqrtm: max_iters must be > 0".into(),
));
}
if self.tol <= 0.0 || !self.tol.is_finite() {
return Err(SolverError::InternalError(format!(
"sqrtm: tolerance must be positive and finite, got {}",
self.tol
)));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MatrixSqrtPlan {
config: MatrixSqrtConfig,
}
impl MatrixSqrtPlan {
pub fn new(config: MatrixSqrtConfig) -> SolverResult<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn tolerance(&self) -> f64 {
self.config.tol
}
pub fn max_iters(&self) -> u32 {
self.config.max_iters
}
pub fn generate_ptx(&self) -> SolverResult<String> {
let n = self.config.n;
let float_ty = precision_to_ptx_type(&self.config.precision)?;
let sm = SmVersion::Sm75;
let mut all_ptx = Vec::new();
let init_ptx = self.emit_init_kernel(n, float_ty, sm)?;
all_ptx.push(init_ptx);
let iter_ptx = self.emit_iteration_kernel(n, float_ty, sm)?;
all_ptx.push(iter_ptx);
let conv_ptx = self.emit_convergence_kernel(n, float_ty, sm)?;
all_ptx.push(conv_ptx);
Ok(all_ptx.join("\n"))
}
fn emit_init_kernel(&self, n: u32, float_ty: PtxType, sm: SmVersion) -> SolverResult<String> {
let name = format!("solver_sqrtm_init_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("a_ptr", PtxType::U64)
.param("y_ptr", PtxType::U64)
.param("z_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
b.if_lt_u32(gid, total, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let y_ptr = b.load_param_u64("y_ptr");
let z_ptr = b.load_param_u64("z_ptr");
let n_inner = b.load_param_u32("n");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let a_addr = b.byte_offset_addr(a_ptr, gid_r.clone(), elem_size);
let val = load_float(b, float_ty, a_addr);
let y_addr = b.byte_offset_addr(y_ptr, gid_r.clone(), elem_size);
store_float(b, float_ty, y_addr, val);
let row = b.alloc_reg(PtxType::U32);
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
let z_addr = b.byte_offset_addr(z_ptr, gid_r, elem_size);
let one = one_const(b, float_ty);
let zero = zero_const(b, float_ty);
let is_diag = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
let z_val = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {z_val}, {one}, {zero}, {is_diag};",
float_ty.as_ptx_str()
));
store_float(b, float_ty, z_addr, z_val);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_iteration_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!("solver_sqrtm_iter_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("m_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
b.if_lt_u32(gid, total, |b| {
let m_ptr = b.load_param_u64("m_ptr");
let out_ptr = b.load_param_u64("out_ptr");
let n_inner = b.load_param_u32("n");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let m_addr = b.byte_offset_addr(m_ptr, gid_r.clone(), elem_size);
let m_val = load_float(b, float_ty, m_addr);
let row = b.alloc_reg(PtxType::U32);
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {row}, {gid_r}, {n_inner};"));
b.raw_ptx(&format!("div.u32 {col}, {gid_r}, {n_inner};"));
let is_diag = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.eq.u32 {is_diag}, {row}, {col};"));
let one = one_const(b, float_ty);
let zero = zero_const(b, float_ty);
let diag_add = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"selp{} {diag_add}, {one}, {zero}, {is_diag};",
float_ty.as_ptx_str()
));
let sum = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"add{} {sum}, {m_val}, {diag_add};",
float_ty.as_ptx_str()
));
let half = half_const(b, float_ty);
let result = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"mul{} {result}, {sum}, {half};",
float_ty.as_ptx_str()
));
let out_addr = b.byte_offset_addr(out_ptr, gid_r, elem_size);
store_float(b, float_ty, out_addr, result);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_convergence_kernel(
&self,
n: u32,
float_ty: PtxType,
sm: SmVersion,
) -> SolverResult<String> {
let name = format!("solver_sqrtm_conv_{}_n{}", ptx_type_suffix(float_ty), n);
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(256)
.param("y_new_ptr", PtxType::U64)
.param("y_old_ptr", PtxType::U64)
.param("norm_ptr", PtxType::U64)
.param("n", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let n_reg = b.load_param_u32("n");
let total = b.mul_lo_u32(n_reg.clone(), n_reg);
b.if_lt_u32(gid, total, |b| {
let y_new_ptr = b.load_param_u64("y_new_ptr");
let y_old_ptr = b.load_param_u64("y_old_ptr");
let gid_r = b.global_thread_id_x();
let elem_size = if float_ty == PtxType::F32 { 4u32 } else { 8u32 };
let new_addr = b.byte_offset_addr(y_new_ptr, gid_r.clone(), elem_size);
let old_addr = b.byte_offset_addr(y_old_ptr, gid_r, elem_size);
let new_val = load_float(b, float_ty, new_addr);
let old_val = load_float(b, float_ty, old_addr);
let diff = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"sub{} {diff}, {new_val}, {old_val};",
float_ty.as_ptx_str()
));
let diff_sq = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"mul{} {diff_sq}, {diff}, {diff};",
float_ty.as_ptx_str()
));
let norm_ptr = b.load_param_u64("norm_ptr");
if float_ty == PtxType::F64 {
let _old = b.atom_global_add_f64(norm_ptr, diff_sq);
} else {
let _old = b.atom_global_add_f32(norm_ptr, diff_sq);
}
});
b.ret();
})
.build()?;
Ok(ptx)
}
}
fn precision_to_ptx_type(precision: &str) -> SolverResult<PtxType> {
match precision {
"f32" => Ok(PtxType::F32),
"f64" => Ok(PtxType::F64),
other => Err(SolverError::InternalError(format!(
"unsupported precision '{other}'"
))),
}
}
fn ptx_type_suffix(ty: PtxType) -> &'static str {
match ty {
PtxType::F32 => "f32",
PtxType::F64 => "f64",
_ => "unknown",
}
}
fn load_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register) -> Register {
let dst = b.alloc_reg(float_ty);
b.raw_ptx(&format!(
"ld.global{} {dst}, [{addr}];",
float_ty.as_ptx_str()
));
dst
}
fn store_float(b: &mut BodyBuilder<'_>, float_ty: PtxType, addr: Register, val: Register) {
b.raw_ptx(&format!(
"st.global{} [{addr}], {val};",
float_ty.as_ptx_str()
));
}
fn zero_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
let dst = b.alloc_reg(float_ty);
if float_ty == PtxType::F32 {
let bits = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {bits}, 0;"));
b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
} else {
let bits = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {bits}, 0;"));
b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
}
dst
}
fn one_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
let dst = b.alloc_reg(float_ty);
if float_ty == PtxType::F32 {
let bits = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {bits}, 1065353216;"));
b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
} else {
let bits = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {bits}, 4607182418800017408;"));
b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
}
dst
}
fn half_const(b: &mut BodyBuilder<'_>, float_ty: PtxType) -> Register {
let dst = b.alloc_reg(float_ty);
if float_ty == PtxType::F32 {
let bits = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {bits}, 1056964608;"));
b.raw_ptx(&format!("mov.b32 {dst}, {bits};"));
} else {
let bits = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("mov.u64 {bits}, 4602678819172646912;"));
b.raw_ptx(&format!("mov.b64 {dst}, {bits};"));
}
dst
}
#[cfg(test)]
#[path = "matrix_functions_tests.rs"]
mod tests;