1use crate::scaling::TSymScalingMethod;
25use crate::sparse_sym_iface::{EMatrixFormat, SparseSymLinearSolverInterface};
26use crate::status::ESymSolverStatus;
27use crate::sym_solver::SymLinearSolver;
28use pounce_common::types::{Index, Number};
29use pounce_linalg::triplet_convert::{TriFull, TripletToCsrConverter};
30
31pub struct TSymLinearSolver {
34 backend: Box<dyn SparseSymLinearSolverInterface>,
35 scaling_method: Option<Box<dyn TSymScalingMethod>>,
36 matrix_format: EMatrixFormat,
37 converter: Option<TripletToCsrConverter>,
38
39 initialized: bool,
41 have_structure: bool,
45 use_scaling: bool,
47 just_switched_on_scaling: bool,
50 linear_scaling_on_demand: bool,
54
55 dim: Index,
56 nonzeros_triplet: Index,
57 nonzeros_compressed: Index,
58
59 airn: Vec<Index>,
61 ajcn: Vec<Index>,
63 scaling_factors: Vec<Number>,
66}
67
68impl std::fmt::Debug for TSymLinearSolver {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("TSymLinearSolver")
71 .field("matrix_format", &self.matrix_format)
72 .field("dim", &self.dim)
73 .field("nonzeros_triplet", &self.nonzeros_triplet)
74 .field("nonzeros_compressed", &self.nonzeros_compressed)
75 .field("use_scaling", &self.use_scaling)
76 .field("initialized", &self.initialized)
77 .finish_non_exhaustive()
78 }
79}
80
81impl TSymLinearSolver {
82 pub fn new(
87 backend: Box<dyn SparseSymLinearSolverInterface>,
88 scaling_method: Option<Box<dyn TSymScalingMethod>>,
89 linear_scaling_on_demand: bool,
90 ) -> Self {
91 let matrix_format = backend.matrix_format();
92 let converter = match matrix_format {
93 EMatrixFormat::TripletFormat => None,
94 EMatrixFormat::CsrFormat0Offset => {
95 Some(TripletToCsrConverter::new(0, TriFull::Triangular))
96 }
97 EMatrixFormat::CsrFormat1Offset => {
98 Some(TripletToCsrConverter::new(1, TriFull::Triangular))
99 }
100 EMatrixFormat::CsrFullFormat0Offset => {
101 Some(TripletToCsrConverter::new(0, TriFull::Full))
102 }
103 EMatrixFormat::CsrFullFormat1Offset => {
104 Some(TripletToCsrConverter::new(1, TriFull::Full))
105 }
106 };
107 let use_scaling = scaling_method.is_some() && !linear_scaling_on_demand;
108 Self {
109 backend,
110 scaling_method,
111 matrix_format,
112 converter,
113 initialized: false,
114 have_structure: false,
115 use_scaling,
116 just_switched_on_scaling: false,
117 linear_scaling_on_demand,
118 dim: 0,
119 nonzeros_triplet: 0,
120 nonzeros_compressed: 0,
121 airn: Vec::new(),
122 ajcn: Vec::new(),
123 scaling_factors: Vec::new(),
124 }
125 }
126
127 pub fn initialize_structure(
131 &mut self,
132 dim: Index,
133 airn: &[Index],
134 ajcn: &[Index],
135 ) -> ESymSolverStatus {
136 assert_eq!(airn.len(), ajcn.len());
137 let nz = airn.len() as Index;
138 self.dim = dim;
139 self.nonzeros_triplet = nz;
140 self.airn = airn.to_vec();
141 self.ajcn = ajcn.to_vec();
142
143 let (ia, ja, nonzeros) = match self.converter.as_mut() {
144 None => (&self.airn[..], &self.ajcn[..], self.nonzeros_triplet),
145 Some(conv) => {
146 let nonzeros_compressed = conv.initialize(self.dim, &self.airn, &self.ajcn);
147 self.nonzeros_compressed = nonzeros_compressed;
148 (conv.ia(), conv.ja(), nonzeros_compressed)
149 }
150 };
151 let status = self.backend.initialize_structure(dim, nonzeros, ia, ja);
152 if status != ESymSolverStatus::Success {
153 return status;
154 }
155 if self.scaling_method.is_some() {
156 self.scaling_factors = vec![0.0; dim as usize];
157 }
158 self.have_structure = true;
159 self.initialized = true;
160 status
161 }
162
163 #[allow(clippy::too_many_arguments)]
174 pub fn multi_solve(
175 &mut self,
176 vals: &[Number],
177 new_matrix: bool,
178 nrhs: Index,
179 rhs_vals: &mut [Number],
180 check_neg_evals: bool,
181 number_of_neg_evals: Index,
182 ) -> ESymSolverStatus {
183 debug_assert!(self.initialized);
184 debug_assert_eq!(vals.len(), self.nonzeros_triplet as usize);
185 debug_assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
186
187 {
198 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
199 static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
200 static WARNED: AtomicBool = AtomicBool::new(false);
201 let n_call = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
202 let skip: usize = std::env::var("POUNCE_DBG_KKT_DUMP_SKIP")
203 .ok()
204 .and_then(|s| s.parse().ok())
205 .unwrap_or(0);
206 if n_call < skip {
207 } else if let Ok(path) = std::env::var("POUNCE_DBG_KKT_DUMP") {
209 if !WARNED.swap(true, Ordering::SeqCst) {
210 eprintln!(
211 "warning: POUNCE_DBG_KKT_DUMP is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
212 );
213 }
214 use std::io::Write;
215 if let Ok(mut f) = std::fs::File::create(&path) {
216 let dim = self.dim as u64;
217 let nnz = self.nonzeros_triplet as u64;
218 let nrhs64 = nrhs as u64;
219 let _ = f.write_all(&dim.to_le_bytes());
220 let _ = f.write_all(&nnz.to_le_bytes());
221 let _ = f.write_all(&nrhs64.to_le_bytes());
222 for &i in &self.airn {
223 let _ = f.write_all(&(i as i64).to_le_bytes());
224 }
225 for &j in &self.ajcn {
226 let _ = f.write_all(&(j as i64).to_le_bytes());
227 }
228 for &v in vals {
229 let _ = f.write_all(&v.to_le_bytes());
230 }
231 for &v in &*rhs_vals {
232 let _ = f.write_all(&v.to_le_bytes());
233 }
234 let _ = f.flush();
235 }
236 unsafe {
239 std::env::remove_var("POUNCE_DBG_KKT_DUMP");
240 }
241 }
242 }
243
244 let mut new_matrix = new_matrix;
246 if new_matrix || self.just_switched_on_scaling {
247 self.give_matrix_to_solver(true, vals);
248 new_matrix = true;
249 }
250
251 if self.use_scaling {
253 for irhs in 0..nrhs as usize {
254 let base = irhs * self.dim as usize;
255 for i in 0..self.dim as usize {
256 rhs_vals[base + i] *= self.scaling_factors[i];
257 }
258 }
259 }
260
261 let status = loop {
265 let (ia_ptr, ia_len, ja_ptr, ja_len) = match self.converter.as_ref() {
266 None => (
267 self.airn.as_ptr(),
268 self.airn.len(),
269 self.ajcn.as_ptr(),
270 self.ajcn.len(),
271 ),
272 Some(c) => (c.ia().as_ptr(), c.ia().len(), c.ja().as_ptr(), c.ja().len()),
273 };
274 let (ia, ja) = unsafe {
278 (
279 std::slice::from_raw_parts(ia_ptr, ia_len),
280 std::slice::from_raw_parts(ja_ptr, ja_len),
281 )
282 };
283 let s = self.backend.multi_solve(
284 new_matrix,
285 ia,
286 ja,
287 nrhs,
288 rhs_vals,
289 check_neg_evals,
290 number_of_neg_evals,
291 );
292 if s == ESymSolverStatus::CallAgain {
293 self.give_matrix_to_solver(false, vals);
294 continue;
295 }
296 break s;
297 };
298
299 if status == ESymSolverStatus::Success && self.use_scaling {
300 for irhs in 0..nrhs as usize {
306 let base = irhs * self.dim as usize;
307 for i in 0..self.dim as usize {
308 rhs_vals[base + i] *= self.scaling_factors[i];
309 }
310 }
311 }
312
313 status
314 }
315
316 fn give_matrix_to_solver(&mut self, new_matrix: bool, vals: &[Number]) {
320 if self.matrix_format == EMatrixFormat::TripletFormat && !self.use_scaling {
324 let pa = self.backend.values_array_mut();
325 pa[..self.nonzeros_triplet as usize]
326 .copy_from_slice(&vals[..self.nonzeros_triplet as usize]);
327 return;
328 }
329
330 let mut atriplet: Vec<Number> = vals[..self.nonzeros_triplet as usize].to_vec();
333
334 if self.use_scaling {
335 if new_matrix || self.just_switched_on_scaling {
336 let Some(method) = self.scaling_method.as_mut() else {
339 unreachable!("use_scaling without a scaling method")
340 };
341 let ok = method.compute_sym_t_scaling_factors(
342 self.dim,
343 self.nonzeros_triplet,
344 &self.airn,
345 &self.ajcn,
346 &atriplet,
347 &mut self.scaling_factors,
348 );
349 assert!(ok, "scaling method failed");
350 self.just_switched_on_scaling = false;
351 }
352 for (i, a) in atriplet
353 .iter_mut()
354 .enumerate()
355 .take(self.nonzeros_triplet as usize)
356 {
357 let r = (self.airn[i] - 1) as usize;
358 let c = (self.ajcn[i] - 1) as usize;
359 *a *= self.scaling_factors[r] * self.scaling_factors[c];
360 }
361 }
362
363 if self.matrix_format == EMatrixFormat::TripletFormat {
364 let pa = self.backend.values_array_mut();
365 pa[..self.nonzeros_triplet as usize].copy_from_slice(&atriplet);
366 } else {
367 let Some(conv) = self.converter.as_ref() else {
368 unreachable!("non-triplet matrix_format requires a converter");
369 };
370 let pa = self.backend.values_array_mut();
371 conv.convert_values(&atriplet, &mut pa[..self.nonzeros_compressed as usize]);
372 }
373 }
374}
375
376impl SymLinearSolver for TSymLinearSolver {
377 fn number_of_neg_evals(&self) -> Index {
378 self.backend.number_of_neg_evals()
379 }
380
381 fn increase_quality(&mut self) -> bool {
385 if self.scaling_method.is_some() && !self.use_scaling && self.linear_scaling_on_demand {
386 self.use_scaling = true;
387 self.just_switched_on_scaling = true;
388 return true;
389 }
390 self.backend.increase_quality()
391 }
392
393 fn provides_inertia(&self) -> bool {
394 self.backend.provides_inertia()
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::scaling::IdentityScalingMethod;
402
403 #[derive(Default)]
407 struct MockBackend {
408 dim: Index,
409 nz: Index,
410 a: Vec<Number>,
411 last_solve_was_new_matrix: bool,
412 last_solve_was_scaled_a: Option<Vec<Number>>,
413 canned_solution: Vec<Number>,
414 neg_evals: Index,
415 increase_quality_calls: u32,
416 max_increase_quality_calls: u32,
417 }
418
419 impl SparseSymLinearSolverInterface for MockBackend {
420 fn initialize_structure(
421 &mut self,
422 dim: Index,
423 nz: Index,
424 _ia: &[Index],
425 _ja: &[Index],
426 ) -> ESymSolverStatus {
427 self.dim = dim;
428 self.nz = nz;
429 self.a = vec![0.0; nz as usize];
430 ESymSolverStatus::Success
431 }
432 fn values_array_mut(&mut self) -> &mut [Number] {
433 &mut self.a
434 }
435 fn multi_solve(
436 &mut self,
437 new_matrix: bool,
438 _ia: &[Index],
439 _ja: &[Index],
440 nrhs: Index,
441 rhs_vals: &mut [Number],
442 _check: bool,
443 _nev: Index,
444 ) -> ESymSolverStatus {
445 self.last_solve_was_new_matrix = new_matrix;
446 self.last_solve_was_scaled_a = Some(self.a.clone());
447 assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
448 for irhs in 0..nrhs as usize {
449 let base = irhs * self.dim as usize;
450 rhs_vals[base..base + self.dim as usize].copy_from_slice(&self.canned_solution);
451 }
452 ESymSolverStatus::Success
453 }
454 fn number_of_neg_evals(&self) -> Index {
455 self.neg_evals
456 }
457 fn increase_quality(&mut self) -> bool {
458 self.increase_quality_calls += 1;
459 self.increase_quality_calls <= self.max_increase_quality_calls
460 }
461 fn provides_inertia(&self) -> bool {
462 true
463 }
464 fn matrix_format(&self) -> EMatrixFormat {
465 EMatrixFormat::TripletFormat
466 }
467 }
468
469 fn make_2x2_indef_pattern() -> ([Index; 3], [Index; 3]) {
470 ([1, 2, 2], [1, 1, 2])
471 }
472
473 #[test]
474 fn unscaled_triplet_solve_passes_values_through() {
475 let backend = MockBackend {
476 canned_solution: vec![10.0, 20.0],
477 ..Default::default()
478 };
479 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
480 let (irn, jcn) = make_2x2_indef_pattern();
481 assert_eq!(
482 solver.initialize_structure(2, &irn, &jcn),
483 ESymSolverStatus::Success
484 );
485
486 let vals = [2.0, 1.0, 3.0];
487 let mut rhs = [3.0, 4.0];
488 assert_eq!(
489 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
490 ESymSolverStatus::Success
491 );
492 assert_eq!(rhs, [10.0, 20.0]);
494 assert!(solver.provides_inertia());
495 }
496
497 #[test]
498 fn identity_scaling_does_not_change_values() {
499 let backend = MockBackend {
500 canned_solution: vec![1.0, 1.0],
501 ..Default::default()
502 };
503 let mut solver = TSymLinearSolver::new(
506 Box::new(backend),
507 Some(Box::new(IdentityScalingMethod)),
508 false,
509 );
510 let (irn, jcn) = make_2x2_indef_pattern();
511 solver.initialize_structure(2, &irn, &jcn);
512
513 let vals = [2.0, 1.0, 3.0];
514 let mut rhs = [4.0, 5.0];
515 assert_eq!(
516 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
517 ESymSolverStatus::Success
518 );
519 assert_eq!(rhs, [1.0, 1.0]);
523 }
524
525 #[test]
526 fn nontrivial_scaling_premultiplies_matrix_and_postmultiplies_solution() {
527 struct DiagTwoThree;
532 impl TSymScalingMethod for DiagTwoThree {
533 fn compute_sym_t_scaling_factors(
534 &mut self,
535 _n: Index,
536 _nnz: Index,
537 _airn: &[Index],
538 _ajcn: &[Index],
539 _a: &[Number],
540 scaling_factors: &mut [Number],
541 ) -> bool {
542 scaling_factors[0] = 2.0;
543 scaling_factors[1] = 3.0;
544 true
545 }
546 }
547
548 let backend = MockBackend {
549 canned_solution: vec![7.0, 11.0],
553 ..Default::default()
554 };
555 let mut solver =
556 TSymLinearSolver::new(Box::new(backend), Some(Box::new(DiagTwoThree)), false);
557 let (irn, jcn) = make_2x2_indef_pattern();
558 solver.initialize_structure(2, &irn, &jcn);
559
560 let vals = [2.0, 1.0, 3.0];
561 let mut rhs = [4.0, 5.0];
562 assert_eq!(
563 solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
564 ESymSolverStatus::Success
565 );
566 assert_eq!(rhs, [2.0 * 7.0, 3.0 * 11.0]);
567 }
568
569 #[test]
570 fn increase_quality_switches_on_scaling_first() {
571 let backend = MockBackend {
572 canned_solution: vec![0.0, 0.0],
573 max_increase_quality_calls: 5,
574 ..Default::default()
575 };
576 let mut solver = TSymLinearSolver::new(
577 Box::new(backend),
578 Some(Box::new(IdentityScalingMethod)),
579 true, );
581 assert!(solver.increase_quality());
584 assert!(solver.increase_quality());
586 }
587
588 #[test]
589 fn increase_quality_without_scaling_goes_straight_to_backend() {
590 let backend = MockBackend {
591 max_increase_quality_calls: 1,
592 ..Default::default()
593 };
594 let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
595 assert!(solver.increase_quality());
596 assert!(!solver.increase_quality());
598 }
599}