1use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
30use pounce_common::diagnostics::{DiagCategory, DiagnosticsState};
31use pounce_common::timing::TimingStatistics;
32use pounce_common::types::{Index, Number};
33use pounce_linalg::compound_vector::CompoundVector;
34use pounce_linalg::dense_vector::DenseVector;
35use pounce_linalg::diag_matrix::DiagMatrix;
36use pounce_linalg::triplet::{GenTMatrix, SymTMatrix};
37use pounce_linalg::Vector;
38use pounce_linsol::{ESymSolverStatus, FactorPattern, SymLinearSolver, TSymLinearSolver};
39use std::ops::Range;
40use std::rc::Rc;
41
42pub struct StdAugSystemSolver {
44 linsol: TSymLinearSolver,
45
46 initialized: bool,
48 struct_sig: Option<(usize, usize, usize, Index, Index, Index)>,
57 n_x: Index,
58 n_s: Index,
59 n_c: Index,
60 n_d: Index,
61 dim: Index,
63
64 irn: Vec<Index>,
66 jcn: Vec<Index>,
68 vals: Vec<Number>,
70
71 w_range: Range<usize>,
73 dx_range: Range<usize>,
74 ds_range: Range<usize>,
75 jc_range: Range<usize>,
76 dc_range: Range<usize>,
77 jd_range: Range<usize>,
78 minus_i_range: Range<usize>,
79 dd_range: Range<usize>,
80
81 last_neg_evals: Index,
82 last_status: Option<ESymSolverStatus>,
83
84 have_factor: bool,
88
89 timing: Option<Rc<TimingStatistics>>,
93
94 diagnostics: Option<Rc<DiagnosticsState>>,
100}
101
102impl std::fmt::Debug for StdAugSystemSolver {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("StdAugSystemSolver")
105 .field("dim", &self.dim)
106 .field("nnz", &self.vals.len())
107 .field("initialized", &self.initialized)
108 .field("last_neg_evals", &self.last_neg_evals)
109 .field("last_status", &self.last_status)
110 .finish_non_exhaustive()
111 }
112}
113
114impl StdAugSystemSolver {
115 pub fn new(linsol: TSymLinearSolver) -> Self {
117 Self {
118 linsol,
119 initialized: false,
120 struct_sig: None,
121 n_x: 0,
122 n_s: 0,
123 n_c: 0,
124 n_d: 0,
125 dim: 0,
126 irn: Vec::new(),
127 jcn: Vec::new(),
128 vals: Vec::new(),
129 w_range: 0..0,
130 dx_range: 0..0,
131 ds_range: 0..0,
132 jc_range: 0..0,
133 dc_range: 0..0,
134 jd_range: 0..0,
135 minus_i_range: 0..0,
136 dd_range: 0..0,
137 last_neg_evals: 0,
138 last_status: None,
139 have_factor: false,
140 timing: None,
141 diagnostics: None,
142 }
143 }
144
145 fn build_structure(&mut self, coeffs: &AugSysCoeffs<'_>) -> ESymSolverStatus {
146 let n_x = coeffs.j_c.n_cols();
147 let n_c = coeffs.j_c.n_rows();
148 let n_d = coeffs.j_d.n_rows();
149 debug_assert_eq!(coeffs.j_d.n_cols(), n_x);
150 let n_s = n_d;
151
152 let w_nnz = match coeffs.w {
153 None => 0_usize,
154 Some(w) => w_nonzeros(w),
155 };
156 let jc_nnz = gen_t_downcast(coeffs.j_c).nonzeros() as usize;
157 let jd_nnz = gen_t_downcast(coeffs.j_d).nonzeros() as usize;
158
159 let total = w_nnz
160 + (n_x as usize) + (n_s as usize) + jc_nnz
163 + (n_c as usize) + jd_nnz
165 + (n_s as usize) + (n_d as usize); self.irn = Vec::with_capacity(total);
169 self.jcn = Vec::with_capacity(total);
170 self.vals = vec![0.0; total];
171
172 let w_start = self.irn.len();
174 if let Some(w) = coeffs.w {
175 if let Some(t) = w.as_any().downcast_ref::<SymTMatrix>() {
176 self.irn.extend_from_slice(t.irows());
177 self.jcn.extend_from_slice(t.jcols());
178 } else if let Some(dm) = w.as_any().downcast_ref::<DiagMatrix>() {
179 let n = w_diag_dim(dm);
182 for i in 0..n {
183 self.irn.push(i + 1);
184 self.jcn.push(i + 1);
185 }
186 } else {
187 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
188 }
189 }
190 self.w_range = w_start..self.irn.len();
191
192 let dx_start = self.irn.len();
194 for i in 0..n_x {
195 self.irn.push(i + 1);
196 self.jcn.push(i + 1);
197 }
198 self.dx_range = dx_start..self.irn.len();
199
200 let ds_start = self.irn.len();
202 for i in 0..n_s {
203 let r = n_x + i + 1;
204 self.irn.push(r);
205 self.jcn.push(r);
206 }
207 self.ds_range = ds_start..self.irn.len();
208
209 let jc_start = self.irn.len();
211 let j_c = gen_t_downcast(coeffs.j_c);
212 let row_off_c = n_x + n_s;
213 for (&i, &j) in j_c.irows().iter().zip(j_c.jcols().iter()) {
214 self.irn.push(row_off_c + i);
217 self.jcn.push(j);
218 }
219 self.jc_range = jc_start..self.irn.len();
220
221 let dc_start = self.irn.len();
223 for i in 0..n_c {
224 let r = n_x + n_s + i + 1;
225 self.irn.push(r);
226 self.jcn.push(r);
227 }
228 self.dc_range = dc_start..self.irn.len();
229
230 let jd_start = self.irn.len();
232 let j_d = gen_t_downcast(coeffs.j_d);
233 let row_off_d = n_x + n_s + n_c;
234 for (&i, &j) in j_d.irows().iter().zip(j_d.jcols().iter()) {
235 self.irn.push(row_off_d + i);
236 self.jcn.push(j);
237 }
238 self.jd_range = jd_start..self.irn.len();
239
240 let mi_start = self.irn.len();
242 for i in 0..n_s {
243 self.irn.push(n_x + n_s + n_c + i + 1);
244 self.jcn.push(n_x + i + 1);
245 }
246 self.minus_i_range = mi_start..self.irn.len();
247
248 let dd_start = self.irn.len();
250 for i in 0..n_d {
251 let r = n_x + n_s + n_c + i + 1;
252 self.irn.push(r);
253 self.jcn.push(r);
254 }
255 self.dd_range = dd_start..self.irn.len();
256
257 debug_assert_eq!(self.irn.len(), total);
258 debug_assert_eq!(self.jcn.len(), total);
259
260 self.n_x = n_x;
261 self.n_s = n_s;
262 self.n_c = n_c;
263 self.n_d = n_d;
264 self.dim = n_x + n_s + n_c + n_d;
265
266 let status = self
267 .linsol
268 .initialize_structure(self.dim, &self.irn, &self.jcn);
269 if status == ESymSolverStatus::Success {
270 self.initialized = true;
271 }
272 status
273 }
274
275 fn refill_values(&mut self, coeffs: &AugSysCoeffs<'_>) {
276 if !self.w_range.is_empty() {
278 let Some(w_dyn) = coeffs.w else {
279 unreachable!("structure pinned with W; W cannot be None now")
280 };
281 let dst = &mut self.vals[self.w_range.clone()];
282 if let Some(t) = w_dyn.as_any().downcast_ref::<SymTMatrix>() {
283 for (d, &v) in dst.iter_mut().zip(t.values().iter()) {
284 *d = coeffs.w_factor * v;
285 }
286 } else if let Some(dm) = w_dyn.as_any().downcast_ref::<DiagMatrix>() {
287 let diag = w_diag_values(dm);
288 for (d, &v) in dst.iter_mut().zip(diag.iter()) {
289 *d = coeffs.w_factor * v;
290 }
291 } else {
292 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
293 }
294 }
295 fill_diag(
297 &mut self.vals[self.dx_range.clone()],
298 coeffs.d_x,
299 coeffs.delta_x,
300 1.0,
301 );
302 fill_diag(
304 &mut self.vals[self.ds_range.clone()],
305 coeffs.d_s,
306 coeffs.delta_s,
307 1.0,
308 );
309 {
311 let j_c = gen_t_downcast(coeffs.j_c);
312 self.vals[self.jc_range.clone()].copy_from_slice(j_c.values());
313 }
314 fill_diag(
316 &mut self.vals[self.dc_range.clone()],
317 coeffs.d_c,
318 coeffs.delta_c,
319 -1.0,
320 );
321 {
323 let j_d = gen_t_downcast(coeffs.j_d);
324 self.vals[self.jd_range.clone()].copy_from_slice(j_d.values());
325 }
326 for v in &mut self.vals[self.minus_i_range.clone()] {
328 *v = -1.0;
329 }
330 fill_diag(
332 &mut self.vals[self.dd_range.clone()],
333 coeffs.d_d,
334 coeffs.delta_d,
335 -1.0,
336 );
337 }
338
339 fn pack_rhs(&self, rhs: &AugSysRhs<'_>, packed: &mut [Number]) {
340 let n_x = self.n_x as usize;
341 let n_s = self.n_s as usize;
342 let n_c = self.n_c as usize;
343 let n_d = self.n_d as usize;
344 copy_vec(rhs.rhs_x, &mut packed[..n_x]);
345 copy_vec(rhs.rhs_s, &mut packed[n_x..n_x + n_s]);
346 copy_vec(rhs.rhs_c, &mut packed[n_x + n_s..n_x + n_s + n_c]);
347 copy_vec(
348 rhs.rhs_d,
349 &mut packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d],
350 );
351 }
352
353 fn unpack_sol(&self, packed: &[Number], sol: &mut AugSysSol<'_>) {
354 let n_x = self.n_x as usize;
355 let n_s = self.n_s as usize;
356 let n_c = self.n_c as usize;
357 let n_d = self.n_d as usize;
358 write_vec(sol.sol_x, &packed[..n_x]);
359 write_vec(sol.sol_s, &packed[n_x..n_x + n_s]);
360 write_vec(sol.sol_c, &packed[n_x + n_s..n_x + n_s + n_c]);
361 write_vec(sol.sol_d, &packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d]);
362 }
363}
364
365impl AugSystemSolver for StdAugSystemSolver {
366 fn provides_inertia(&self) -> bool {
367 self.linsol.provides_inertia()
368 }
369
370 fn number_of_neg_evals(&self) -> Index {
371 self.last_neg_evals
372 }
373
374 fn increase_quality(&mut self) -> bool {
375 self.have_factor = false;
379 self.linsol.increase_quality()
380 }
381
382 fn last_solve_status(&self) -> ESymSolverStatus {
383 self.last_status.unwrap_or(ESymSolverStatus::FatalError)
384 }
385
386 fn solve(
387 &mut self,
388 coeffs: &AugSysCoeffs<'_>,
389 rhs: &AugSysRhs<'_>,
390 sol: &mut AugSysSol<'_>,
391 check_neg_evals: bool,
392 num_neg_evals: Index,
393 ) -> ESymSolverStatus {
394 let sig = {
402 let w_nnz = coeffs.w.map(w_nonzeros).unwrap_or(0);
403 let jc_nnz = gen_t_downcast(coeffs.j_c).nonzeros() as usize;
404 let jd_nnz = gen_t_downcast(coeffs.j_d).nonzeros() as usize;
405 (
406 w_nnz,
407 jc_nnz,
408 jd_nnz,
409 coeffs.j_c.n_cols(),
410 coeffs.j_c.n_rows(),
411 coeffs.j_d.n_rows(),
412 )
413 };
414 if !self.initialized || self.struct_sig != Some(sig) {
415 let s = self.build_structure(coeffs);
416 if s != ESymSolverStatus::Success {
417 self.last_status = Some(s);
418 return s;
419 }
420 self.struct_sig = Some(sig);
421 }
422 self.refill_values(coeffs);
423
424 let mut packed = vec![0.0; self.dim as usize];
425 self.pack_rhs(rhs, &mut packed);
426
427 let dump_rhs = packed.clone();
428
429 let _factor_guard = self
433 .timing
434 .as_deref()
435 .map(|t| t.linear_system_factorization.guard());
436 let status = self.linsol.multi_solve(
437 &self.vals,
438 true,
439 1,
440 &mut packed,
441 check_neg_evals,
442 num_neg_evals,
443 );
444 drop(_factor_guard);
445 self.last_status = Some(status);
446 if status == ESymSolverStatus::Success {
447 if self.linsol.provides_inertia() {
448 self.last_neg_evals = self.linsol.number_of_neg_evals();
449 }
450 self.unpack_sol(&packed, sol);
451 self.have_factor = true;
452 }
453
454 if let Some(diag) = self.diagnostics.clone() {
459 if diag.want(DiagCategory::Kkt) {
460 let solve_idx = diag.next_solve_index();
461 let filename = format!("kkt_solve_{solve_idx:03}.jsonl");
462 let variant = diag.config.kkt_variant;
469 let factor_pattern =
470 if status == ESymSolverStatus::Success && variant.wants_l_pattern() {
471 self.linsol.factor_pattern(variant.wants_l_values())
472 } else {
473 None
474 };
475 if let Some(mut w) = diag.open_writer(&filename) {
476 let _ = write_kkt_record(
477 &mut w,
478 self.dim,
479 &self.irn,
480 &self.jcn,
481 &self.vals,
482 &dump_rhs,
483 &packed,
484 check_neg_evals,
485 num_neg_evals,
486 status,
487 self.last_neg_evals,
488 factor_pattern.as_ref(),
489 );
490 }
491 }
492 }
493 if let Ok(path) = std::env::var("POUNCE_DUMP_KKT") {
494 use std::sync::atomic::{AtomicBool, Ordering};
495 static WARNED: AtomicBool = AtomicBool::new(false);
496 if !WARNED.swap(true, Ordering::SeqCst) {
497 tracing::warn!(target: "pounce::linsol",
498 "warning: POUNCE_DUMP_KKT is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
499 );
500 }
501 dump_kkt(
502 &path,
503 self.dim,
504 &self.irn,
505 &self.jcn,
506 &self.vals,
507 &dump_rhs,
508 &packed,
509 check_neg_evals,
510 num_neg_evals,
511 status,
512 self.last_neg_evals,
513 );
514 }
515
516 status
517 }
518
519 fn resolve(
520 &mut self,
521 coeffs: &AugSysCoeffs<'_>,
522 rhs: &AugSysRhs<'_>,
523 sol: &mut AugSysSol<'_>,
524 ) -> ESymSolverStatus {
525 if !self.have_factor {
532 return self.solve(coeffs, rhs, sol, false, 0);
533 }
534
535 let mut packed = vec![0.0; self.dim as usize];
536 self.pack_rhs(rhs, &mut packed);
537
538 let _back_guard = self
541 .timing
542 .as_deref()
543 .map(|t| t.linear_system_back_solve.guard());
544 let status = self
545 .linsol
546 .multi_solve(&self.vals, false, 1, &mut packed, false, 0);
547 drop(_back_guard);
548 self.last_status = Some(status);
549 if status == ESymSolverStatus::Success {
550 self.unpack_sol(&packed, sol);
551 }
552 status
553 }
554
555 fn set_diagnostics(&mut self, diag: Rc<DiagnosticsState>) {
556 self.diagnostics = Some(diag);
557 }
558
559 fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
560 self.timing = Some(timing);
561 }
562
563 fn try_resolve_many_flat(
564 &mut self,
565 _coeffs: &AugSysCoeffs<'_>,
566 packed_rhs: &mut [Number],
567 nrhs: usize,
568 ) -> Option<ESymSolverStatus> {
569 if !self.have_factor {
574 return None;
575 }
576 if packed_rhs.len() != (self.dim as usize) * nrhs {
577 return Some(ESymSolverStatus::FatalError);
578 }
579 let _back_guard = self
580 .timing
581 .as_deref()
582 .map(|t| t.linear_system_back_solve.guard());
583 let status =
584 self.linsol
585 .multi_solve(&self.vals, false, nrhs as Index, packed_rhs, false, 0);
586 drop(_back_guard);
587 self.last_status = Some(status);
588 Some(status)
589 }
590}
591
592#[allow(clippy::too_many_arguments)]
595fn write_kkt_record(
600 w: &mut dyn std::io::Write,
601 dim: Index,
602 irn: &[Index],
603 jcn: &[Index],
604 vals: &[Number],
605 rhs: &[Number],
606 sol: &[Number],
607 check_neg_evals: bool,
608 num_neg_evals: Index,
609 status: ESymSolverStatus,
610 last_neg_evals: Index,
611 factor_pattern: Option<&FactorPattern>,
612) -> std::io::Result<()> {
613 use std::fmt::Write as _;
614
615 let mut line = String::with_capacity(64 * vals.len());
616 line.push('{');
617 let _ = write!(line, "\"n\":{dim},");
618 let _ = write!(line, "\"check_neg_evals\":{check_neg_evals},");
619 let _ = write!(line, "\"num_neg_evals_expected\":{num_neg_evals},");
620 let _ = write!(line, "\"num_neg_evals_actual\":{last_neg_evals},");
621 let _ = write!(line, "\"status\":\"{status:?}\",");
622
623 line.push_str("\"irn\":[");
624 for (i, v) in irn.iter().enumerate() {
625 if i > 0 {
626 line.push(',');
627 }
628 let _ = write!(line, "{v}");
629 }
630 line.push_str("],\"jcn\":[");
631 for (i, v) in jcn.iter().enumerate() {
632 if i > 0 {
633 line.push(',');
634 }
635 let _ = write!(line, "{v}");
636 }
637 line.push_str("],\"vals\":[");
638 for (i, v) in vals.iter().enumerate() {
639 if i > 0 {
640 line.push(',');
641 }
642 let _ = write!(line, "{v:.17e}");
643 }
644 line.push_str("],\"rhs\":[");
645 for (i, v) in rhs.iter().enumerate() {
646 if i > 0 {
647 line.push(',');
648 }
649 let _ = write!(line, "{v:.17e}");
650 }
651 line.push_str("],\"sol\":[");
652 for (i, v) in sol.iter().enumerate() {
653 if i > 0 {
654 line.push(',');
655 }
656 let _ = write!(line, "{v:.17e}");
657 }
658 line.push(']');
659
660 if let Some(fp) = factor_pattern {
665 line.push_str(",\"L_irn\":[");
666 for (i, v) in fp.l_irn.iter().enumerate() {
667 if i > 0 {
668 line.push(',');
669 }
670 let _ = write!(line, "{v}");
671 }
672 line.push_str("],\"L_jcn\":[");
673 for (i, v) in fp.l_jcn.iter().enumerate() {
674 if i > 0 {
675 line.push(',');
676 }
677 let _ = write!(line, "{v}");
678 }
679 line.push_str("],\"perm\":[");
680 for (i, v) in fp.perm.iter().enumerate() {
681 if i > 0 {
682 line.push(',');
683 }
684 let _ = write!(line, "{v}");
685 }
686 line.push(']');
687 if let Some(vals) = fp.l_vals.as_ref() {
688 line.push_str(",\"L_vals\":[");
689 for (i, v) in vals.iter().enumerate() {
690 if i > 0 {
691 line.push(',');
692 }
693 let _ = write!(line, "{v:.17e}");
694 }
695 line.push(']');
696 }
697 }
698
699 line.push_str("}\n");
700
701 w.write_all(line.as_bytes())
702}
703
704fn dump_kkt(
705 path: &str,
706 dim: Index,
707 irn: &[Index],
708 jcn: &[Index],
709 vals: &[Number],
710 rhs: &[Number],
711 sol: &[Number],
712 check_neg_evals: bool,
713 num_neg_evals: Index,
714 status: ESymSolverStatus,
715 last_neg_evals: Index,
716) {
717 if let Ok(mut f) = std::fs::OpenOptions::new()
718 .create(true)
719 .append(true)
720 .open(path)
721 {
722 let _ = write_kkt_record(
723 &mut f,
724 dim,
725 irn,
726 jcn,
727 vals,
728 rhs,
729 sol,
730 check_neg_evals,
731 num_neg_evals,
732 status,
733 last_neg_evals,
734 None, );
736 }
737}
738
739fn w_nonzeros(w: &dyn pounce_linalg::SymMatrix) -> usize {
744 if let Some(t) = w.as_any().downcast_ref::<SymTMatrix>() {
745 t.nonzeros() as usize
746 } else if let Some(dm) = w.as_any().downcast_ref::<DiagMatrix>() {
747 w_diag_dim(dm) as usize
748 } else {
749 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
750 }
751}
752
753fn w_diag_dim(dm: &DiagMatrix) -> Index {
754 dm.get_diag()
755 .expect("DiagMatrix W has no diagonal set")
756 .dim()
757}
758
759fn w_diag_values(dm: &DiagMatrix) -> Vec<Number> {
760 let diag = dm.get_diag().expect("DiagMatrix W has no diagonal set");
761 diag.as_any()
762 .downcast_ref::<DenseVector>()
763 .expect("StdAugSystemSolver: DiagMatrix W diagonal must be DenseVector in v1.0")
764 .expanded_values()
765}
766
767fn gen_t_downcast(m: &dyn pounce_linalg::Matrix) -> &GenTMatrix {
768 let Some(t) = m.as_any().downcast_ref::<GenTMatrix>() else {
769 unreachable!("StdAugSystemSolver: J_c / J_d must be GenTMatrix in v1.0")
770 };
771 t
772}
773
774fn flat_read(v: &dyn Vector) -> Vec<Number> {
780 if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
781 return dv.expanded_values();
782 }
783 if let Some(cv) = v.as_any().downcast_ref::<CompoundVector>() {
784 let mut out = Vec::with_capacity(cv.dim() as usize);
785 for k in 0..cv.n_comps() {
786 let blk = cv.comp(k);
787 let dblk = blk
788 .as_any()
789 .downcast_ref::<DenseVector>()
790 .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
791 out.extend_from_slice(&dblk.expanded_values());
792 }
793 return out;
794 }
795 unreachable!("StdAugSystemSolver: D_*/rhs/sol must be DenseVector or CompoundVector of DenseVectors in v1.0")
796}
797
798fn flat_write(dst: &mut dyn Vector, src: &[Number]) {
800 if let Some(dv) = dst.as_any_mut().downcast_mut::<DenseVector>() {
801 dv.set_values(src);
802 return;
803 }
804 if let Some(cv) = dst.as_any_mut().downcast_mut::<CompoundVector>() {
805 let mut off = 0usize;
806 for k in 0..cv.n_comps() {
807 let blk = cv.comp_mut(k);
808 let dim = blk.dim() as usize;
809 let dblk = blk
810 .as_any_mut()
811 .downcast_mut::<DenseVector>()
812 .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
813 dblk.set_values(&src[off..off + dim]);
814 off += dim;
815 }
816 return;
817 }
818 unreachable!(
819 "StdAugSystemSolver: sol must be DenseVector or CompoundVector of DenseVectors in v1.0"
820 )
821}
822
823fn fill_diag(dst: &mut [Number], d: Option<&dyn Vector>, delta: Number, sign: Number) {
826 match d {
827 None => {
828 for v in dst.iter_mut() {
829 *v = sign * delta;
830 }
831 }
832 Some(d) => {
833 let xs = flat_read(d);
834 debug_assert_eq!(xs.len(), dst.len());
835 for (out, &x) in dst.iter_mut().zip(xs.iter()) {
836 *out = sign * (x + delta);
837 }
838 }
839 }
840}
841
842fn copy_vec(src: &dyn Vector, dst: &mut [Number]) {
843 let xs = flat_read(src);
844 debug_assert_eq!(xs.len(), dst.len());
845 dst.copy_from_slice(&xs);
846}
847
848fn write_vec(dst: &mut dyn Vector, src: &[Number]) {
849 flat_write(dst, src);
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855 use pounce_common::types::{Index, Number};
856 use pounce_linalg::dense_vector::DenseVectorSpace;
857 use pounce_linalg::triplet::{GenTMatrixSpace, SymTMatrixSpace};
858 use pounce_linsol::sparse_sym_iface::SparseSymLinearSolverInterface;
859 use pounce_linsol::EMatrixFormat;
860
861 struct DenseMock {
864 dim: Index,
865 nz: Index,
866 a: Vec<Number>,
867 last_factor: Vec<Number>, neg_evals: Index,
869 }
870
871 impl DenseMock {
872 fn new() -> Self {
873 Self {
874 dim: 0,
875 nz: 0,
876 a: Vec::new(),
877 last_factor: Vec::new(),
878 neg_evals: 0,
879 }
880 }
881 }
882
883 impl SparseSymLinearSolverInterface for DenseMock {
884 fn initialize_structure(
885 &mut self,
886 dim: Index,
887 nz: Index,
888 _ia: &[Index],
889 _ja: &[Index],
890 ) -> ESymSolverStatus {
891 self.dim = dim;
892 self.nz = nz;
893 self.a = vec![0.0; nz as usize];
894 ESymSolverStatus::Success
895 }
896 fn values_array_mut(&mut self) -> &mut [Number] {
897 &mut self.a
898 }
899 fn multi_solve(
900 &mut self,
901 new_matrix: bool,
902 ia: &[Index],
903 ja: &[Index],
904 nrhs: Index,
905 rhs_vals: &mut [Number],
906 _check: bool,
907 _nev: Index,
908 ) -> ESymSolverStatus {
909 let n = self.dim as usize;
910 if new_matrix {
911 let mut dense = vec![0.0; n * n];
914 for k in 0..self.nz as usize {
915 let i = (ia[k] - 1) as usize;
916 let j = (ja[k] - 1) as usize;
917 dense[i * n + j] += self.a[k];
918 if i != j {
919 dense[j * n + i] += self.a[k];
920 }
921 }
922 self.last_factor = dense;
923 }
924 for col in 0..nrhs as usize {
926 let mut a = self.last_factor.clone();
927 let b = &mut rhs_vals[col * n..col * n + n];
928 let mut neg = 0_i32;
929 for k in 0..n {
930 let mut piv = k;
932 let mut piv_abs = a[k * n + k].abs();
933 for r in (k + 1)..n {
934 let av = a[r * n + k].abs();
935 if av > piv_abs {
936 piv_abs = av;
937 piv = r;
938 }
939 }
940 if piv != k {
941 for c in 0..n {
942 a.swap(k * n + c, piv * n + c);
943 }
944 b.swap(k, piv);
945 }
946 let p = a[k * n + k];
947 if p.abs() < 1e-14 {
948 return ESymSolverStatus::Singular;
949 }
950 if p < 0.0 {
951 neg += 1;
952 }
953 for r in (k + 1)..n {
954 let f = a[r * n + k] / p;
955 for c in k..n {
956 a[r * n + c] -= f * a[k * n + c];
957 }
958 b[r] -= f * b[k];
959 }
960 }
961 for k in (0..n).rev() {
963 let mut s = b[k];
964 for c in (k + 1)..n {
965 s -= a[k * n + c] * b[c];
966 }
967 b[k] = s / a[k * n + k];
968 }
969 self.neg_evals = neg;
970 }
971 ESymSolverStatus::Success
972 }
973 fn number_of_neg_evals(&self) -> Index {
974 self.neg_evals
975 }
976 fn increase_quality(&mut self) -> bool {
977 false
978 }
979 fn provides_inertia(&self) -> bool {
980 true
981 }
982 fn matrix_format(&self) -> EMatrixFormat {
983 EMatrixFormat::TripletFormat
984 }
985 }
986
987 #[test]
999 fn solves_5x5_kkt_through_dense_mock() {
1000 let w_space = SymTMatrixSpace::new(2, vec![1, 2], vec![1, 2]);
1002 let mut w = SymTMatrix::new(w_space);
1003 w.set_values(&[2.0, 3.0]);
1004
1005 let jc_space = GenTMatrixSpace::new(1, 2, vec![1, 1], vec![1, 2]);
1007 let mut j_c = GenTMatrix::new(jc_space);
1008 j_c.set_values(&[1.0, 1.0]);
1009
1010 let jd_space = GenTMatrixSpace::new(1, 2, vec![1], vec![1]);
1012 let mut j_d = GenTMatrix::new(jd_space);
1013 j_d.set_values(&[1.0]);
1014
1015 let s_space = DenseVectorSpace::new(1);
1017 let mut d_s = s_space.make_new_dense();
1018 d_s.set_values(&[1.0]);
1019
1020 let xs = DenseVectorSpace::new(2);
1028 let mut rx = xs.make_new_dense();
1029 rx.set_values(&[4.0, 4.0]);
1030 let mut rs = s_space.make_new_dense();
1031 rs.set_values(&[0.0]);
1032 let cs = DenseVectorSpace::new(1);
1033 let mut rc = cs.make_new_dense();
1034 rc.set_values(&[2.0]);
1035 let ds_space = DenseVectorSpace::new(1);
1036 let mut rd = ds_space.make_new_dense();
1037 rd.set_values(&[0.0]);
1038
1039 let mut sx = xs.make_new_dense();
1040 let mut ss = s_space.make_new_dense();
1041 let mut sc = cs.make_new_dense();
1042 let mut sd = ds_space.make_new_dense();
1043
1044 let linsol = TSymLinearSolver::new(Box::new(DenseMock::new()), None, false);
1045 let mut solver = StdAugSystemSolver::new(linsol);
1046
1047 let coeffs = AugSysCoeffs {
1048 w: Some(&w),
1049 w_factor: 1.0,
1050 d_x: None,
1051 delta_x: 0.0,
1052 d_s: Some(&d_s),
1053 delta_s: 0.0,
1054 j_c: &j_c,
1055 d_c: None,
1056 delta_c: 0.0,
1057 j_d: &j_d,
1058 d_d: None,
1059 delta_d: 0.0,
1060 };
1061 let rhs = AugSysRhs {
1062 rhs_x: &rx,
1063 rhs_s: &rs,
1064 rhs_c: &rc,
1065 rhs_d: &rd,
1066 };
1067 let mut sol = AugSysSol {
1068 sol_x: &mut sx,
1069 sol_s: &mut ss,
1070 sol_c: &mut sc,
1071 sol_d: &mut sd,
1072 };
1073 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1074 assert_eq!(status, ESymSolverStatus::Success);
1075
1076 for v in sx.values() {
1077 assert!((v - 1.0).abs() < 1e-10, "sol_x = {v}");
1078 }
1079 for v in ss.values() {
1080 assert!((v - 1.0).abs() < 1e-10, "sol_s = {v}");
1081 }
1082 for v in sc.values() {
1083 assert!((v - 1.0).abs() < 1e-10, "sol_c = {v}");
1084 }
1085 for v in sd.values() {
1086 assert!((v - 1.0).abs() < 1e-10, "sol_d = {v}");
1087 }
1088 }
1089
1090 #[test]
1100 fn lowrank_smw_matches_dense_w_on_constrained_system() {
1101 use crate::kkt::low_rank_aug_system_solver::LowRankAugSystemSolver;
1102 use pounce_linalg::diag_matrix::DiagMatrix;
1103 use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
1104 use pounce_linalg::multi_vector_matrix::MultiVectorMatrixSpace;
1105
1106 let n = 4usize;
1107 let sigma = 2.0;
1108 let vcols = [
1113 vec![0.6, 0.1, -0.2, 0.3],
1114 vec![0.2, 0.5, 0.1, -0.1],
1115 vec![-0.1, 0.2, 0.4, 0.2],
1116 vec![0.3, -0.2, 0.1, 0.4],
1117 vec![0.15, 0.25, -0.3, 0.1],
1118 vec![-0.2, 0.1, 0.2, 0.35],
1119 ];
1120 let ucols = [
1121 vec![0.3, -0.1, 0.2, 0.1],
1122 vec![0.1, 0.3, -0.2, 0.2],
1123 vec![0.2, 0.1, 0.1, -0.3],
1124 vec![-0.1, 0.2, 0.15, 0.1],
1125 vec![0.25, -0.15, 0.1, 0.2],
1126 vec![0.1, 0.2, -0.25, 0.15],
1127 ];
1128 let mut wfull = vec![0.0_f64; n * n];
1130 for i in 0..n {
1131 wfull[i * n + i] = sigma;
1132 }
1133 for c in vcols.iter() {
1134 for i in 0..n {
1135 for j in 0..n {
1136 wfull[i * n + j] += c[i] * c[j];
1137 }
1138 }
1139 }
1140 for c in ucols.iter() {
1141 for i in 0..n {
1142 for j in 0..n {
1143 wfull[i * n + j] -= c[i] * c[j];
1144 }
1145 }
1146 }
1147
1148 let make_jc = || {
1150 let sp = GenTMatrixSpace::new(1, 4, vec![1, 1, 1, 1], vec![1, 2, 3, 4]);
1151 let mut m = GenTMatrix::new(sp);
1152 m.set_values(&[1.0, 1.0, 1.0, 1.0]);
1153 m
1154 };
1155 let make_jd = || {
1159 let sp = GenTMatrixSpace::new(1, 4, vec![1, 1], vec![1, 3]);
1160 let mut m = GenTMatrix::new(sp);
1161 m.set_values(&[1.0, 1.0]);
1162 m
1163 };
1164
1165 let xs = DenseVectorSpace::new(4);
1166 let cs = DenseVectorSpace::new(1);
1167 let mk = |sp: &Rc<DenseVectorSpace>, vals: &[Number]| {
1168 let mut d = sp.make_new_dense();
1169 d.set_values(vals);
1170 d
1171 };
1172
1173 let solve_with = |w: &dyn pounce_linalg::SymMatrix,
1174 aug: &mut dyn AugSystemSolver|
1175 -> (Vec<Number>, Vec<Number>) {
1176 let j_c = make_jc();
1177 let j_d = make_jd();
1178 let rx = mk(&xs, &[1.0, 2.0, -1.0, 0.5]);
1179 let rs = mk(&cs, &[0.4]);
1180 let rc = mk(&cs, &[3.0]);
1181 let rd = mk(&cs, &[0.7]);
1182 let mut sx = mk(&xs, &[0.0, 0.0, 0.0, 0.0]);
1183 let mut ss = mk(&cs, &[0.0]);
1184 let mut sc = mk(&cs, &[0.0]);
1185 let mut sd = mk(&cs, &[0.0]);
1186 let d_s = mk(&cs, &[1.5]);
1187 let coeffs = AugSysCoeffs {
1188 w: Some(w),
1189 w_factor: 1.0,
1190 d_x: None,
1191 delta_x: 0.0,
1192 d_s: Some(&d_s),
1193 delta_s: 0.0,
1194 j_c: &j_c,
1195 d_c: None,
1196 delta_c: 0.0,
1197 j_d: &j_d,
1198 d_d: None,
1199 delta_d: 0.0,
1200 };
1201 let rhs = AugSysRhs {
1202 rhs_x: &rx,
1203 rhs_s: &rs,
1204 rhs_c: &rc,
1205 rhs_d: &rd,
1206 };
1207 let mut sol = AugSysSol {
1208 sol_x: &mut sx,
1209 sol_s: &mut ss,
1210 sol_c: &mut sc,
1211 sol_d: &mut sd,
1212 };
1213 let status = aug.solve(&coeffs, &rhs, &mut sol, false, 1);
1214 assert_eq!(status, ESymSolverStatus::Success);
1215 (sx.expanded_values(), sc.expanded_values())
1216 };
1217
1218 let mut wi = Vec::new();
1220 let mut wj = Vec::new();
1221 let mut wv = Vec::new();
1222 for i in 0..n {
1223 for j in 0..=i {
1224 wi.push(i as Index + 1);
1225 wj.push(j as Index + 1);
1226 wv.push(wfull[i * n + j]);
1227 }
1228 }
1229 let w_space = SymTMatrixSpace::new(4, wi, wj);
1230 let mut w_dense = SymTMatrix::new(w_space);
1231 w_dense.set_values(&wv);
1232 let mut std_solver = StdAugSystemSolver::new(TSymLinearSolver::new(
1233 Box::new(pounce_feral::FeralSolverInterface::new()),
1234 None,
1235 false,
1236 ));
1237 let (ref_x, ref_c) = solve_with(&w_dense, &mut std_solver);
1238
1239 let lr_space = LowRankUpdateSymMatrixSpace::new(4, None, false);
1241 let mut lr = lr_space.make_new_low_rank();
1242 let mut diag = xs.make_new_dense();
1243 diag.set_values(&[sigma; 4]);
1244 lr.set_diag(Rc::new(diag) as Rc<dyn Vector>);
1245 let build_mvm = |cols: &[Vec<Number>]| {
1246 let sp = MultiVectorMatrixSpace::new(cols.len() as Index, Rc::clone(&xs));
1247 let mut mvm = sp.make_new_multi_vector();
1248 for (k, c) in cols.iter().enumerate() {
1249 let mut cv = xs.make_new_dense();
1250 cv.set_values(c);
1251 mvm.set_vector(k as Index, Rc::new(cv) as Rc<dyn Vector>);
1252 }
1253 mvm
1254 };
1255 lr.set_v(Rc::new(build_mvm(&vcols)));
1256 lr.set_u(Rc::new(build_mvm(&ucols)));
1257 let _ = DiagMatrix::new(4); let mut lr_solver =
1260 LowRankAugSystemSolver::new(Box::new(StdAugSystemSolver::new(TSymLinearSolver::new(
1261 Box::new(pounce_feral::FeralSolverInterface::new()),
1262 None,
1263 false,
1264 ))));
1265 let (lr_x, lr_c) = solve_with(&lr, &mut lr_solver);
1266
1267 for (a, b) in ref_x.iter().zip(lr_x.iter()) {
1268 assert!((a - b).abs() < 1e-9, "sol_x mismatch: dense={a} smw={b}");
1269 }
1270 for (a, b) in ref_c.iter().zip(lr_c.iter()) {
1271 assert!((a - b).abs() < 1e-9, "sol_c mismatch: dense={a} smw={b}");
1272 }
1273 }
1274}