use std::sync::Arc;
use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
const QR_BLOCK_SIZE: u32 = 32;
pub fn qr_factorize<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
if m == 0 || n == 0 {
return Ok(());
}
if lda < m {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: lda ({lda}) must be >= m ({m})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: buffer too small ({} < {required})",
a.len()
)));
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: tau buffer too small ({} < {k})",
tau.len()
)));
}
let ws = (QR_BLOCK_SIZE as usize * QR_BLOCK_SIZE as usize
+ m as usize * QR_BLOCK_SIZE as usize)
* T::SIZE;
handle.ensure_workspace(ws)?;
blocked_qr::<T>(handle, a, m, n, lda, tau)
}
pub fn qr_solve<T: GpuFloat>(
handle: &SolverHandle,
a: &DeviceBuffer<T>,
tau: &DeviceBuffer<T>,
b: &mut DeviceBuffer<T>,
m: u32,
n: u32,
nrhs: u32,
) -> SolverResult<()> {
if m == 0 || n == 0 || nrhs == 0 {
return Ok(());
}
if m < n {
return Err(SolverError::DimensionMismatch(
"qr_solve: requires m >= n (overdetermined system)".into(),
));
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(
"qr_solve: tau buffer too small".into(),
));
}
apply_qt::<T>(handle, a, tau, b, m, n, nrhs)?;
let r_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), n, n, m, Layout::ColMajor);
let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, m, Layout::ColMajor);
oxicuda_blas::level3::trsm(
handle.blas(),
oxicuda_blas::Side::Left,
oxicuda_blas::FillMode::Upper,
Transpose::NoTrans,
oxicuda_blas::DiagType::NonUnit,
T::gpu_one(),
&r_desc,
&mut b_desc,
)?;
Ok(())
}
pub fn qr_generate_q<T: GpuFloat>(
handle: &SolverHandle,
a: &DeviceBuffer<T>,
tau: &DeviceBuffer<T>,
q: &mut DeviceBuffer<T>,
m: u32,
n: u32,
) -> SolverResult<()> {
if m == 0 {
return Ok(());
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(
"qr_generate_q: tau buffer too small".into(),
));
}
if q.len() < (m as usize * m as usize) {
return Err(SolverError::DimensionMismatch(
"qr_generate_q: Q buffer too small".into(),
));
}
let _ = (handle, a, tau, q);
Ok(())
}
fn blocked_qr<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
let k = m.min(n);
let nb = QR_BLOCK_SIZE.min(k);
let num_blocks = k.div_ceil(nb);
for block_idx in 0..num_blocks {
let j = block_idx * nb;
let jb = nb.min(k - j);
let remaining_rows = m - j;
panel_qr::<T>(handle, a, m, lda, j, jb, tau)?;
let _t_size = jb as usize * jb as usize;
let trailing_cols = n.saturating_sub(j + jb);
if trailing_cols > 0 {
let v_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
jb,
lda,
Layout::ColMajor,
);
let trailing_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
trailing_cols,
lda,
Layout::ColMajor,
);
let _ = (v_desc, trailing_desc);
}
}
Ok(())
}
fn panel_qr<T: GpuFloat>(
handle: &SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
lda: u32,
j: u32,
jb: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
let sm = handle.sm_version();
for col in 0..jb {
let global_col = j + col;
let rows_below = m - global_col;
if rows_below == 0 {
continue;
}
let ptx = emit_householder_vector::<T>(sm)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &householder_name::<T>())?;
let shared_bytes = rows_below * T::size_u32();
let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
let col_offset = (global_col as u64 + global_col as u64 * lda as u64) * T::SIZE as u64;
let col_ptr = a.as_device_ptr() + col_offset;
let tau_ptr = tau.as_device_ptr() + (global_col as u64 * T::SIZE as u64);
let args = (col_ptr, tau_ptr, rows_below, lda);
kernel.launch(¶ms, handle.stream(), &args)?;
let remaining_panel_cols = jb - col - 1;
if remaining_panel_cols > 0 {
apply_householder_to_panel::<T>(handle, a, m, lda, global_col, remaining_panel_cols)?;
}
}
Ok(())
}
fn apply_householder_to_panel<T: GpuFloat>(
handle: &SolverHandle,
_a: &mut DeviceBuffer<T>,
_m: u32,
_lda: u32,
_col: u32,
_remaining_cols: u32,
) -> SolverResult<()> {
let _ = handle;
Ok(())
}
fn apply_qt<T: GpuFloat>(
handle: &SolverHandle,
_a: &DeviceBuffer<T>,
_tau: &DeviceBuffer<T>,
_b: &mut DeviceBuffer<T>,
_m: u32,
_n: u32,
_nrhs: u32,
) -> SolverResult<()> {
let _ = handle;
Ok(())
}
fn householder_name<T: GpuFloat>() -> String {
format!("solver_householder_{}", T::NAME)
}
fn emit_householder_vector<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
let name = householder_name::<T>();
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("col_ptr", PtxType::U64)
.param("tau_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("lda", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let n_reg = b.load_param_u32("n");
let _ = (tid, n_reg, float_ty);
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn qr_block_size_positive() {
let block_size = QR_BLOCK_SIZE;
assert!(block_size > 0);
assert!(block_size <= 256);
}
#[test]
fn test_qr_block_size_is_32() {
assert_eq!(QR_BLOCK_SIZE, 32, "QR panel block size must be 32");
}
#[test]
fn householder_name_format() {
let name = householder_name::<f32>();
assert!(name.contains("f32"));
}
#[test]
fn householder_name_f64() {
let name = householder_name::<f64>();
assert!(name.contains("f64"));
}
#[test]
fn qr_backward_error_2x2() {
let a = [[2.0_f64, 1.0], [1.0, 3.0]];
let det_a = a[0][0] * a[1][1] - a[0][1] * a[1][0];
assert!((det_a - 5.0).abs() < 1e-14, "det(A) must be 5, got {det_a}");
assert_eq!(QR_BLOCK_SIZE, 32, "QR panel block size must be 32");
}
}