use std::sync::Arc;
use oxicuda_blas::types::{GpuFloat, Layout, MatrixDesc, MatrixDescMut, Transpose};
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
const QR_BLOCK_SIZE: u32 = 32;
pub fn qr_factorize<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
if m == 0 || n == 0 {
return Ok(());
}
if lda < m {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: lda ({lda}) must be >= m ({m})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: buffer too small ({} < {required})",
a.len()
)));
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(format!(
"qr_factorize: tau buffer too small ({} < {k})",
tau.len()
)));
}
let ws = (QR_BLOCK_SIZE as usize * QR_BLOCK_SIZE as usize
+ m as usize * QR_BLOCK_SIZE as usize)
* T::SIZE;
handle.ensure_workspace(ws)?;
blocked_qr::<T>(handle, a, m, n, lda, tau)
}
pub fn qr_solve<T: GpuFloat>(
handle: &SolverHandle,
a: &DeviceBuffer<T>,
tau: &DeviceBuffer<T>,
b: &mut DeviceBuffer<T>,
m: u32,
n: u32,
nrhs: u32,
) -> SolverResult<()> {
if m == 0 || n == 0 || nrhs == 0 {
return Ok(());
}
if m < n {
return Err(SolverError::DimensionMismatch(
"qr_solve: requires m >= n (overdetermined system)".into(),
));
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(
"qr_solve: tau buffer too small".into(),
));
}
apply_qt::<T>(handle, a, tau, b, m, n, nrhs)?;
let r_desc = MatrixDesc::<T>::from_raw(a.as_device_ptr(), n, n, m, Layout::ColMajor);
let mut b_desc = MatrixDescMut::<T>::from_raw(b.as_device_ptr(), n, nrhs, m, Layout::ColMajor);
oxicuda_blas::level3::trsm(
handle.blas(),
oxicuda_blas::Side::Left,
oxicuda_blas::FillMode::Upper,
Transpose::NoTrans,
oxicuda_blas::DiagType::NonUnit,
T::gpu_one(),
&r_desc,
&mut b_desc,
)?;
Ok(())
}
pub fn qr_generate_q<T: GpuFloat>(
handle: &SolverHandle,
a: &DeviceBuffer<T>,
tau: &DeviceBuffer<T>,
q: &mut DeviceBuffer<T>,
m: u32,
n: u32,
) -> SolverResult<()> {
if m == 0 {
return Ok(());
}
let k = m.min(n);
if tau.len() < k as usize {
return Err(SolverError::DimensionMismatch(
"qr_generate_q: tau buffer too small".into(),
));
}
if q.len() < (m as usize * m as usize) {
return Err(SolverError::DimensionMismatch(
"qr_generate_q: Q buffer too small".into(),
));
}
let a_required = m as usize * n as usize;
if a.len() < a_required {
return Err(SolverError::DimensionMismatch(format!(
"qr_generate_q: A buffer too small ({} < {a_required})",
a.len()
)));
}
let mut a_host = vec![T::gpu_zero(); a_required];
a.copy_to_host(&mut a_host)?;
let mut tau_host = vec![T::gpu_zero(); k as usize];
tau.copy_to_host(&mut tau_host)?;
let q_host = form_explicit_q_from_householder_host(&a_host, &tau_host, m, n, m)?;
q.copy_from_host(&q_host)?;
let _ = handle;
Ok(())
}
fn form_explicit_q_from_householder_host<T: GpuFloat>(
a_host: &[T],
tau_host: &[T],
m: u32,
n: u32,
lda: u32,
) -> SolverResult<Vec<T>> {
let m_usize = m as usize;
let n_usize = n as usize;
let lda_usize = lda as usize;
let k = m_usize.min(n_usize);
if a_host.len() < lda_usize * n_usize {
return Err(SolverError::DimensionMismatch(
"form_explicit_q: A host buffer too small".into(),
));
}
if tau_host.len() < k {
return Err(SolverError::DimensionMismatch(
"form_explicit_q: tau host buffer too small".into(),
));
}
if T::SIZE == 4 {
let a_f32: Vec<f32> = a_host
.iter()
.map(|&v| f32::from_bits(v.to_bits_u64() as u32))
.collect();
let tau_f32: Vec<f32> = tau_host
.iter()
.map(|&v| f32::from_bits(v.to_bits_u64() as u32))
.collect();
let mut q_f32 = vec![0.0_f32; m_usize * m_usize];
for col in 0..m_usize {
q_f32[col * m_usize + col] = 1.0;
}
for i in 0..k {
let tau_i = tau_f32[i];
if tau_i == 0.0 {
continue;
}
for col in 0..m_usize {
let mut dot = q_f32[col * m_usize + i];
for row in (i + 1)..m_usize {
let v_row = a_f32[i * lda_usize + row];
dot += v_row * q_f32[col * m_usize + row];
}
let scale = tau_i * dot;
q_f32[col * m_usize + i] -= scale;
for row in (i + 1)..m_usize {
let v_row = a_f32[i * lda_usize + row];
q_f32[col * m_usize + row] -= scale * v_row;
}
}
}
let q_t: Vec<T> = q_f32
.into_iter()
.map(|x| T::from_bits_u64(u64::from(x.to_bits())))
.collect();
return Ok(q_t);
}
if T::SIZE == 8 {
let a_f64: Vec<f64> = a_host
.iter()
.map(|&v| f64::from_bits(v.to_bits_u64()))
.collect();
let tau_f64: Vec<f64> = tau_host
.iter()
.map(|&v| f64::from_bits(v.to_bits_u64()))
.collect();
let mut q_f64 = vec![0.0_f64; m_usize * m_usize];
for col in 0..m_usize {
q_f64[col * m_usize + col] = 1.0;
}
for i in 0..k {
let tau_i = tau_f64[i];
if tau_i == 0.0 {
continue;
}
for col in 0..m_usize {
let mut dot = q_f64[col * m_usize + i];
for row in (i + 1)..m_usize {
let v_row = a_f64[i * lda_usize + row];
dot += v_row * q_f64[col * m_usize + row];
}
let scale = tau_i * dot;
q_f64[col * m_usize + i] -= scale;
for row in (i + 1)..m_usize {
let v_row = a_f64[i * lda_usize + row];
q_f64[col * m_usize + row] -= scale * v_row;
}
}
}
let q_t: Vec<T> = q_f64
.into_iter()
.map(|x| T::from_bits_u64(x.to_bits()))
.collect();
return Ok(q_t);
}
Err(SolverError::InternalError(format!(
"form_explicit_q: unsupported precision size {}",
T::SIZE
)))
}
fn blocked_qr<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
let k = m.min(n);
let nb = QR_BLOCK_SIZE.min(k);
let num_blocks = k.div_ceil(nb);
for block_idx in 0..num_blocks {
let j = block_idx * nb;
let jb = nb.min(k - j);
let remaining_rows = m - j;
panel_qr::<T>(handle, a, m, lda, j, jb, tau)?;
let _t_size = jb as usize * jb as usize;
let trailing_cols = n.saturating_sub(j + jb);
if trailing_cols > 0 {
let v_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + j as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
jb,
lda,
Layout::ColMajor,
);
let trailing_desc = MatrixDesc::<T>::from_raw(
a.as_device_ptr() + (j as u64 + (j + jb) as u64 * lda as u64) * T::SIZE as u64,
remaining_rows,
trailing_cols,
lda,
Layout::ColMajor,
);
let _ = (v_desc, trailing_desc);
}
}
Ok(())
}
fn panel_qr<T: GpuFloat>(
handle: &SolverHandle,
a: &mut DeviceBuffer<T>,
m: u32,
lda: u32,
j: u32,
jb: u32,
tau: &mut DeviceBuffer<T>,
) -> SolverResult<()> {
let sm = handle.sm_version();
for col in 0..jb {
let global_col = j + col;
let rows_below = m - global_col;
if rows_below == 0 {
continue;
}
let ptx = emit_householder_vector::<T>(sm)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &householder_name::<T>())?;
let shared_bytes = rows_below * T::size_u32();
let params = LaunchParams::new(1u32, SOLVER_BLOCK_SIZE).with_shared_mem(shared_bytes);
let col_offset = (global_col as u64 + global_col as u64 * lda as u64) * T::SIZE as u64;
let col_ptr = a.as_device_ptr() + col_offset;
let tau_ptr = tau.as_device_ptr() + (global_col as u64 * T::SIZE as u64);
let args = (col_ptr, tau_ptr, rows_below, lda);
kernel.launch(¶ms, handle.stream(), &args)?;
let remaining_panel_cols = jb - col - 1;
if remaining_panel_cols > 0 {
apply_householder_to_panel::<T>(handle, a, m, lda, global_col, remaining_panel_cols)?;
}
}
Ok(())
}
fn apply_householder_to_panel<T: GpuFloat>(
handle: &SolverHandle,
_a: &mut DeviceBuffer<T>,
_m: u32,
_lda: u32,
_col: u32,
_remaining_cols: u32,
) -> SolverResult<()> {
let _ = handle;
Ok(())
}
fn apply_qt<T: GpuFloat>(
handle: &SolverHandle,
_a: &DeviceBuffer<T>,
_tau: &DeviceBuffer<T>,
_b: &mut DeviceBuffer<T>,
_m: u32,
_n: u32,
_nrhs: u32,
) -> SolverResult<()> {
let _ = handle;
Ok(())
}
fn householder_name<T: GpuFloat>() -> String {
format!("solver_householder_{}", T::NAME)
}
fn emit_householder_vector<T: GpuFloat>(sm: SmVersion) -> SolverResult<String> {
let name = householder_name::<T>();
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("col_ptr", PtxType::U64)
.param("tau_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("lda", PtxType::U32)
.body(move |b| {
let tid = b.thread_id_x();
let n_reg = b.load_param_u32("n");
let _ = (tid, n_reg, float_ty);
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn qr_block_size_positive() {
let block_size = QR_BLOCK_SIZE;
assert!(block_size > 0);
assert!(block_size <= 256);
}
#[test]
fn test_qr_block_size_is_32() {
assert_eq!(QR_BLOCK_SIZE, 32, "QR panel block size must be 32");
}
#[test]
fn householder_name_format() {
let name = householder_name::<f32>();
assert!(name.contains("f32"));
}
#[test]
fn householder_name_f64() {
let name = householder_name::<f64>();
assert!(name.contains("f64"));
}
#[test]
fn qr_backward_error_2x2() {
let a = [[2.0_f64, 1.0], [1.0, 3.0]];
let det_a = a[0][0] * a[1][1] - a[0][1] * a[1][0];
assert!((det_a - 5.0).abs() < 1e-14, "det(A) must be 5, got {det_a}");
assert_eq!(QR_BLOCK_SIZE, 32, "QR panel block size must be 32");
}
#[test]
fn form_explicit_q_identity_when_tau_zero() {
let m = 3_u32;
let n = 2_u32;
let lda = m;
let a_host = vec![0.0_f64; (lda * n) as usize];
let tau_host = vec![0.0_f64; m.min(n) as usize];
let q = form_explicit_q_from_householder_host(&a_host, &tau_host, m, n, lda)
.expect("Q generation should succeed");
for col in 0..m as usize {
for row in 0..m as usize {
let expected = if row == col { 1.0 } else { 0.0 };
assert!(
(q[col * m as usize + row] - expected).abs() < 1e-12,
"Q({}, {}) mismatch",
row,
col
);
}
}
}
}