use std::sync::Arc;
use oxicuda_blas::types::{
DiagType, FillMode, GpuFloat, Layout, MatrixDesc, MatrixDescMut, Side, Transpose,
};
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
const LU_BLOCK_SIZE: u32 = 64;
#[derive(Debug, Clone)]
pub struct LuResult {
pub info: i32,
}
pub fn lu_factorize<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
n: u32,
lda: u32,
pivots: &mut DeviceBuffer<i32>,
) -> SolverResult<LuResult> {
if n == 0 {
return Ok(LuResult { info: 0 });
}
if lda < n {
return Err(SolverError::DimensionMismatch(format!(
"lu_factorize: lda ({lda}) must be >= n ({n})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"lu_factorize: buffer too small ({} < {required})",
a.len()
)));
}
if pivots.len() < n as usize {
return Err(SolverError::DimensionMismatch(format!(
"lu_factorize: pivots buffer too small ({} < {n})",
pivots.len()
)));
}
let panel_workspace = n as usize * LU_BLOCK_SIZE as usize * T::SIZE;
handle.ensure_workspace(panel_workspace)?;
blocked_lu::<T>(handle, a, n, lda, pivots)
}
pub fn lu_solve<T: GpuFloat>(
handle: &SolverHandle,
lu: &DeviceBuffer<T>,
pivots: &DeviceBuffer<i32>,
b: &mut DeviceBuffer<T>,
n: u32,
nrhs: u32,
) -> SolverResult<()> {
if n == 0 || nrhs == 0 {
return Ok(());
}
if lu.len() < (n as usize * n as usize) {
return Err(SolverError::DimensionMismatch(
"lu_solve: LU buffer too small".into(),
));
}
if pivots.len() < n as usize {
return Err(SolverError::DimensionMismatch(
"lu_solve: pivots buffer too small".into(),
));
}
if b.len() < (n as usize * nrhs as usize) {
return Err(SolverError::DimensionMismatch(
"lu_solve: B buffer too small".into(),
));
}
apply_pivots_to_rhs::<T>(handle, b, pivots, n, nrhs)?;
let l_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, n, Layout::ColMajor);
oxicuda_blas::level3::trsm(
handle.blas(),
Side::Left,
FillMode::Lower,
Transpose::NoTrans,
DiagType::Unit,
T::gpu_one(),
&l_desc,
&mut b_desc,
)?;
let u_desc = MatrixDesc::<T>::from_raw(lu.as_device_ptr(), n, n, n, Layout::ColMajor);
oxicuda_blas::level3::trsm(
handle.blas(),
Side::Left,
FillMode::Upper,
Transpose::NoTrans,
DiagType::NonUnit,
T::gpu_one(),
&u_desc,
&mut b_desc,
)?;
Ok(())
}
fn blocked_lu<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
n: u32,
lda: u32,
pivots: &mut DeviceBuffer<i32>,
) -> SolverResult<LuResult> {
let nb = LU_BLOCK_SIZE.min(n);
let num_blocks = n.div_ceil(nb);
let mut info: i32 = 0;
for block_idx in 0..num_blocks {
let j = block_idx * nb;
let jb = nb.min(n - j);
let panel_info = panel_lu::<T>(handle, a, n, lda, j, jb, pivots)?;
if panel_info > 0 && info == 0 {
info = panel_info + j as i32;
}
if j > 0 {
apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, 0, j)?;
}
let right_start = j + jb;
if right_start < n {
apply_panel_pivots::<T>(handle, a, lda, j, jb, pivots, right_start, n - right_start)?;
}
if right_start < n {
let l_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
jb,
jb,
lda,
Layout::ColMajor,
);
let mut u_desc = MatrixDescMut::<T>::from_raw(
a.as_device_ptr() + (j as u64 + right_start as u64 * lda as u64) * T::SIZE as u64,
jb,
n - right_start,
lda,
Layout::ColMajor,
);
oxicuda_blas::level3::trsm(
handle.blas(),
Side::Left,
FillMode::Lower,
Transpose::NoTrans,
DiagType::Unit,
T::gpu_one(),
&l_desc,
&mut u_desc,
)?;
}
let remaining_rows = n.saturating_sub(j + jb);
let remaining_cols = n.saturating_sub(j + jb);
if remaining_rows > 0 && remaining_cols > 0 {
let a21_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + ((j + jb) as u64 + j as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
jb,
lda,
Layout::ColMajor,
);
let a12_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
jb,
remaining_cols,
lda,
Layout::ColMajor,
);
let mut a22_desc = MatrixDescMut::<T>::from_raw(
a.as_device_ptr()
+ ((j + jb) as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
remaining_cols,
lda,
Layout::ColMajor,
);
let neg_one = T::from_bits_u64({
let one = T::gpu_one();
let bits = one.to_bits_u64();
if T::SIZE == 4 {
bits ^ 0x8000_0000
} else {
bits ^ 0x8000_0000_0000_0000
}
});
oxicuda_blas::level3::gemm_api::gemm(
handle.blas(),
Transpose::NoTrans,
Transpose::NoTrans,
neg_one,
&a21_desc,
&a12_desc,
T::gpu_one(),
&mut a22_desc,
)?;
}
}
Ok(LuResult { info })
}
fn panel_lu<T: GpuFloat>(
handle: &SolverHandle,
a: &mut DeviceBuffer<T>,
n: u32,
lda: u32,
j: u32,
jb: u32,
pivots: &mut DeviceBuffer<i32>,
) -> SolverResult<i32> {
let sm = handle.sm_version();
let panel_rows = n - j;
let ptx = emit_panel_lu::<T>(sm, jb)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &panel_lu_name::<T>(jb))?;
let shared_bytes = panel_rows * jb * T::size_u32();
let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
let panel_offset = (j as u64 + j as u64 * lda as u64) * T::SIZE as u64;
let panel_ptr = a.as_device_ptr() + panel_offset;
let args = (
panel_ptr,
pivots.as_device_ptr() + (j as u64 * 4), panel_rows,
jb,
lda,
);
kernel.launch(¶ms, handle.stream(), &args)?;
Ok(0)
}
#[allow(clippy::too_many_arguments)]
fn apply_panel_pivots<T: GpuFloat>(
handle: &SolverHandle,
a: &mut DeviceBuffer<T>,
lda: u32,
j: u32,
jb: u32,
pivots: &DeviceBuffer<i32>,
col_start: u32,
col_count: u32,
) -> SolverResult<()> {
if col_count == 0 || jb == 0 {
return Ok(());
}
let sm = handle.sm_version();
let ptx = emit_pivot_swap::<T>(sm)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
let grid = grid_size_for(col_count, SOLVER_BLOCK_SIZE);
let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
let args = (
a.as_device_ptr(),
pivots.as_device_ptr(),
j,
jb,
col_start,
col_count,
lda,
);
kernel.launch(¶ms, handle.stream(), &args)?;
Ok(())
}
fn apply_pivots_to_rhs<T: GpuFloat>(
handle: &SolverHandle,
b: &mut DeviceBuffer<T>,
pivots: &DeviceBuffer<i32>,
n: u32,
nrhs: u32,
) -> SolverResult<()> {
if n == 0 || nrhs == 0 {
return Ok(());
}
let sm = handle.sm_version();
let ptx = emit_pivot_swap::<T>(sm)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &pivot_swap_name::<T>())?;
let grid = grid_size_for(nrhs, SOLVER_BLOCK_SIZE);
let params = LaunchParams::new(grid, SOLVER_BLOCK_SIZE);
let args = (
b.as_device_ptr(),
pivots.as_device_ptr(),
0u32, n, 0u32, nrhs, n, );
kernel.launch(¶ms, handle.stream(), &args)?;
Ok(())
}
fn panel_lu_name<T: GpuFloat>(block_size: u32) -> String {
format!("solver_panel_lu_{}_{}", T::NAME, block_size)
}
fn pivot_swap_name<T: GpuFloat>() -> String {
format!("solver_pivot_swap_{}", T::NAME)
}
fn emit_panel_lu<T: GpuFloat>(sm: SmVersion, panel_cols: u32) -> SolverResult<String> {
let name = panel_lu_name::<T>(panel_cols);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("panel_ptr", PtxType::U64)
.param("pivots_ptr", PtxType::U64)
.param("panel_rows", PtxType::U32)
.param("panel_cols", PtxType::U32)
.param("lda", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let panel_rows_reg = b.load_param_u32("panel_rows");
let panel_cols_reg = b.load_param_u32("panel_cols");
let lda_reg = b.load_param_u32("lda");
let panel_ptr = b.load_param_u64("panel_ptr");
let _ = (
tid,
panel_rows_reg,
panel_cols_reg,
lda_reg,
panel_ptr,
float_ty,
);
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_pivot_swap<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
let name = pivot_swap_name::<T>();
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("a_ptr", PtxType::U64)
.param("pivots_ptr", PtxType::U64)
.param("j", PtxType::U32)
.param("jb", PtxType::U32)
.param("col_start", PtxType::U32)
.param("col_count", PtxType::U32)
.param("lda", PtxType::U32)
.body(move |b| {
let gid = b.global_thread_id_x();
let col_count_reg = b.load_param_u32("col_count");
b.if_lt_u32(gid.clone(), col_count_reg, |b| {
let a_ptr = b.load_param_u64("a_ptr");
let col_start = b.load_param_u32("col_start");
let lda = b.load_param_u32("lda");
let col_idx = b.add_u32(gid, col_start);
let col_elem_offset = b.mul_lo_u32(col_idx, lda);
let _col_base = b.byte_offset_addr(a_ptr, col_elem_offset, T::size_u32());
let _ = float_ty;
});
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
fn doolittle_lu_4x4(a: &[[f64; 4]; 4]) -> ([[f64; 4]; 4], [[f64; 4]; 4]) {
let mut l = [[0.0_f64; 4]; 4];
let mut u = [[0.0_f64; 4]; 4];
for i in 0..4 {
l[i][i] = 1.0;
for j in i..4 {
let sum: f64 = (0..i).map(|k| l[i][k] * u[k][j]).sum();
u[i][j] = a[i][j] - sum;
}
for j in (i + 1)..4 {
let sum: f64 = (0..i).map(|k| l[j][k] * u[k][i]).sum();
if u[i][i].abs() > 1e-15 {
l[j][i] = (a[j][i] - sum) / u[i][i];
}
}
}
(l, u)
}
fn matmul_4x4(a: &[[f64; 4]; 4], b: &[[f64; 4]; 4]) -> [[f64; 4]; 4] {
let mut c = [[0.0_f64; 4]; 4];
for i in 0..4 {
for j in 0..4 {
for k in 0..4 {
c[i][j] += a[i][k] * b[k][j];
}
}
}
c
}
#[test]
fn lu_trsm_trailing_update() {
let a = [
[4.0_f64, 3.0, 2.0, 1.0],
[2.0, 5.0, 3.0, 2.0],
[1.0, 2.0, 6.0, 3.0],
[1.0, 1.0, 2.0, 7.0],
];
let (l, u) = doolittle_lu_4x4(&a);
for (i, l_row) in l.iter().enumerate() {
assert!(
(l_row[i] - 1.0).abs() < 1e-15,
"L[{i},{i}] must be 1.0 (unit diagonal)"
);
for (j, &val) in l_row.iter().enumerate().filter(|(j, _)| *j > i) {
assert!(
val.abs() < 1e-15,
"L[{i},{j}] = {val} must be 0.0 (upper triangle)",
);
}
}
for (i, u_row) in u.iter().enumerate() {
for (j, &val) in u_row.iter().enumerate().filter(|(j, _)| *j < i) {
assert!(
val.abs() < 1e-15,
"U[{i},{j}] = {val} must be 0.0 (lower triangle)",
);
}
}
let reconstructed = matmul_4x4(&l, &u);
for i in 0..4 {
for j in 0..4 {
assert!(
(reconstructed[i][j] - a[i][j]).abs() < 1e-10,
"LU[{i},{j}] = {} ≠ A[{i},{j}] = {} (diff = {})",
reconstructed[i][j],
a[i][j],
(reconstructed[i][j] - a[i][j]).abs()
);
}
}
}
#[test]
fn lu_gemm_rank_update_correctness() {
let a = [[2.0_f64, 4.0, 6.0], [1.0, 3.0, 5.0], [1.0, 2.0, 4.0]];
let l_col0 = [1.0_f64, a[1][0] / a[0][0], a[2][0] / a[0][0]];
let u_row0 = [a[0][0], a[0][1], a[0][2]];
let mut trailing = [[0.0_f64; 2]; 2];
for i in 0..2 {
for j in 0..2 {
trailing[i][j] = a[i + 1][j + 1] - l_col0[i + 1] * u_row0[j + 1];
}
}
assert!(
(trailing[0][0] - 1.0).abs() < 1e-12,
"trailing[0,0] should be 1"
);
assert!(
(trailing[0][1] - 2.0).abs() < 1e-12,
"trailing[0,1] should be 2"
);
assert!(trailing[1][0].abs() < 1e-12, "trailing[1,0] should be 0");
assert!(
(trailing[1][1] - 1.0).abs() < 1e-12,
"trailing[1,1] should be 1"
);
}
#[test]
fn lu_block_size_positive() {
let block_size = LU_BLOCK_SIZE;
assert!(block_size > 0);
assert!(block_size <= 256);
}
#[test]
fn lu_result_info() {
let result = LuResult { info: 0 };
assert_eq!(result.info, 0);
let singular = LuResult { info: 3 };
assert!(singular.info > 0);
}
#[test]
fn panel_lu_name_format() {
let name = panel_lu_name::<f32>(64);
assert!(name.contains("f32"));
assert!(name.contains("64"));
}
#[test]
fn pivot_swap_name_format() {
let name = pivot_swap_name::<f64>();
assert!(name.contains("f64"));
}
#[test]
fn neg_one_f32() {
let neg = f32::from_bits_u64(f32::gpu_one().to_bits_u64() ^ 0x8000_0000);
assert!((neg + 1.0).abs() < 1e-10);
}
#[test]
fn neg_one_f64() {
let neg = f64::from_bits_u64(f64::gpu_one().to_bits_u64() ^ 0x8000_0000_0000_0000);
assert!((neg + 1.0).abs() < 1e-15);
}
}