1use crate::error::FactorizationError;
32use crate::sparse_sym_iface::SparseSymLinearSolverInterface;
33use crate::t_sym_solver::TSymLinearSolver;
34use pounce_common::types::{Index, Number};
35
36pub struct Factorization {
48 inner: TSymLinearSolver,
49 dim: Index,
50 nnz: Index,
51 values: Vec<Number>,
52 inertia_known: bool,
53}
54
55impl std::fmt::Debug for Factorization {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("Factorization")
58 .field("dim", &self.dim)
59 .field("nnz", &self.nnz)
60 .field("inertia_known", &self.inertia_known)
61 .finish_non_exhaustive()
62 }
63}
64
65impl Factorization {
66 pub fn new(
87 dim: Index,
88 airn: Vec<Index>,
89 ajcn: Vec<Index>,
90 values: Vec<Number>,
91 backend: Box<dyn SparseSymLinearSolverInterface>,
92 ) -> Result<Self, FactorizationError> {
93 assert_eq!(
94 airn.len(),
95 ajcn.len(),
96 "airn and ajcn must have same length"
97 );
98 assert_eq!(values.len(), airn.len(), "values must match nnz");
99 let nnz = airn.len() as Index;
100 let mut inner = TSymLinearSolver::new(backend, None, false);
101 FactorizationError::from_status(inner.initialize_structure(dim, &airn, &ajcn))?;
102
103 let mut me = Self {
104 inner,
105 dim,
106 nnz,
107 values,
108 inertia_known: false,
109 };
110
111 me.do_factor()?;
117 Ok(me)
118 }
119
120 pub fn solve(&mut self, rhs: &mut [Number], nrhs: usize) -> Result<(), FactorizationError> {
132 assert_eq!(
133 rhs.len(),
134 self.dim as usize * nrhs,
135 "rhs length must equal dim * nrhs"
136 );
137 let status = self.inner.multi_solve(
138 &self.values,
139 false, nrhs as Index,
141 rhs,
142 false,
143 0,
144 );
145 FactorizationError::from_status(status)
146 }
147
148 pub fn solve_one(&mut self, rhs: &mut [Number]) -> Result<(), FactorizationError> {
151 self.solve(rhs, 1)
152 }
153
154 pub fn refactor(&mut self, new_values: &[Number]) -> Result<(), FactorizationError> {
165 assert_eq!(
166 new_values.len(),
167 self.nnz as usize,
168 "new_values length must equal nnz",
169 );
170 self.values.copy_from_slice(new_values);
171 self.inertia_known = false;
172 self.do_factor()
173 }
174
175 pub fn number_of_neg_evals(&self) -> Option<Index> {
178 use crate::sym_solver::SymLinearSolver;
179 if self.inertia_known && self.inner.provides_inertia() {
180 Some(self.inner.number_of_neg_evals())
181 } else {
182 None
183 }
184 }
185
186 pub fn dim(&self) -> Index {
188 self.dim
189 }
190
191 pub fn nnz(&self) -> Index {
193 self.nnz
194 }
195
196 fn do_factor(&mut self) -> Result<(), FactorizationError> {
198 let mut dummy_rhs = vec![0.0; self.dim as usize];
199 let status = self.inner.multi_solve(
200 &self.values,
201 true, 1,
203 &mut dummy_rhs,
204 false,
205 0,
206 );
207 FactorizationError::from_status(status)?;
208 self.inertia_known = true;
209 Ok(())
210 }
211}
212
213#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::sparse_sym_iface::EMatrixFormat;
223 use crate::status::ESymSolverStatus;
224
225 struct DenseLuBackend {
229 dim: usize,
230 nnz: usize,
231 rows: Vec<Index>, cols: Vec<Index>, values: Vec<Number>,
234 factor: Option<DenseLu>,
236 }
237
238 struct DenseLu {
239 a: Vec<Vec<f64>>, perm: Vec<usize>,
241 neg_evals: Index,
242 }
243
244 impl DenseLuBackend {
245 fn new() -> Self {
246 Self {
247 dim: 0,
248 nnz: 0,
249 rows: Vec::new(),
250 cols: Vec::new(),
251 values: Vec::new(),
252 factor: None,
253 }
254 }
255
256 fn assemble_dense(&self) -> Vec<Vec<f64>> {
257 let n = self.dim;
258 let mut a = vec![vec![0.0; n]; n];
259 for k in 0..self.nnz {
260 let i = (self.rows[k] - 1) as usize;
261 let j = (self.cols[k] - 1) as usize;
262 a[i][j] += self.values[k];
263 if i != j {
264 a[j][i] += self.values[k];
265 }
266 }
267 a
268 }
269
270 fn factor_dense(&mut self) -> ESymSolverStatus {
271 let n = self.dim;
272 let mut a = self.assemble_dense();
273 let mut perm: Vec<usize> = (0..n).collect();
274 for k in 0..n {
276 let mut p = k;
278 let mut maxv = a[perm[k]][k].abs();
279 for i in (k + 1)..n {
280 let v = a[perm[i]][k].abs();
281 if v > maxv {
282 maxv = v;
283 p = i;
284 }
285 }
286 if maxv < 1e-300 {
287 return ESymSolverStatus::Singular;
288 }
289 perm.swap(k, p);
290 let pk = perm[k];
291 for &pi in &perm[(k + 1)..n] {
292 let factor = a[pi][k] / a[pk][k];
293 a[pi][k] = factor;
294 #[allow(clippy::needless_range_loop)]
295 for j in (k + 1)..n {
296 a[pi][j] -= factor * a[pk][j];
297 }
298 }
299 }
300 let mut neg = 0;
306 for k in 0..n {
307 if a[perm[k]][k] < 0.0 {
308 neg += 1;
309 }
310 }
311 self.factor = Some(DenseLu {
312 a,
313 perm,
314 neg_evals: neg as Index,
315 });
316 ESymSolverStatus::Success
317 }
318
319 fn solve_one(&self, b: &mut [f64]) {
320 let factor = self.factor.as_ref().unwrap();
321 let n = self.dim;
322 let mut x: Vec<f64> = factor.perm.iter().map(|&p| b[p]).collect();
324 for i in 0..n {
326 let pi = factor.perm[i];
327 for j in 0..i {
328 x[i] -= factor.a[pi][j] * x[j];
329 }
330 }
331 for i in (0..n).rev() {
333 let pi = factor.perm[i];
334 for j in (i + 1)..n {
335 x[i] -= factor.a[pi][j] * x[j];
336 }
337 x[i] /= factor.a[pi][i];
338 }
339 b.copy_from_slice(&x);
340 }
341 }
342
343 impl SparseSymLinearSolverInterface for DenseLuBackend {
344 fn initialize_structure(
345 &mut self,
346 dim: Index,
347 nonzeros: Index,
348 ia: &[Index],
349 ja: &[Index],
350 ) -> ESymSolverStatus {
351 self.dim = dim as usize;
352 self.nnz = nonzeros as usize;
353 self.rows = ia.to_vec();
354 self.cols = ja.to_vec();
355 self.values = vec![0.0; self.nnz];
356 ESymSolverStatus::Success
357 }
358
359 fn values_array_mut(&mut self) -> &mut [Number] {
360 &mut self.values
361 }
362
363 fn multi_solve(
364 &mut self,
365 new_matrix: bool,
366 _ia: &[Index],
367 _ja: &[Index],
368 nrhs: Index,
369 rhs_vals: &mut [Number],
370 check_neg_evals: bool,
371 number_of_neg_evals: Index,
372 ) -> ESymSolverStatus {
373 if new_matrix {
374 let s = self.factor_dense();
375 if s != ESymSolverStatus::Success {
376 return s;
377 }
378 if check_neg_evals {
379 let actual = self.factor.as_ref().unwrap().neg_evals;
380 if actual != number_of_neg_evals {
381 return ESymSolverStatus::WrongInertia;
382 }
383 }
384 }
385 let n = self.dim;
386 for k in 0..nrhs as usize {
387 let base = k * n;
388 self.solve_one(&mut rhs_vals[base..base + n]);
389 }
390 ESymSolverStatus::Success
391 }
392
393 fn number_of_neg_evals(&self) -> Index {
394 self.factor.as_ref().map(|f| f.neg_evals).unwrap_or(0)
395 }
396
397 fn increase_quality(&mut self) -> bool {
398 false
399 }
400
401 fn provides_inertia(&self) -> bool {
402 true
403 }
404
405 fn matrix_format(&self) -> EMatrixFormat {
406 EMatrixFormat::TripletFormat
407 }
408 }
409
410 #[test]
413 fn factors_spd_2x2_and_solves_one_rhs() {
414 let airn = vec![1, 2, 2];
415 let ajcn = vec![1, 1, 2];
416 let values = vec![2.0, 1.0, 3.0];
417 let mut f =
418 Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
419 let mut rhs = vec![3.0, 4.0];
420 f.solve_one(&mut rhs).unwrap();
421 assert!((rhs[0] - 1.0).abs() < 1e-12);
422 assert!((rhs[1] - 1.0).abs() < 1e-12);
423 }
424
425 #[test]
428 fn packed_multi_rhs_matches_one_at_a_time() {
429 let airn = vec![1, 2, 2];
430 let ajcn = vec![1, 1, 2];
431 let values = vec![2.0, 1.0, 3.0];
432 let backend1 = Box::new(DenseLuBackend::new());
433 let backend2 = Box::new(DenseLuBackend::new());
434 let mut f1 =
435 Factorization::new(2, airn.clone(), ajcn.clone(), values.clone(), backend1).unwrap();
436 let mut f2 = Factorization::new(2, airn, ajcn, values, backend2).unwrap();
437
438 let mut packed = vec![
440 3.0, 4.0, 5.0, 5.0, 2.0, 6.0, ];
444 f1.solve(&mut packed, 3).unwrap();
445
446 let mut col0 = vec![3.0, 4.0];
448 let mut col1 = vec![5.0, 5.0];
449 let mut col2 = vec![2.0, 6.0];
450 f2.solve_one(&mut col0).unwrap();
451 f2.solve_one(&mut col1).unwrap();
452 f2.solve_one(&mut col2).unwrap();
453
454 for (i, &v) in col0.iter().enumerate() {
455 assert!((packed[i] - v).abs() < 1e-12, "col0 mismatch at {i}");
456 }
457 for (i, &v) in col1.iter().enumerate() {
458 assert!((packed[2 + i] - v).abs() < 1e-12, "col1 mismatch at {i}");
459 }
460 for (i, &v) in col2.iter().enumerate() {
461 assert!((packed[4 + i] - v).abs() < 1e-12, "col2 mismatch at {i}");
462 }
463 }
464
465 #[test]
468 fn refactor_yields_correct_solution_for_new_values() {
469 let airn = vec![1, 2, 2];
470 let ajcn = vec![1, 1, 2];
471 let mut f = Factorization::new(
472 2,
473 airn,
474 ajcn,
475 vec![2.0, 1.0, 3.0],
476 Box::new(DenseLuBackend::new()),
477 )
478 .unwrap();
479
480 f.refactor(&[4.0, 1.0, 5.0]).unwrap();
482 let mut rhs = vec![5.0, 6.0]; f.solve_one(&mut rhs).unwrap();
484 let r0 = 4.0 * rhs[0] + rhs[1] - 5.0;
486 let r1 = rhs[0] + 5.0 * rhs[1] - 6.0;
487 assert!(r0.abs() < 1e-10);
488 assert!(r1.abs() < 1e-10);
489 }
490
491 #[test]
493 fn singular_matrix_returns_singular_error() {
494 let airn = vec![1, 2, 2];
498 let ajcn = vec![1, 1, 2];
499 let err = Factorization::new(
500 2,
501 airn,
502 ajcn,
503 vec![1.0, 1.0, 1.0],
504 Box::new(DenseLuBackend::new()),
505 )
506 .unwrap_err();
507 assert_eq!(err, FactorizationError::Singular);
508 }
509
510 #[test]
512 fn solve_one_matches_solve_with_nrhs_one() {
513 let airn = vec![1, 2, 2];
514 let ajcn = vec![1, 1, 2];
515 let values = vec![2.0, 1.0, 3.0];
516 let mut f1 = Factorization::new(
517 2,
518 airn.clone(),
519 ajcn.clone(),
520 values.clone(),
521 Box::new(DenseLuBackend::new()),
522 )
523 .unwrap();
524 let mut f2 =
525 Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
526
527 let mut rhs1 = vec![3.0, 4.0];
528 let mut rhs2 = vec![3.0, 4.0];
529 f1.solve_one(&mut rhs1).unwrap();
530 f2.solve(&mut rhs2, 1).unwrap();
531 assert_eq!(rhs1, rhs2);
532 }
533
534 #[test]
537 fn inertia_is_reported_when_backend_provides_it() {
538 let airn = vec![1, 2, 2];
539 let ajcn = vec![1, 1, 2];
540 let f = Factorization::new(
541 2,
542 airn,
543 ajcn,
544 vec![2.0, 1.0, 3.0], Box::new(DenseLuBackend::new()),
546 )
547 .unwrap();
548 assert_eq!(f.number_of_neg_evals(), Some(0));
549 assert_eq!(f.dim(), 2);
550 assert_eq!(f.nnz(), 3);
551 }
552}