oxicuda_solver/sparse/direct.rs
1//! Direct sparse solver via dense LU factorization.
2//!
3//! Provides a simple direct solve path for small-to-medium sparse systems
4//! by delegating to the dense LU factorization and solve routines. This is
5//! useful as a fallback when iterative methods fail to converge or when the
6//! system is small enough that direct methods are competitive.
7//!
8//! For large sparse systems, use the iterative solvers (CG, BiCGSTAB, GMRES)
9//! which are much more memory-efficient.
10
11#![allow(dead_code)]
12
13use oxicuda_blas::GpuFloat;
14use oxicuda_memory::DeviceBuffer;
15
16use crate::dense::lu::{lu_factorize, lu_solve};
17use crate::error::{SolverError, SolverResult};
18use crate::handle::SolverHandle;
19
20// ---------------------------------------------------------------------------
21// Public API
22// ---------------------------------------------------------------------------
23
24/// Solves `A * X = B` directly using dense LU factorization.
25///
26/// The matrix `a_dense` is the dense representation of the sparse matrix,
27/// stored in column-major order with leading dimension `n`. The right-hand
28/// side `b` is overwritten with the solution.
29///
30/// This function is a convenience wrapper around [`lu_factorize`] +
31/// [`lu_solve`] for cases where the sparse matrix has been assembled into
32/// dense form.
33///
34/// # Arguments
35///
36/// * `handle` — solver handle.
37/// * `a_dense` — dense matrix (n x n, column-major). Destroyed on output.
38/// * `n` — system dimension.
39/// * `b` — right-hand side / solution (n x nrhs, column-major). Overwritten.
40/// * `nrhs` — number of right-hand side columns.
41///
42/// # Errors
43///
44/// Returns [`SolverError::SingularMatrix`] if the matrix is singular.
45/// Returns [`SolverError::DimensionMismatch`] for invalid dimensions.
46pub fn direct_solve<T: GpuFloat>(
47 handle: &mut SolverHandle,
48 a_dense: &mut DeviceBuffer<T>,
49 n: u32,
50 b: &mut DeviceBuffer<T>,
51 nrhs: u32,
52) -> SolverResult<()> {
53 // Validate dimensions.
54 if n == 0 || nrhs == 0 {
55 return Ok(());
56 }
57 let a_required = n as usize * n as usize;
58 if a_dense.len() < a_required {
59 return Err(SolverError::DimensionMismatch(format!(
60 "direct_solve: A buffer too small ({} < {a_required})",
61 a_dense.len()
62 )));
63 }
64 let b_required = n as usize * nrhs as usize;
65 if b.len() < b_required {
66 return Err(SolverError::DimensionMismatch(format!(
67 "direct_solve: B buffer too small ({} < {b_required})",
68 b.len()
69 )));
70 }
71
72 // Step 1: LU factorize A.
73 let mut pivots = DeviceBuffer::<i32>::zeroed(n as usize)?;
74 let lu_result = lu_factorize(handle, a_dense, n, n, &mut pivots)?;
75
76 if lu_result.info > 0 {
77 return Err(SolverError::SingularMatrix);
78 }
79
80 // Step 2: Solve using LU factors.
81 lu_solve(handle, a_dense, &pivots, b, n, nrhs)?;
82
83 Ok(())
84}
85
86// ---------------------------------------------------------------------------
87// Solver selection heuristic
88// ---------------------------------------------------------------------------
89
90/// Returns `true` if the direct sparse solver (dense LU) is preferred over
91/// iterative methods for the given system dimensions and density.
92///
93/// Heuristic: direct solver wins for small systems OR for dense/near-dense
94/// systems where iterative methods converge slowly.
95///
96/// # Arguments
97///
98/// * `n` — system dimension.
99/// * `density` — fill ratio in [0.0, 1.0] (nnz / (n * n)).
100///
101/// # Examples
102///
103/// ```
104/// use oxicuda_solver::sparse::direct::prefer_direct_solver;
105///
106/// // Small system: always direct.
107/// assert!(prefer_direct_solver(50, 0.01));
108/// // Large sparse system: iterative (CG preferred for SPD).
109/// assert!(!prefer_direct_solver(10_000, 0.001));
110/// // Dense system: direct even if large-ish.
111/// assert!(prefer_direct_solver(200, 0.8));
112/// ```
113pub fn prefer_direct_solver(n: usize, density: f64) -> bool {
114 // Direct solver preferred for small systems OR high density.
115 n <= 100 || density > 0.3
116}
117
118// ---------------------------------------------------------------------------
119// Tests
120// ---------------------------------------------------------------------------
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn direct_solve_zero_dimension() {
128 // n == 0 or nrhs == 0 should be a no-op.
129 }
130
131 #[test]
132 fn direct_solve_structure() {
133 // Verify the algorithm structure:
134 // 1. LU factorize
135 // 2. LU solve
136 let steps = ["lu_factorize", "lu_solve"];
137 assert_eq!(steps.len(), 2);
138 }
139
140 // ---------------------------------------------------------------------------
141 // Sparse direct vs iterative selection tests
142 // ---------------------------------------------------------------------------
143
144 #[test]
145 fn sparse_direct_vs_iterative_selection() {
146 // For n=100 SPD system with < 1% density → iterative (CG) preferred.
147 // prefer_direct_solver returns false.
148 assert!(
149 !prefer_direct_solver(100_001, 0.009),
150 "large sparse system should prefer iterative"
151 );
152
153 // For n=50 with 50% density → direct Cholesky preferred.
154 assert!(
155 prefer_direct_solver(50, 0.5),
156 "small system should prefer direct"
157 );
158
159 // For n=100 → boundary: exactly <= 100 → prefer direct.
160 assert!(
161 prefer_direct_solver(100, 0.01),
162 "n=100 is within direct solver range"
163 );
164
165 // For high density regardless of size → direct.
166 assert!(
167 prefer_direct_solver(500, 0.4),
168 "density 0.4 > 0.3 → prefer direct"
169 );
170
171 // For large sparse → iterative.
172 assert!(
173 !prefer_direct_solver(10_000, 0.001),
174 "n=10000 with density 0.001 should prefer iterative"
175 );
176 }
177
178 #[test]
179 fn prefer_direct_solver_density_boundary() {
180 // density = 0.3 is the boundary: > 0.3 → direct, <= 0.3 → depends on n.
181 let n_large = 1000;
182 assert!(
183 prefer_direct_solver(n_large, 0.31),
184 "density 0.31 > 0.3 should prefer direct"
185 );
186 assert!(
187 !prefer_direct_solver(n_large, 0.29),
188 "density 0.29 <= 0.3 with large n should prefer iterative"
189 );
190 }
191
192 #[test]
193 fn prefer_direct_solver_small_system() {
194 // Any system with n <= 100 uses direct regardless of density.
195 for n in [1_usize, 10, 50, 100] {
196 for &density in &[0.001, 0.1, 0.5, 1.0] {
197 assert!(
198 prefer_direct_solver(n, density),
199 "n={n} is small enough for direct solver"
200 );
201 }
202 }
203 }
204}