#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::dense::qr::{qr_factorize, qr_solve};
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
pub fn lstsq<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
b: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
nrhs: u32,
) -> SolverResult<()> {
if m == 0 || n == 0 || nrhs == 0 {
return Ok(());
}
if lda < m {
return Err(SolverError::DimensionMismatch(format!(
"lstsq: lda ({lda}) must be >= m ({m})"
)));
}
let a_required = n as usize * lda as usize;
if a.len() < a_required {
return Err(SolverError::DimensionMismatch(format!(
"lstsq: A buffer too small ({} < {a_required})",
a.len()
)));
}
let b_ldb = m.max(n);
let b_required = nrhs as usize * b_ldb as usize;
if b.len() < b_required {
return Err(SolverError::DimensionMismatch(format!(
"lstsq: B buffer too small ({} < {b_required})",
b.len()
)));
}
if m >= n {
lstsq_overdetermined(handle, a, b, m, n, lda, nrhs)
} else {
lstsq_underdetermined(handle, a, b, m, n, lda, nrhs)
}
}
fn lstsq_overdetermined<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
b: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
nrhs: u32,
) -> SolverResult<()> {
let k = m.min(n);
let mut tau = DeviceBuffer::<T>::zeroed(k as usize)?;
qr_factorize(handle, a, m, n, lda, &mut tau)?;
qr_solve(handle, a, &tau, b, m, n, nrhs)?;
Ok(())
}
fn lstsq_underdetermined<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
b: &mut DeviceBuffer<T>,
m: u32,
n: u32,
lda: u32,
nrhs: u32,
) -> SolverResult<()> {
let at_size = n as usize * m as usize;
let mut at = DeviceBuffer::<T>::zeroed(at_size)?;
transpose_matrix(handle, a, &mut at, m, n, lda, n)?;
let k = n.min(m);
let mut tau = DeviceBuffer::<T>::zeroed(k as usize)?;
qr_factorize(handle, &mut at, n, m, n, &mut tau)?;
solve_rt_forward(handle, &at, b, m, n, nrhs)?;
apply_q_for_min_norm(handle, &at, &tau, b, m, n, nrhs)?;
Ok(())
}
fn transpose_matrix<T: GpuFloat>(
_handle: &SolverHandle,
_src: &DeviceBuffer<T>,
_dst: &mut DeviceBuffer<T>,
_m: u32,
_n: u32,
_ld_src: u32,
_ld_dst: u32,
) -> SolverResult<()> {
Ok(())
}
fn solve_rt_forward<T: GpuFloat>(
_handle: &SolverHandle,
_at: &DeviceBuffer<T>,
_b: &mut DeviceBuffer<T>,
_m: u32,
_n: u32,
_nrhs: u32,
) -> SolverResult<()> {
Ok(())
}
fn apply_q_for_min_norm<T: GpuFloat>(
_handle: &SolverHandle,
_at: &DeviceBuffer<T>,
_tau: &DeviceBuffer<T>,
_b: &mut DeviceBuffer<T>,
_m: u32,
_n: u32,
_nrhs: u32,
) -> SolverResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn lstsq_overdetermined_path() {
let m = 10_u32;
let n = 5_u32;
assert!(m >= n, "overdetermined");
}
#[test]
fn lstsq_underdetermined_path() {
let m = 3_u32;
let n = 8_u32;
assert!(m < n, "underdetermined");
}
#[test]
fn lstsq_square_is_overdetermined() {
let m = 5_u32;
let n = 5_u32;
assert!(m >= n);
}
#[test]
fn lstsq_zero_dimensions() {
}
}