#![allow(dead_code)]
use oxicuda_blas::GpuFloat;
use oxicuda_memory::DeviceBuffer;
use crate::dense::lu::{lu_factorize, lu_solve};
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
pub fn direct_solve<T: GpuFloat>(
handle: &mut SolverHandle,
a_dense: &mut DeviceBuffer<T>,
n: u32,
b: &mut DeviceBuffer<T>,
nrhs: u32,
) -> SolverResult<()> {
if n == 0 || nrhs == 0 {
return Ok(());
}
let a_required = n as usize * n as usize;
if a_dense.len() < a_required {
return Err(SolverError::DimensionMismatch(format!(
"direct_solve: A buffer too small ({} < {a_required})",
a_dense.len()
)));
}
let b_required = n as usize * nrhs as usize;
if b.len() < b_required {
return Err(SolverError::DimensionMismatch(format!(
"direct_solve: B buffer too small ({} < {b_required})",
b.len()
)));
}
let mut pivots = DeviceBuffer::<i32>::zeroed(n as usize)?;
let lu_result = lu_factorize(handle, a_dense, n, n, &mut pivots)?;
if lu_result.info > 0 {
return Err(SolverError::SingularMatrix);
}
lu_solve(handle, a_dense, &pivots, b, n, nrhs)?;
Ok(())
}
pub fn prefer_direct_solver(n: usize, density: f64) -> bool {
n <= 100 || density > 0.3
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn direct_solve_zero_dimension() {
}
#[test]
fn direct_solve_structure() {
let steps = ["lu_factorize", "lu_solve"];
assert_eq!(steps.len(), 2);
}
#[test]
fn sparse_direct_vs_iterative_selection() {
assert!(
!prefer_direct_solver(100_001, 0.009),
"large sparse system should prefer iterative"
);
assert!(
prefer_direct_solver(50, 0.5),
"small system should prefer direct"
);
assert!(
prefer_direct_solver(100, 0.01),
"n=100 is within direct solver range"
);
assert!(
prefer_direct_solver(500, 0.4),
"density 0.4 > 0.3 → prefer direct"
);
assert!(
!prefer_direct_solver(10_000, 0.001),
"n=10000 with density 0.001 should prefer iterative"
);
}
#[test]
fn prefer_direct_solver_density_boundary() {
let n_large = 1000;
assert!(
prefer_direct_solver(n_large, 0.31),
"density 0.31 > 0.3 should prefer direct"
);
assert!(
!prefer_direct_solver(n_large, 0.29),
"density 0.29 <= 0.3 with large n should prefer iterative"
);
}
#[test]
fn prefer_direct_solver_small_system() {
for n in [1_usize, 10, 50, 100] {
for &density in &[0.001, 0.1, 0.5, 1.0] {
assert!(
prefer_direct_solver(n, density),
"n={n} is small enough for direct solver"
);
}
}
}
}