#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::dense::lu::{lu_factorize, lu_solve};
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
pub fn inverse<T: GpuFloat>(
handle: &mut SolverHandle,
a: &mut DeviceBuffer<T>,
n: u32,
lda: u32,
) -> SolverResult<()> {
if n == 0 {
return Ok(());
}
if lda < n {
return Err(SolverError::DimensionMismatch(format!(
"inverse: lda ({lda}) must be >= n ({n})"
)));
}
let required = n as usize * lda as usize;
if a.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"inverse: buffer too small ({} < {required})",
a.len()
)));
}
let mut pivots = DeviceBuffer::<i32>::zeroed(n as usize)?;
let lu_result = lu_factorize(handle, a, n, lda, &mut pivots)?;
if lu_result.info > 0 {
return Err(SolverError::SingularMatrix);
}
let identity_size = n as usize * n as usize;
let mut identity = DeviceBuffer::<T>::zeroed(identity_size)?;
set_identity_diagonal(handle, &mut identity, n)?;
lu_solve(handle, a, &pivots, &mut identity, n, n)?;
copy_matrix(handle, &identity, a, n, n)?;
Ok(())
}
fn set_identity_diagonal<T: GpuFloat>(
_handle: &SolverHandle,
identity: &mut DeviceBuffer<T>,
n: u32,
) -> SolverResult<()> {
let n_usize = n as usize;
let required = n_usize * n_usize;
if identity.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"set_identity_diagonal: buffer too small ({} < {required})",
identity.len()
)));
}
let mut host = vec![T::gpu_zero(); identity.len()];
for i in 0..n_usize {
host[i * n_usize + i] = T::gpu_one();
}
identity.copy_from_host(&host)?;
Ok(())
}
fn copy_matrix<T: GpuFloat>(
_handle: &SolverHandle,
src: &DeviceBuffer<T>,
dst: &mut DeviceBuffer<T>,
n: u32,
lda: u32,
) -> SolverResult<()> {
let n_usize = n as usize;
let lda_usize = lda as usize;
if lda_usize < n_usize {
return Err(SolverError::DimensionMismatch(format!(
"copy_matrix: lda ({lda}) must be >= n ({n})"
)));
}
let src_required = n_usize * n_usize;
if src.len() < src_required {
return Err(SolverError::DimensionMismatch(format!(
"copy_matrix: src buffer too small ({} < {src_required})",
src.len()
)));
}
let dst_required = n_usize * lda_usize;
if dst.len() < dst_required {
return Err(SolverError::DimensionMismatch(format!(
"copy_matrix: dst buffer too small ({} < {dst_required})",
dst.len()
)));
}
let mut src_host = vec![T::gpu_zero(); src.len()];
src.copy_to_host(&mut src_host)?;
let mut dst_host = vec![T::gpu_zero(); dst.len()];
dst.copy_to_host(&mut dst_host)?;
for col in 0..n_usize {
for row in 0..n_usize {
dst_host[col * lda_usize + row] = src_host[col * n_usize + row];
}
}
dst.copy_from_host(&dst_host)?;
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn inverse_validates_zero_dimension() {
}
#[test]
fn inverse_structure() {
let steps = ["lu_factorize", "set_identity", "lu_solve", "copy"];
assert_eq!(steps.len(), 4);
}
}