1use crate::scaling::TSymScalingMethod;
25use crate::sparse_sym_iface::{EMatrixFormat, FactorPattern, SparseSymLinearSolverInterface};
26use crate::status::ESymSolverStatus;
27use crate::sym_solver::SymLinearSolver;
28use pounce_common::types::{Index, Number};
29use pounce_linalg::triplet_convert::{TriFull, TripletToCsrConverter};
30
31pub struct TSymLinearSolver {
34 backend: Box<dyn SparseSymLinearSolverInterface>,
35 scaling_method: Option<Box<dyn TSymScalingMethod>>,
36 matrix_format: EMatrixFormat,
37 converter: Option<TripletToCsrConverter>,
38
39 initialized: bool,
41 have_structure: bool,
45 use_scaling: bool,
47 just_switched_on_scaling: bool,
50 linear_scaling_on_demand: bool,
54
55 dim: Index,
56 nonzeros_triplet: Index,
57 nonzeros_compressed: Index,
58
59 airn: Vec<Index>,
61 ajcn: Vec<Index>,
63 scaling_factors: Vec<Number>,
66}
67
68impl std::fmt::Debug for TSymLinearSolver {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("TSymLinearSolver")
71 .field("matrix_format", &self.matrix_format)
72 .field("dim", &self.dim)
73 .field("nonzeros_triplet", &self.nonzeros_triplet)
74 .field("nonzeros_compressed", &self.nonzeros_compressed)
75 .field("use_scaling", &self.use_scaling)
76 .field("initialized", &self.initialized)
77 .finish_non_exhaustive()
78 }
79}
80
81impl TSymLinearSolver {
82 pub fn new(
87 backend: Box<dyn SparseSymLinearSolverInterface>,
88 scaling_method: Option<Box<dyn TSymScalingMethod>>,
89 linear_scaling_on_demand: bool,
90 ) -> Self {
91 let matrix_format = backend.matrix_format();
92 let converter = match matrix_format {
93 EMatrixFormat::TripletFormat => None,
94 EMatrixFormat::CsrFormat0Offset => {
95 Some(TripletToCsrConverter::new(0, TriFull::Triangular))
96 }
97 EMatrixFormat::CsrFormat1Offset => {
98 Some(TripletToCsrConverter::new(1, TriFull::Triangular))
99 }
100 EMatrixFormat::CsrFullFormat0Offset => {
101 Some(TripletToCsrConverter::new(0, TriFull::Full))
102 }
103 EMatrixFormat::CsrFullFormat1Offset => {
104 Some(TripletToCsrConverter::new(1, TriFull::Full))
105 }
106 };
107 let use_scaling = scaling_method.is_some() && !linear_scaling_on_demand;
108 Self {
109 backend,
110 scaling_method,
111 matrix_format,
112 converter,
113 initialized: false,
114 have_structure: false,
115 use_scaling,
116 just_switched_on_scaling: false,
117 linear_scaling_on_demand,
118 dim: 0,
119 nonzeros_triplet: 0,
120 nonzeros_compressed: 0,
121 airn: Vec::new(),
122 ajcn: Vec::new(),
123 scaling_factors: Vec::new(),
124 }
125 }
126
127 pub fn initialize_structure(
131 &mut self,
132 dim: Index,
133 airn: &[Index],
134 ajcn: &[Index],
135 ) -> ESymSolverStatus {
136 assert_eq!(airn.len(), ajcn.len());
137 let nz = airn.len() as Index;
138 self.dim = dim;
139 self.nonzeros_triplet = nz;
140 self.airn = airn.to_vec();
141 self.ajcn = ajcn.to_vec();
142
143 let (ia, ja, nonzeros) = match self.converter.as_mut() {
144 None => (&self.airn[..], &self.ajcn[..], self.nonzeros_triplet),
145 Some(conv) => {
146 let nonzeros_compressed = conv.initialize(self.dim, &self.airn, &self.ajcn);
147 self.nonzeros_compressed = nonzeros_compressed;
148 (conv.ia(), conv.ja(), nonzeros_compressed)
149 }
150 };
151 let status = self.backend.initialize_structure(dim, nonzeros, ia, ja);
152 if status != ESymSolverStatus::Success {
153 return status;
154 }
155 if self.scaling_method.is_some() {
156 self.scaling_factors = vec![0.0; dim as usize];
157 }
158 self.have_structure = true;
159 self.initialized = true;
160 status
161 }
162
163 #[allow(clippy::too_many_arguments)]
174 pub fn multi_solve(
175 &mut self,
176 vals: &[Number],
177 new_matrix: bool,
178 nrhs: Index,
179 rhs_vals: &mut [Number],
180 check_neg_evals: bool,
181 number_of_neg_evals: Index,
182 ) -> ESymSolverStatus {
183 debug_assert!(self.initialized);
184 debug_assert_eq!(vals.len(), self.nonzeros_triplet as usize);
185 debug_assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
186
187 {
198 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
199 static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
200 static DUMPED: AtomicBool = AtomicBool::new(false);
201 let n_call = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
202 let skip: usize = std::env::var("POUNCE_DBG_KKT_DUMP_SKIP")
203 .ok()
204 .and_then(|s| s.parse().ok())
205 .unwrap_or(0);
206 if let Ok(path) = std::env::var("POUNCE_DBG_KKT_DUMP") {
218 if claim_kkt_dump(n_call, skip, &DUMPED) {
219 tracing::warn!(
220 target: "pounce::linsol",
221 "POUNCE_DBG_KKT_DUMP is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
222 );
223 use std::io::Write;
224 if let Ok(mut f) = std::fs::File::create(&path) {
225 let dim = self.dim as u64;
226 let nnz = self.nonzeros_triplet as u64;
227 let nrhs64 = nrhs as u64;
228 let _ = f.write_all(&dim.to_le_bytes());
229 let _ = f.write_all(&nnz.to_le_bytes());
230 let _ = f.write_all(&nrhs64.to_le_bytes());
231 for &i in &self.airn {
232 let _ = f.write_all(&(i as i64).to_le_bytes());
233 }
234 for &j in &self.ajcn {
235 let _ = f.write_all(&(j as i64).to_le_bytes());
236 }
237 for &v in vals {
238 let _ = f.write_all(&v.to_le_bytes());
239 }
240 for &v in &*rhs_vals {
241 let _ = f.write_all(&v.to_le_bytes());
242 }
243 let _ = f.flush();
244 }
245 }
246 }
247 }
248
249 let mut new_matrix = new_matrix;
251 if new_matrix || self.just_switched_on_scaling {
252 self.give_matrix_to_solver(true, vals);
253 new_matrix = true;
254 }
255
256 if self.use_scaling {
258 for irhs in 0..nrhs as usize {
259 let base = irhs * self.dim as usize;
260 for i in 0..self.dim as usize {
261 rhs_vals[base + i] *= self.scaling_factors[i];
262 }
263 }
264 }
265
266 let status = loop {
270 let (ia_ptr, ia_len, ja_ptr, ja_len) = match self.converter.as_ref() {
271 None => (
272 self.airn.as_ptr(),
273 self.airn.len(),
274 self.ajcn.as_ptr(),
275 self.ajcn.len(),
276 ),
277 Some(c) => (c.ia().as_ptr(), c.ia().len(), c.ja().as_ptr(), c.ja().len()),
278 };
279 let (ia, ja) = unsafe {
283 (
284 std::slice::from_raw_parts(ia_ptr, ia_len),
285 std::slice::from_raw_parts(ja_ptr, ja_len),
286 )
287 };
288 let s = self.backend.multi_solve(
289 new_matrix,
290 ia,
291 ja,
292 nrhs,
293 rhs_vals,
294 check_neg_evals,
295 number_of_neg_evals,
296 );
297 if s == ESymSolverStatus::CallAgain {
298 self.give_matrix_to_solver(false, vals);
299 continue;
300 }
301 break s;
302 };
303
304 if status == ESymSolverStatus::Success && self.use_scaling {
305 for irhs in 0..nrhs as usize {
311 let base = irhs * self.dim as usize;
312 for i in 0..self.dim as usize {
313 rhs_vals[base + i] *= self.scaling_factors[i];
314 }
315 }
316 }
317
318 status
319 }
320
321 fn give_matrix_to_solver(&mut self, new_matrix: bool, vals: &[Number]) {
325 if self.matrix_format == EMatrixFormat::TripletFormat && !self.use_scaling {
329 let pa = self.backend.values_array_mut();
330 pa[..self.nonzeros_triplet as usize]
331 .copy_from_slice(&vals[..self.nonzeros_triplet as usize]);
332 return;
333 }
334
335 let mut atriplet: Vec<Number> = vals[..self.nonzeros_triplet as usize].to_vec();
338
339 if self.use_scaling {
340 if new_matrix || self.just_switched_on_scaling {
341 let Some(method) = self.scaling_method.as_mut() else {
344 unreachable!("use_scaling without a scaling method")
345 };
346 let ok = method.compute_sym_t_scaling_factors(
347 self.dim,
348 self.nonzeros_triplet,
349 &self.airn,
350 &self.ajcn,
351 &atriplet,
352 &mut self.scaling_factors,
353 );
354 assert!(ok, "scaling method failed");
355 self.just_switched_on_scaling = false;
356 }
357 for (i, a) in atriplet
358 .iter_mut()
359 .enumerate()
360 .take(self.nonzeros_triplet as usize)
361 {
362 let r = (self.airn[i] - 1) as usize;
363 let c = (self.ajcn[i] - 1) as usize;
364 *a *= self.scaling_factors[r] * self.scaling_factors[c];
365 }
366 }
367
368 if self.matrix_format == EMatrixFormat::TripletFormat {
369 let pa = self.backend.values_array_mut();
370 pa[..self.nonzeros_triplet as usize].copy_from_slice(&atriplet);
371 } else {
372 let Some(conv) = self.converter.as_ref() else {
373 unreachable!("non-triplet matrix_format requires a converter");
374 };
375 let pa = self.backend.values_array_mut();
376 conv.convert_values(&atriplet, &mut pa[..self.nonzeros_compressed as usize]);
377 }
378 }
379
380 pub fn factor_pattern(&self, want_values: bool) -> Option<FactorPattern> {
384 self.backend.factor_pattern(want_values)
385 }
386}
387
388impl SymLinearSolver for TSymLinearSolver {
389 fn number_of_neg_evals(&self) -> Index {
390 self.backend.number_of_neg_evals()
391 }
392
393 fn increase_quality(&mut self) -> bool {
397 if self.scaling_method.is_some() && !self.use_scaling && self.linear_scaling_on_demand {
398 self.use_scaling = true;
399 self.just_switched_on_scaling = true;
400 return true;
401 }
402 self.backend.increase_quality()
403 }
404
405 fn provides_inertia(&self) -> bool {
406 self.backend.provides_inertia()
407 }
408}
409
410fn claim_kkt_dump(n_call: usize, skip: usize, dumped: &std::sync::atomic::AtomicBool) -> bool {
424 use std::sync::atomic::Ordering;
425 if n_call < skip {
426 return false;
427 }
428 !dumped.swap(true, Ordering::SeqCst)
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::scaling::IdentityScalingMethod;
435
436 #[derive(Default)]
440 struct MockBackend {
441 dim: Index,
442 nz: Index,
443 a: Vec<Number>,
444 last_solve_was_new_matrix: bool,
445 last_solve_was_scaled_a: Option<Vec<Number>>,
446 canned_solution: Vec<Number>,
447 neg_evals: Index,
448 increase_quality_calls: u32,
449 max_increase_quality_calls: u32,
450 }
451
452 impl SparseSymLinearSolverInterface for MockBackend {
453 fn initialize_structure(
454 &mut self,
455 dim: Index,
456 nz: Index,
457 _ia: &[Index],
458 _ja: &[Index],
459 ) -> ESymSolverStatus {
460 self.dim = dim;
461 self.nz = nz;
462 self.a = vec![0.0; nz as usize];
463 ESymSolverStatus::Success
464 }
465 fn values_array_mut(&mut self) -> &mut [Number] {
466 &mut self.a
467 }
468 fn multi_solve(
469 &mut self,
470 new_matrix: bool,
471 _ia: &[Index],
472 _ja: &[Index],
473 nrhs: Index,
474 rhs_vals: &mut [Number],
475 _check: bool,
476 _nev: Index,
477 ) -> ESymSolverStatus {
478 self.last_solve_was_new_matrix = new_matrix;
479 self.last_solve_was_scaled_a = Some(self.a.clone());
480 assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
481 for irhs in 0..nrhs as usize {
482 let base = irhs * self.dim as usize;
483 rhs_vals[base..base + self.dim as usize].copy_from_slice(&self.canned_solution);
484 }
485 ESymSolverStatus::Success
486 }
487 fn number_of_neg_evals(&self) -> Index {
488 self.neg_evals
489 }
490 fn increase_quality(&mut self) -> bool {
491 self.increase_quality_calls += 1;
492 self.increase_quality_calls <= self.max_increase_quality_calls
493 }
494 fn provides_inertia(&self) -> bool {
495 true
496 }
497 fn matrix_format(&self) -> EMatrixFormat {
498 EMatrixFormat::TripletFormat
499 }
500 }
501
502 fn make_2x2_indef_pattern() -> ([Index; 3], [Index; 3]) {
503 ([1, 2, 2], [1, 1, 2])
504 }
505
506 #[test]
507 fn unscaled_triplet_solve_passes_values_through() {
508 let backend = MockBackend {
509 canned_solution: vec![10.0, 20.0],
510 ..Default::default()
511 };
512 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
513 let (irn, jcn) = make_2x2_indef_pattern();
514 assert_eq!(
515 solver.initialize_structure(2, &irn, &jcn),
516 ESymSolverStatus::Success
517 );
518
519 let vals = [2.0, 1.0, 3.0];
520 let mut rhs = [3.0, 4.0];
521 assert_eq!(
522 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
523 ESymSolverStatus::Success
524 );
525 assert_eq!(rhs, [10.0, 20.0]);
527 assert!(solver.provides_inertia());
528 }
529
530 #[test]
531 fn identity_scaling_does_not_change_values() {
532 let backend = MockBackend {
533 canned_solution: vec![1.0, 1.0],
534 ..Default::default()
535 };
536 let mut solver = TSymLinearSolver::new(
539 Box::new(backend),
540 Some(Box::new(IdentityScalingMethod)),
541 false,
542 );
543 let (irn, jcn) = make_2x2_indef_pattern();
544 solver.initialize_structure(2, &irn, &jcn);
545
546 let vals = [2.0, 1.0, 3.0];
547 let mut rhs = [4.0, 5.0];
548 assert_eq!(
549 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
550 ESymSolverStatus::Success
551 );
552 assert_eq!(rhs, [1.0, 1.0]);
556 }
557
558 #[test]
559 fn nontrivial_scaling_premultiplies_matrix_and_postmultiplies_solution() {
560 struct DiagTwoThree;
565 impl TSymScalingMethod for DiagTwoThree {
566 fn compute_sym_t_scaling_factors(
567 &mut self,
568 _n: Index,
569 _nnz: Index,
570 _airn: &[Index],
571 _ajcn: &[Index],
572 _a: &[Number],
573 scaling_factors: &mut [Number],
574 ) -> bool {
575 scaling_factors[0] = 2.0;
576 scaling_factors[1] = 3.0;
577 true
578 }
579 }
580
581 let backend = MockBackend {
582 canned_solution: vec![7.0, 11.0],
586 ..Default::default()
587 };
588 let mut solver =
589 TSymLinearSolver::new(Box::new(backend), Some(Box::new(DiagTwoThree)), false);
590 let (irn, jcn) = make_2x2_indef_pattern();
591 solver.initialize_structure(2, &irn, &jcn);
592
593 let vals = [2.0, 1.0, 3.0];
594 let mut rhs = [4.0, 5.0];
595 assert_eq!(
596 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
597 ESymSolverStatus::Success
598 );
599 assert_eq!(rhs, [2.0 * 7.0, 3.0 * 11.0]);
600 }
601
602 #[test]
603 fn increase_quality_switches_on_scaling_first() {
604 let backend = MockBackend {
605 canned_solution: vec![0.0, 0.0],
606 max_increase_quality_calls: 5,
607 ..Default::default()
608 };
609 let mut solver = TSymLinearSolver::new(
610 Box::new(backend),
611 Some(Box::new(IdentityScalingMethod)),
612 true, );
614 assert!(solver.increase_quality());
617 assert!(solver.increase_quality());
619 }
620
621 #[test]
622 fn increase_quality_without_scaling_goes_straight_to_backend() {
623 let backend = MockBackend {
624 max_increase_quality_calls: 1,
625 ..Default::default()
626 };
627 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
628 assert!(solver.increase_quality());
629 assert!(!solver.increase_quality());
631 }
632
633 #[test]
637 fn claim_kkt_dump_is_one_shot_after_skip() {
638 use std::sync::atomic::AtomicBool;
639 let dumped = AtomicBool::new(false);
640 assert!(!super::claim_kkt_dump(0, 2, &dumped));
642 assert!(!super::claim_kkt_dump(1, 2, &dumped));
643 assert!(super::claim_kkt_dump(2, 2, &dumped));
645 assert!(!super::claim_kkt_dump(3, 2, &dumped));
646 assert!(!super::claim_kkt_dump(4, 2, &dumped));
647 }
648
649 #[test]
654 fn claim_kkt_dump_claims_exactly_once_under_concurrency() {
655 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
656 use std::sync::{Arc, Barrier};
657 let dumped = Arc::new(AtomicBool::new(false));
658 let wins = Arc::new(AtomicUsize::new(0));
659 let n_threads = 32;
660 let barrier = Arc::new(Barrier::new(n_threads));
661 let mut handles = Vec::new();
662 for _ in 0..n_threads {
663 let d = Arc::clone(&dumped);
664 let w = Arc::clone(&wins);
665 let b = Arc::clone(&barrier);
666 handles.push(std::thread::spawn(move || {
667 b.wait();
669 if super::claim_kkt_dump(0, 0, &d) {
670 w.fetch_add(1, Ordering::SeqCst);
671 }
672 }));
673 }
674 for h in handles {
675 h.join().unwrap();
676 }
677 assert_eq!(
678 wins.load(Ordering::SeqCst),
679 1,
680 "exactly one thread must claim the one-shot KKT dump"
681 );
682 }
683}