1use crate::scaling::TSymScalingMethod;
25use crate::sparse_sym_iface::{EMatrixFormat, FactorPattern, SparseSymLinearSolverInterface};
26use crate::status::ESymSolverStatus;
27use crate::sym_solver::SymLinearSolver;
28use pounce_common::types::{Index, Number};
29use pounce_linalg::triplet_convert::{TriFull, TripletToCsrConverter};
30
31pub struct TSymLinearSolver {
34 backend: Box<dyn SparseSymLinearSolverInterface>,
35 scaling_method: Option<Box<dyn TSymScalingMethod>>,
36 matrix_format: EMatrixFormat,
37 converter: Option<TripletToCsrConverter>,
38
39 initialized: bool,
41 have_structure: bool,
45 use_scaling: bool,
47 just_switched_on_scaling: bool,
50 linear_scaling_on_demand: bool,
54
55 dim: Index,
56 nonzeros_triplet: Index,
57 nonzeros_compressed: Index,
58
59 airn: Vec<Index>,
61 ajcn: Vec<Index>,
63 scaling_factors: Vec<Number>,
66}
67
68impl std::fmt::Debug for TSymLinearSolver {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("TSymLinearSolver")
71 .field("matrix_format", &self.matrix_format)
72 .field("dim", &self.dim)
73 .field("nonzeros_triplet", &self.nonzeros_triplet)
74 .field("nonzeros_compressed", &self.nonzeros_compressed)
75 .field("use_scaling", &self.use_scaling)
76 .field("initialized", &self.initialized)
77 .finish_non_exhaustive()
78 }
79}
80
81impl TSymLinearSolver {
82 pub fn new(
87 backend: Box<dyn SparseSymLinearSolverInterface>,
88 scaling_method: Option<Box<dyn TSymScalingMethod>>,
89 linear_scaling_on_demand: bool,
90 ) -> Self {
91 let matrix_format = backend.matrix_format();
92 let converter = match matrix_format {
93 EMatrixFormat::TripletFormat => None,
94 EMatrixFormat::CsrFormat0Offset => {
95 Some(TripletToCsrConverter::new(0, TriFull::Triangular))
96 }
97 EMatrixFormat::CsrFormat1Offset => {
98 Some(TripletToCsrConverter::new(1, TriFull::Triangular))
99 }
100 EMatrixFormat::CsrFullFormat0Offset => {
101 Some(TripletToCsrConverter::new(0, TriFull::Full))
102 }
103 EMatrixFormat::CsrFullFormat1Offset => {
104 Some(TripletToCsrConverter::new(1, TriFull::Full))
105 }
106 };
107 let use_scaling = scaling_method.is_some() && !linear_scaling_on_demand;
108 Self {
109 backend,
110 scaling_method,
111 matrix_format,
112 converter,
113 initialized: false,
114 have_structure: false,
115 use_scaling,
116 just_switched_on_scaling: false,
117 linear_scaling_on_demand,
118 dim: 0,
119 nonzeros_triplet: 0,
120 nonzeros_compressed: 0,
121 airn: Vec::new(),
122 ajcn: Vec::new(),
123 scaling_factors: Vec::new(),
124 }
125 }
126
127 pub fn initialize_structure(
131 &mut self,
132 dim: Index,
133 airn: &[Index],
134 ajcn: &[Index],
135 ) -> ESymSolverStatus {
136 assert_eq!(airn.len(), ajcn.len());
137 let nz = airn.len() as Index;
138 self.dim = dim;
139 self.nonzeros_triplet = nz;
140 self.airn = airn.to_vec();
141 self.ajcn = ajcn.to_vec();
142
143 let (ia, ja, nonzeros) = match self.converter.as_mut() {
144 None => (&self.airn[..], &self.ajcn[..], self.nonzeros_triplet),
145 Some(conv) => {
146 let nonzeros_compressed = conv.initialize(self.dim, &self.airn, &self.ajcn);
147 self.nonzeros_compressed = nonzeros_compressed;
148 (conv.ia(), conv.ja(), nonzeros_compressed)
149 }
150 };
151 let status = self.backend.initialize_structure(dim, nonzeros, ia, ja);
152 if status != ESymSolverStatus::Success {
153 return status;
154 }
155 if self.scaling_method.is_some() {
156 self.scaling_factors = vec![0.0; dim as usize];
157 }
158 self.have_structure = true;
159 self.initialized = true;
160 status
161 }
162
163 #[allow(clippy::too_many_arguments)]
174 pub fn multi_solve(
175 &mut self,
176 vals: &[Number],
177 new_matrix: bool,
178 nrhs: Index,
179 rhs_vals: &mut [Number],
180 check_neg_evals: bool,
181 number_of_neg_evals: Index,
182 ) -> ESymSolverStatus {
183 debug_assert!(self.initialized);
184 debug_assert_eq!(vals.len(), self.nonzeros_triplet as usize);
185 debug_assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
186
187 {
198 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
199 static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
200 static WARNED: AtomicBool = AtomicBool::new(false);
201 let n_call = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
202 let skip: usize = std::env::var("POUNCE_DBG_KKT_DUMP_SKIP")
203 .ok()
204 .and_then(|s| s.parse().ok())
205 .unwrap_or(0);
206 if n_call < skip {
207 } else if let Ok(path) = std::env::var("POUNCE_DBG_KKT_DUMP") {
209 if !WARNED.swap(true, Ordering::SeqCst) {
210 tracing::warn!(
211 target: "pounce::linsol",
212 "POUNCE_DBG_KKT_DUMP is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
213 );
214 }
215 use std::io::Write;
216 if let Ok(mut f) = std::fs::File::create(&path) {
217 let dim = self.dim as u64;
218 let nnz = self.nonzeros_triplet as u64;
219 let nrhs64 = nrhs as u64;
220 let _ = f.write_all(&dim.to_le_bytes());
221 let _ = f.write_all(&nnz.to_le_bytes());
222 let _ = f.write_all(&nrhs64.to_le_bytes());
223 for &i in &self.airn {
224 let _ = f.write_all(&(i as i64).to_le_bytes());
225 }
226 for &j in &self.ajcn {
227 let _ = f.write_all(&(j as i64).to_le_bytes());
228 }
229 for &v in vals {
230 let _ = f.write_all(&v.to_le_bytes());
231 }
232 for &v in &*rhs_vals {
233 let _ = f.write_all(&v.to_le_bytes());
234 }
235 let _ = f.flush();
236 }
237 unsafe {
240 std::env::remove_var("POUNCE_DBG_KKT_DUMP");
241 }
242 }
243 }
244
245 let mut new_matrix = new_matrix;
247 if new_matrix || self.just_switched_on_scaling {
248 self.give_matrix_to_solver(true, vals);
249 new_matrix = true;
250 }
251
252 if self.use_scaling {
254 for irhs in 0..nrhs as usize {
255 let base = irhs * self.dim as usize;
256 for i in 0..self.dim as usize {
257 rhs_vals[base + i] *= self.scaling_factors[i];
258 }
259 }
260 }
261
262 let status = loop {
266 let (ia_ptr, ia_len, ja_ptr, ja_len) = match self.converter.as_ref() {
267 None => (
268 self.airn.as_ptr(),
269 self.airn.len(),
270 self.ajcn.as_ptr(),
271 self.ajcn.len(),
272 ),
273 Some(c) => (c.ia().as_ptr(), c.ia().len(), c.ja().as_ptr(), c.ja().len()),
274 };
275 let (ia, ja) = unsafe {
279 (
280 std::slice::from_raw_parts(ia_ptr, ia_len),
281 std::slice::from_raw_parts(ja_ptr, ja_len),
282 )
283 };
284 let s = self.backend.multi_solve(
285 new_matrix,
286 ia,
287 ja,
288 nrhs,
289 rhs_vals,
290 check_neg_evals,
291 number_of_neg_evals,
292 );
293 if s == ESymSolverStatus::CallAgain {
294 self.give_matrix_to_solver(false, vals);
295 continue;
296 }
297 break s;
298 };
299
300 if status == ESymSolverStatus::Success && self.use_scaling {
301 for irhs in 0..nrhs as usize {
307 let base = irhs * self.dim as usize;
308 for i in 0..self.dim as usize {
309 rhs_vals[base + i] *= self.scaling_factors[i];
310 }
311 }
312 }
313
314 status
315 }
316
317 fn give_matrix_to_solver(&mut self, new_matrix: bool, vals: &[Number]) {
321 if self.matrix_format == EMatrixFormat::TripletFormat && !self.use_scaling {
325 let pa = self.backend.values_array_mut();
326 pa[..self.nonzeros_triplet as usize]
327 .copy_from_slice(&vals[..self.nonzeros_triplet as usize]);
328 return;
329 }
330
331 let mut atriplet: Vec<Number> = vals[..self.nonzeros_triplet as usize].to_vec();
334
335 if self.use_scaling {
336 if new_matrix || self.just_switched_on_scaling {
337 let Some(method) = self.scaling_method.as_mut() else {
340 unreachable!("use_scaling without a scaling method")
341 };
342 let ok = method.compute_sym_t_scaling_factors(
343 self.dim,
344 self.nonzeros_triplet,
345 &self.airn,
346 &self.ajcn,
347 &atriplet,
348 &mut self.scaling_factors,
349 );
350 assert!(ok, "scaling method failed");
351 self.just_switched_on_scaling = false;
352 }
353 for (i, a) in atriplet
354 .iter_mut()
355 .enumerate()
356 .take(self.nonzeros_triplet as usize)
357 {
358 let r = (self.airn[i] - 1) as usize;
359 let c = (self.ajcn[i] - 1) as usize;
360 *a *= self.scaling_factors[r] * self.scaling_factors[c];
361 }
362 }
363
364 if self.matrix_format == EMatrixFormat::TripletFormat {
365 let pa = self.backend.values_array_mut();
366 pa[..self.nonzeros_triplet as usize].copy_from_slice(&atriplet);
367 } else {
368 let Some(conv) = self.converter.as_ref() else {
369 unreachable!("non-triplet matrix_format requires a converter");
370 };
371 let pa = self.backend.values_array_mut();
372 conv.convert_values(&atriplet, &mut pa[..self.nonzeros_compressed as usize]);
373 }
374 }
375
376 pub fn factor_pattern(&self, want_values: bool) -> Option<FactorPattern> {
380 self.backend.factor_pattern(want_values)
381 }
382}
383
384impl SymLinearSolver for TSymLinearSolver {
385 fn number_of_neg_evals(&self) -> Index {
386 self.backend.number_of_neg_evals()
387 }
388
389 fn increase_quality(&mut self) -> bool {
393 if self.scaling_method.is_some() && !self.use_scaling && self.linear_scaling_on_demand {
394 self.use_scaling = true;
395 self.just_switched_on_scaling = true;
396 return true;
397 }
398 self.backend.increase_quality()
399 }
400
401 fn provides_inertia(&self) -> bool {
402 self.backend.provides_inertia()
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::scaling::IdentityScalingMethod;
410
411 #[derive(Default)]
415 struct MockBackend {
416 dim: Index,
417 nz: Index,
418 a: Vec<Number>,
419 last_solve_was_new_matrix: bool,
420 last_solve_was_scaled_a: Option<Vec<Number>>,
421 canned_solution: Vec<Number>,
422 neg_evals: Index,
423 increase_quality_calls: u32,
424 max_increase_quality_calls: u32,
425 }
426
427 impl SparseSymLinearSolverInterface for MockBackend {
428 fn initialize_structure(
429 &mut self,
430 dim: Index,
431 nz: Index,
432 _ia: &[Index],
433 _ja: &[Index],
434 ) -> ESymSolverStatus {
435 self.dim = dim;
436 self.nz = nz;
437 self.a = vec![0.0; nz as usize];
438 ESymSolverStatus::Success
439 }
440 fn values_array_mut(&mut self) -> &mut [Number] {
441 &mut self.a
442 }
443 fn multi_solve(
444 &mut self,
445 new_matrix: bool,
446 _ia: &[Index],
447 _ja: &[Index],
448 nrhs: Index,
449 rhs_vals: &mut [Number],
450 _check: bool,
451 _nev: Index,
452 ) -> ESymSolverStatus {
453 self.last_solve_was_new_matrix = new_matrix;
454 self.last_solve_was_scaled_a = Some(self.a.clone());
455 assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
456 for irhs in 0..nrhs as usize {
457 let base = irhs * self.dim as usize;
458 rhs_vals[base..base + self.dim as usize].copy_from_slice(&self.canned_solution);
459 }
460 ESymSolverStatus::Success
461 }
462 fn number_of_neg_evals(&self) -> Index {
463 self.neg_evals
464 }
465 fn increase_quality(&mut self) -> bool {
466 self.increase_quality_calls += 1;
467 self.increase_quality_calls <= self.max_increase_quality_calls
468 }
469 fn provides_inertia(&self) -> bool {
470 true
471 }
472 fn matrix_format(&self) -> EMatrixFormat {
473 EMatrixFormat::TripletFormat
474 }
475 }
476
477 fn make_2x2_indef_pattern() -> ([Index; 3], [Index; 3]) {
478 ([1, 2, 2], [1, 1, 2])
479 }
480
481 #[test]
482 fn unscaled_triplet_solve_passes_values_through() {
483 let backend = MockBackend {
484 canned_solution: vec![10.0, 20.0],
485 ..Default::default()
486 };
487 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
488 let (irn, jcn) = make_2x2_indef_pattern();
489 assert_eq!(
490 solver.initialize_structure(2, &irn, &jcn),
491 ESymSolverStatus::Success
492 );
493
494 let vals = [2.0, 1.0, 3.0];
495 let mut rhs = [3.0, 4.0];
496 assert_eq!(
497 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
498 ESymSolverStatus::Success
499 );
500 assert_eq!(rhs, [10.0, 20.0]);
502 assert!(solver.provides_inertia());
503 }
504
505 #[test]
506 fn identity_scaling_does_not_change_values() {
507 let backend = MockBackend {
508 canned_solution: vec![1.0, 1.0],
509 ..Default::default()
510 };
511 let mut solver = TSymLinearSolver::new(
514 Box::new(backend),
515 Some(Box::new(IdentityScalingMethod)),
516 false,
517 );
518 let (irn, jcn) = make_2x2_indef_pattern();
519 solver.initialize_structure(2, &irn, &jcn);
520
521 let vals = [2.0, 1.0, 3.0];
522 let mut rhs = [4.0, 5.0];
523 assert_eq!(
524 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
525 ESymSolverStatus::Success
526 );
527 assert_eq!(rhs, [1.0, 1.0]);
531 }
532
533 #[test]
534 fn nontrivial_scaling_premultiplies_matrix_and_postmultiplies_solution() {
535 struct DiagTwoThree;
540 impl TSymScalingMethod for DiagTwoThree {
541 fn compute_sym_t_scaling_factors(
542 &mut self,
543 _n: Index,
544 _nnz: Index,
545 _airn: &[Index],
546 _ajcn: &[Index],
547 _a: &[Number],
548 scaling_factors: &mut [Number],
549 ) -> bool {
550 scaling_factors[0] = 2.0;
551 scaling_factors[1] = 3.0;
552 true
553 }
554 }
555
556 let backend = MockBackend {
557 canned_solution: vec![7.0, 11.0],
561 ..Default::default()
562 };
563 let mut solver =
564 TSymLinearSolver::new(Box::new(backend), Some(Box::new(DiagTwoThree)), false);
565 let (irn, jcn) = make_2x2_indef_pattern();
566 solver.initialize_structure(2, &irn, &jcn);
567
568 let vals = [2.0, 1.0, 3.0];
569 let mut rhs = [4.0, 5.0];
570 assert_eq!(
571 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
572 ESymSolverStatus::Success
573 );
574 assert_eq!(rhs, [2.0 * 7.0, 3.0 * 11.0]);
575 }
576
577 #[test]
578 fn increase_quality_switches_on_scaling_first() {
579 let backend = MockBackend {
580 canned_solution: vec![0.0, 0.0],
581 max_increase_quality_calls: 5,
582 ..Default::default()
583 };
584 let mut solver = TSymLinearSolver::new(
585 Box::new(backend),
586 Some(Box::new(IdentityScalingMethod)),
587 true, );
589 assert!(solver.increase_quality());
592 assert!(solver.increase_quality());
594 }
595
596 #[test]
597 fn increase_quality_without_scaling_goes_straight_to_backend() {
598 let backend = MockBackend {
599 max_increase_quality_calls: 1,
600 ..Default::default()
601 };
602 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
603 assert!(solver.increase_quality());
604 assert!(!solver.increase_quality());
606 }
607}