1use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
30use pounce_common::diagnostics::{DiagCategory, DiagnosticsState};
31use pounce_common::timing::TimingStatistics;
32use pounce_common::types::{Index, Number};
33use pounce_linalg::compound_vector::CompoundVector;
34use pounce_linalg::dense_vector::DenseVector;
35use pounce_linalg::diag_matrix::DiagMatrix;
36use pounce_linalg::triplet::{GenTMatrix, SymTMatrix};
37use pounce_linalg::Vector;
38use pounce_linsol::{ESymSolverStatus, FactorPattern, SymLinearSolver, TSymLinearSolver};
39use std::ops::Range;
40use std::rc::Rc;
41
42pub struct StdAugSystemSolver {
44 linsol: TSymLinearSolver,
45
46 initialized: bool,
48 struct_sig: Option<(usize, usize, usize, Index, Index, Index)>,
57 n_x: Index,
58 n_s: Index,
59 n_c: Index,
60 n_d: Index,
61 dim: Index,
63
64 irn: Vec<Index>,
66 jcn: Vec<Index>,
68 vals: Vec<Number>,
70
71 w_range: Range<usize>,
73 dx_range: Range<usize>,
74 ds_range: Range<usize>,
75 jc_range: Range<usize>,
76 dc_range: Range<usize>,
77 jd_range: Range<usize>,
78 minus_i_range: Range<usize>,
79 dd_range: Range<usize>,
80
81 last_neg_evals: Index,
82 last_status: Option<ESymSolverStatus>,
83
84 have_factor: bool,
88
89 timing: Option<Rc<TimingStatistics>>,
93
94 diagnostics: Option<Rc<DiagnosticsState>>,
100}
101
102impl std::fmt::Debug for StdAugSystemSolver {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("StdAugSystemSolver")
105 .field("dim", &self.dim)
106 .field("nnz", &self.vals.len())
107 .field("initialized", &self.initialized)
108 .field("last_neg_evals", &self.last_neg_evals)
109 .field("last_status", &self.last_status)
110 .finish_non_exhaustive()
111 }
112}
113
114impl StdAugSystemSolver {
115 pub fn new(linsol: TSymLinearSolver) -> Self {
117 Self {
118 linsol,
119 initialized: false,
120 struct_sig: None,
121 n_x: 0,
122 n_s: 0,
123 n_c: 0,
124 n_d: 0,
125 dim: 0,
126 irn: Vec::new(),
127 jcn: Vec::new(),
128 vals: Vec::new(),
129 w_range: 0..0,
130 dx_range: 0..0,
131 ds_range: 0..0,
132 jc_range: 0..0,
133 dc_range: 0..0,
134 jd_range: 0..0,
135 minus_i_range: 0..0,
136 dd_range: 0..0,
137 last_neg_evals: 0,
138 last_status: None,
139 have_factor: false,
140 timing: None,
141 diagnostics: None,
142 }
143 }
144
145 fn build_structure(&mut self, coeffs: &AugSysCoeffs<'_>) -> ESymSolverStatus {
146 let n_x = coeffs.j_c.n_cols();
147 let n_c = coeffs.j_c.n_rows();
148 let n_d = coeffs.j_d.n_rows();
149 debug_assert_eq!(coeffs.j_d.n_cols(), n_x);
150 let n_s = n_d;
151
152 let w_nnz = match coeffs.w {
153 None => 0_usize,
154 Some(w) => w_nonzeros(w),
155 };
156 let jc_nnz = gen_t_downcast(coeffs.j_c).nonzeros() as usize;
157 let jd_nnz = gen_t_downcast(coeffs.j_d).nonzeros() as usize;
158
159 let total = w_nnz
160 + (n_x as usize) + (n_s as usize) + jc_nnz
163 + (n_c as usize) + jd_nnz
165 + (n_s as usize) + (n_d as usize); self.irn = Vec::with_capacity(total);
169 self.jcn = Vec::with_capacity(total);
170 self.vals = vec![0.0; total];
171
172 let w_start = self.irn.len();
174 if let Some(w) = coeffs.w {
175 if let Some(t) = w.as_any().downcast_ref::<SymTMatrix>() {
176 self.irn.extend_from_slice(t.irows());
177 self.jcn.extend_from_slice(t.jcols());
178 } else if let Some(dm) = w.as_any().downcast_ref::<DiagMatrix>() {
179 let n = w_diag_dim(dm);
182 for i in 0..n {
183 self.irn.push(i + 1);
184 self.jcn.push(i + 1);
185 }
186 } else {
187 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
188 }
189 }
190 self.w_range = w_start..self.irn.len();
191
192 let dx_start = self.irn.len();
194 for i in 0..n_x {
195 self.irn.push(i + 1);
196 self.jcn.push(i + 1);
197 }
198 self.dx_range = dx_start..self.irn.len();
199
200 let ds_start = self.irn.len();
202 for i in 0..n_s {
203 let r = n_x + i + 1;
204 self.irn.push(r);
205 self.jcn.push(r);
206 }
207 self.ds_range = ds_start..self.irn.len();
208
209 let jc_start = self.irn.len();
211 let j_c = gen_t_downcast(coeffs.j_c);
212 let row_off_c = n_x + n_s;
213 for (&i, &j) in j_c.irows().iter().zip(j_c.jcols().iter()) {
214 self.irn.push(row_off_c + i);
217 self.jcn.push(j);
218 }
219 self.jc_range = jc_start..self.irn.len();
220
221 let dc_start = self.irn.len();
223 for i in 0..n_c {
224 let r = n_x + n_s + i + 1;
225 self.irn.push(r);
226 self.jcn.push(r);
227 }
228 self.dc_range = dc_start..self.irn.len();
229
230 let jd_start = self.irn.len();
232 let j_d = gen_t_downcast(coeffs.j_d);
233 let row_off_d = n_x + n_s + n_c;
234 for (&i, &j) in j_d.irows().iter().zip(j_d.jcols().iter()) {
235 self.irn.push(row_off_d + i);
236 self.jcn.push(j);
237 }
238 self.jd_range = jd_start..self.irn.len();
239
240 let mi_start = self.irn.len();
242 for i in 0..n_s {
243 self.irn.push(n_x + n_s + n_c + i + 1);
244 self.jcn.push(n_x + i + 1);
245 }
246 self.minus_i_range = mi_start..self.irn.len();
247
248 let dd_start = self.irn.len();
250 for i in 0..n_d {
251 let r = n_x + n_s + n_c + i + 1;
252 self.irn.push(r);
253 self.jcn.push(r);
254 }
255 self.dd_range = dd_start..self.irn.len();
256
257 debug_assert_eq!(self.irn.len(), total);
258 debug_assert_eq!(self.jcn.len(), total);
259
260 self.n_x = n_x;
261 self.n_s = n_s;
262 self.n_c = n_c;
263 self.n_d = n_d;
264 self.dim = n_x + n_s + n_c + n_d;
265
266 let status = self
267 .linsol
268 .initialize_structure(self.dim, &self.irn, &self.jcn);
269 if status == ESymSolverStatus::Success {
270 self.initialized = true;
271 }
272 status
273 }
274
275 fn refill_values(&mut self, coeffs: &AugSysCoeffs<'_>) {
276 if !self.w_range.is_empty() {
278 let Some(w_dyn) = coeffs.w else {
279 unreachable!("structure pinned with W; W cannot be None now")
280 };
281 let dst = &mut self.vals[self.w_range.clone()];
282 if let Some(t) = w_dyn.as_any().downcast_ref::<SymTMatrix>() {
283 for (d, &v) in dst.iter_mut().zip(t.values().iter()) {
284 *d = coeffs.w_factor * v;
285 }
286 } else if let Some(dm) = w_dyn.as_any().downcast_ref::<DiagMatrix>() {
287 let diag = w_diag_values(dm);
288 for (d, &v) in dst.iter_mut().zip(diag.iter()) {
289 *d = coeffs.w_factor * v;
290 }
291 } else {
292 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
293 }
294 }
295 fill_diag(
297 &mut self.vals[self.dx_range.clone()],
298 coeffs.d_x,
299 coeffs.delta_x,
300 1.0,
301 );
302 fill_diag(
304 &mut self.vals[self.ds_range.clone()],
305 coeffs.d_s,
306 coeffs.delta_s,
307 1.0,
308 );
309 {
311 let j_c = gen_t_downcast(coeffs.j_c);
312 self.vals[self.jc_range.clone()].copy_from_slice(j_c.values());
313 }
314 fill_diag(
316 &mut self.vals[self.dc_range.clone()],
317 coeffs.d_c,
318 coeffs.delta_c,
319 -1.0,
320 );
321 {
323 let j_d = gen_t_downcast(coeffs.j_d);
324 self.vals[self.jd_range.clone()].copy_from_slice(j_d.values());
325 }
326 for v in &mut self.vals[self.minus_i_range.clone()] {
328 *v = -1.0;
329 }
330 fill_diag(
332 &mut self.vals[self.dd_range.clone()],
333 coeffs.d_d,
334 coeffs.delta_d,
335 -1.0,
336 );
337 }
338
339 fn pack_rhs(&self, rhs: &AugSysRhs<'_>, packed: &mut [Number]) {
340 let n_x = self.n_x as usize;
341 let n_s = self.n_s as usize;
342 let n_c = self.n_c as usize;
343 let n_d = self.n_d as usize;
344 copy_vec(rhs.rhs_x, &mut packed[..n_x]);
345 copy_vec(rhs.rhs_s, &mut packed[n_x..n_x + n_s]);
346 copy_vec(rhs.rhs_c, &mut packed[n_x + n_s..n_x + n_s + n_c]);
347 copy_vec(
348 rhs.rhs_d,
349 &mut packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d],
350 );
351 }
352
353 fn unpack_sol(&self, packed: &[Number], sol: &mut AugSysSol<'_>) {
354 let n_x = self.n_x as usize;
355 let n_s = self.n_s as usize;
356 let n_c = self.n_c as usize;
357 let n_d = self.n_d as usize;
358 write_vec(sol.sol_x, &packed[..n_x]);
359 write_vec(sol.sol_s, &packed[n_x..n_x + n_s]);
360 write_vec(sol.sol_c, &packed[n_x + n_s..n_x + n_s + n_c]);
361 write_vec(sol.sol_d, &packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d]);
362 }
363}
364
365impl AugSystemSolver for StdAugSystemSolver {
366 fn provides_inertia(&self) -> bool {
367 self.linsol.provides_inertia()
368 }
369
370 fn number_of_neg_evals(&self) -> Index {
371 self.last_neg_evals
372 }
373
374 fn system_dim(&self) -> Index {
375 self.dim
376 }
377
378 fn kkt_triplets(&self) -> Option<(Index, Vec<Index>, Vec<Index>, Vec<Number>)> {
379 if self.irn.is_empty() {
380 return None;
381 }
382 Some((
383 self.dim,
384 self.irn.clone(),
385 self.jcn.clone(),
386 self.vals.clone(),
387 ))
388 }
389
390 fn l_factor(&self, want_values: bool) -> Option<FactorPattern> {
391 self.linsol.factor_pattern(want_values)
392 }
393
394 fn increase_quality(&mut self) -> bool {
395 self.have_factor = false;
399 self.linsol.increase_quality()
400 }
401
402 fn last_solve_status(&self) -> ESymSolverStatus {
403 self.last_status.unwrap_or(ESymSolverStatus::FatalError)
404 }
405
406 fn solve(
407 &mut self,
408 coeffs: &AugSysCoeffs<'_>,
409 rhs: &AugSysRhs<'_>,
410 sol: &mut AugSysSol<'_>,
411 check_neg_evals: bool,
412 num_neg_evals: Index,
413 ) -> ESymSolverStatus {
414 let sig = {
422 let w_nnz = coeffs.w.map(w_nonzeros).unwrap_or(0);
423 let jc_nnz = gen_t_downcast(coeffs.j_c).nonzeros() as usize;
424 let jd_nnz = gen_t_downcast(coeffs.j_d).nonzeros() as usize;
425 (
426 w_nnz,
427 jc_nnz,
428 jd_nnz,
429 coeffs.j_c.n_cols(),
430 coeffs.j_c.n_rows(),
431 coeffs.j_d.n_rows(),
432 )
433 };
434 if !self.initialized || self.struct_sig != Some(sig) {
435 let s = self.build_structure(coeffs);
436 if s != ESymSolverStatus::Success {
437 self.last_status = Some(s);
438 return s;
439 }
440 self.struct_sig = Some(sig);
441 }
442 self.refill_values(coeffs);
443
444 let mut packed = vec![0.0; self.dim as usize];
445 self.pack_rhs(rhs, &mut packed);
446
447 let dump_rhs = packed.clone();
448
449 let _factor_guard = self
453 .timing
454 .as_deref()
455 .map(|t| t.linear_system_factorization.guard());
456 let status = self.linsol.multi_solve(
457 &self.vals,
458 true,
459 1,
460 &mut packed,
461 check_neg_evals,
462 num_neg_evals,
463 );
464 drop(_factor_guard);
465 self.last_status = Some(status);
466 if self.linsol.provides_inertia()
478 && matches!(
479 status,
480 ESymSolverStatus::Success
481 | ESymSolverStatus::WrongInertia
482 | ESymSolverStatus::Singular
483 )
484 {
485 self.last_neg_evals = self.linsol.number_of_neg_evals();
486 }
487 if status == ESymSolverStatus::Success {
488 self.unpack_sol(&packed, sol);
489 self.have_factor = true;
490 }
491
492 if let Some(diag) = self.diagnostics.clone() {
497 if diag.want(DiagCategory::Kkt) {
498 let solve_idx = diag.next_solve_index();
499 let filename = format!("kkt_solve_{solve_idx:03}.jsonl");
500 let variant = diag.config.kkt_variant;
507 let factor_pattern =
508 if status == ESymSolverStatus::Success && variant.wants_l_pattern() {
509 self.linsol.factor_pattern(variant.wants_l_values())
510 } else {
511 None
512 };
513 if let Some(mut w) = diag.open_writer(&filename) {
514 let _ = write_kkt_record(
515 &mut w,
516 self.dim,
517 &self.irn,
518 &self.jcn,
519 &self.vals,
520 &dump_rhs,
521 &packed,
522 check_neg_evals,
523 num_neg_evals,
524 status,
525 self.last_neg_evals,
526 factor_pattern.as_ref(),
527 );
528 }
529 }
530 }
531 if let Ok(path) = std::env::var("POUNCE_DUMP_KKT") {
532 use std::sync::atomic::{AtomicBool, Ordering};
533 static WARNED: AtomicBool = AtomicBool::new(false);
534 if !WARNED.swap(true, Ordering::SeqCst) {
535 tracing::warn!(target: "pounce::linsol",
536 "warning: POUNCE_DUMP_KKT is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
537 );
538 }
539 dump_kkt(
540 &path,
541 self.dim,
542 &self.irn,
543 &self.jcn,
544 &self.vals,
545 &dump_rhs,
546 &packed,
547 check_neg_evals,
548 num_neg_evals,
549 status,
550 self.last_neg_evals,
551 );
552 }
553
554 status
555 }
556
557 fn resolve(
558 &mut self,
559 coeffs: &AugSysCoeffs<'_>,
560 rhs: &AugSysRhs<'_>,
561 sol: &mut AugSysSol<'_>,
562 ) -> ESymSolverStatus {
563 if !self.have_factor {
570 return self.solve(coeffs, rhs, sol, false, 0);
571 }
572
573 let mut packed = vec![0.0; self.dim as usize];
574 self.pack_rhs(rhs, &mut packed);
575
576 let _back_guard = self
579 .timing
580 .as_deref()
581 .map(|t| t.linear_system_back_solve.guard());
582 let status = self
583 .linsol
584 .multi_solve(&self.vals, false, 1, &mut packed, false, 0);
585 drop(_back_guard);
586 self.last_status = Some(status);
587 if status == ESymSolverStatus::Success {
588 self.unpack_sol(&packed, sol);
589 }
590 status
591 }
592
593 fn set_diagnostics(&mut self, diag: Rc<DiagnosticsState>) {
594 self.diagnostics = Some(diag);
595 }
596
597 fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
598 self.timing = Some(timing);
599 }
600
601 fn try_resolve_many_flat(
602 &mut self,
603 _coeffs: &AugSysCoeffs<'_>,
604 packed_rhs: &mut [Number],
605 nrhs: usize,
606 ) -> Option<ESymSolverStatus> {
607 if !self.have_factor {
612 return None;
613 }
614 if packed_rhs.len() != (self.dim as usize) * nrhs {
615 return Some(ESymSolverStatus::FatalError);
616 }
617 let _back_guard = self
618 .timing
619 .as_deref()
620 .map(|t| t.linear_system_back_solve.guard());
621 let status =
622 self.linsol
623 .multi_solve(&self.vals, false, nrhs as Index, packed_rhs, false, 0);
624 drop(_back_guard);
625 self.last_status = Some(status);
626 Some(status)
627 }
628}
629
630#[allow(clippy::too_many_arguments)]
633fn write_kkt_record(
638 w: &mut dyn std::io::Write,
639 dim: Index,
640 irn: &[Index],
641 jcn: &[Index],
642 vals: &[Number],
643 rhs: &[Number],
644 sol: &[Number],
645 check_neg_evals: bool,
646 num_neg_evals: Index,
647 status: ESymSolverStatus,
648 last_neg_evals: Index,
649 factor_pattern: Option<&FactorPattern>,
650) -> std::io::Result<()> {
651 use std::fmt::Write as _;
652
653 let mut line = String::with_capacity(64 * vals.len());
654 line.push('{');
655 let _ = write!(line, "\"n\":{dim},");
656 let _ = write!(line, "\"check_neg_evals\":{check_neg_evals},");
657 let _ = write!(line, "\"num_neg_evals_expected\":{num_neg_evals},");
658 let _ = write!(line, "\"num_neg_evals_actual\":{last_neg_evals},");
659 let _ = write!(line, "\"status\":\"{status:?}\",");
660
661 line.push_str("\"irn\":[");
662 for (i, v) in irn.iter().enumerate() {
663 if i > 0 {
664 line.push(',');
665 }
666 let _ = write!(line, "{v}");
667 }
668 line.push_str("],\"jcn\":[");
669 for (i, v) in jcn.iter().enumerate() {
670 if i > 0 {
671 line.push(',');
672 }
673 let _ = write!(line, "{v}");
674 }
675 line.push_str("],\"vals\":[");
676 for (i, v) in vals.iter().enumerate() {
677 if i > 0 {
678 line.push(',');
679 }
680 let _ = write!(line, "{v:.17e}");
681 }
682 line.push_str("],\"rhs\":[");
683 for (i, v) in rhs.iter().enumerate() {
684 if i > 0 {
685 line.push(',');
686 }
687 let _ = write!(line, "{v:.17e}");
688 }
689 line.push_str("],\"sol\":[");
690 for (i, v) in sol.iter().enumerate() {
691 if i > 0 {
692 line.push(',');
693 }
694 let _ = write!(line, "{v:.17e}");
695 }
696 line.push(']');
697
698 if let Some(fp) = factor_pattern {
703 line.push_str(",\"L_irn\":[");
704 for (i, v) in fp.l_irn.iter().enumerate() {
705 if i > 0 {
706 line.push(',');
707 }
708 let _ = write!(line, "{v}");
709 }
710 line.push_str("],\"L_jcn\":[");
711 for (i, v) in fp.l_jcn.iter().enumerate() {
712 if i > 0 {
713 line.push(',');
714 }
715 let _ = write!(line, "{v}");
716 }
717 line.push_str("],\"perm\":[");
718 for (i, v) in fp.perm.iter().enumerate() {
719 if i > 0 {
720 line.push(',');
721 }
722 let _ = write!(line, "{v}");
723 }
724 line.push(']');
725 if let Some(vals) = fp.l_vals.as_ref() {
726 line.push_str(",\"L_vals\":[");
727 for (i, v) in vals.iter().enumerate() {
728 if i > 0 {
729 line.push(',');
730 }
731 let _ = write!(line, "{v:.17e}");
732 }
733 line.push(']');
734 }
735 }
736
737 line.push_str("}\n");
738
739 w.write_all(line.as_bytes())
740}
741
742fn dump_kkt(
743 path: &str,
744 dim: Index,
745 irn: &[Index],
746 jcn: &[Index],
747 vals: &[Number],
748 rhs: &[Number],
749 sol: &[Number],
750 check_neg_evals: bool,
751 num_neg_evals: Index,
752 status: ESymSolverStatus,
753 last_neg_evals: Index,
754) {
755 if let Ok(mut f) = std::fs::OpenOptions::new()
756 .create(true)
757 .append(true)
758 .open(path)
759 {
760 let _ = write_kkt_record(
761 &mut f,
762 dim,
763 irn,
764 jcn,
765 vals,
766 rhs,
767 sol,
768 check_neg_evals,
769 num_neg_evals,
770 status,
771 last_neg_evals,
772 None, );
774 }
775}
776
777fn w_nonzeros(w: &dyn pounce_linalg::SymMatrix) -> usize {
782 if let Some(t) = w.as_any().downcast_ref::<SymTMatrix>() {
783 t.nonzeros() as usize
784 } else if let Some(dm) = w.as_any().downcast_ref::<DiagMatrix>() {
785 w_diag_dim(dm) as usize
786 } else {
787 unreachable!("StdAugSystemSolver: W must be a SymTMatrix or DiagMatrix in v1.0")
788 }
789}
790
791fn w_diag_dim(dm: &DiagMatrix) -> Index {
792 dm.get_diag()
793 .expect("DiagMatrix W has no diagonal set")
794 .dim()
795}
796
797fn w_diag_values(dm: &DiagMatrix) -> Vec<Number> {
798 let diag = dm.get_diag().expect("DiagMatrix W has no diagonal set");
799 diag.as_any()
800 .downcast_ref::<DenseVector>()
801 .expect("StdAugSystemSolver: DiagMatrix W diagonal must be DenseVector in v1.0")
802 .expanded_values()
803}
804
805fn gen_t_downcast(m: &dyn pounce_linalg::Matrix) -> &GenTMatrix {
806 let Some(t) = m.as_any().downcast_ref::<GenTMatrix>() else {
807 unreachable!("StdAugSystemSolver: J_c / J_d must be GenTMatrix in v1.0")
808 };
809 t
810}
811
812fn flat_read(v: &dyn Vector) -> Vec<Number> {
818 if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
819 return dv.expanded_values();
820 }
821 if let Some(cv) = v.as_any().downcast_ref::<CompoundVector>() {
822 let mut out = Vec::with_capacity(cv.dim() as usize);
823 for k in 0..cv.n_comps() {
824 let blk = cv.comp(k);
825 let dblk = blk
826 .as_any()
827 .downcast_ref::<DenseVector>()
828 .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
829 out.extend_from_slice(&dblk.expanded_values());
830 }
831 return out;
832 }
833 unreachable!("StdAugSystemSolver: D_*/rhs/sol must be DenseVector or CompoundVector of DenseVectors in v1.0")
834}
835
836fn flat_write(dst: &mut dyn Vector, src: &[Number]) {
838 if let Some(dv) = dst.as_any_mut().downcast_mut::<DenseVector>() {
839 dv.set_values(src);
840 return;
841 }
842 if let Some(cv) = dst.as_any_mut().downcast_mut::<CompoundVector>() {
843 let mut off = 0usize;
844 for k in 0..cv.n_comps() {
845 let blk = cv.comp_mut(k);
846 let dim = blk.dim() as usize;
847 let dblk = blk
848 .as_any_mut()
849 .downcast_mut::<DenseVector>()
850 .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
851 dblk.set_values(&src[off..off + dim]);
852 off += dim;
853 }
854 return;
855 }
856 unreachable!(
857 "StdAugSystemSolver: sol must be DenseVector or CompoundVector of DenseVectors in v1.0"
858 )
859}
860
861fn fill_diag(dst: &mut [Number], d: Option<&dyn Vector>, delta: Number, sign: Number) {
864 match d {
865 None => {
866 for v in dst.iter_mut() {
867 *v = sign * delta;
868 }
869 }
870 Some(d) => {
871 let xs = flat_read(d);
872 debug_assert_eq!(xs.len(), dst.len());
873 for (out, &x) in dst.iter_mut().zip(xs.iter()) {
874 *out = sign * (x + delta);
875 }
876 }
877 }
878}
879
880fn copy_vec(src: &dyn Vector, dst: &mut [Number]) {
881 let xs = flat_read(src);
882 debug_assert_eq!(xs.len(), dst.len());
883 dst.copy_from_slice(&xs);
884}
885
886fn write_vec(dst: &mut dyn Vector, src: &[Number]) {
887 flat_write(dst, src);
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893 use pounce_common::types::{Index, Number};
894 use pounce_linalg::dense_vector::DenseVectorSpace;
895 use pounce_linalg::triplet::{GenTMatrixSpace, SymTMatrixSpace};
896 use pounce_linsol::sparse_sym_iface::SparseSymLinearSolverInterface;
897 use pounce_linsol::EMatrixFormat;
898
899 struct DenseMock {
902 dim: Index,
903 nz: Index,
904 a: Vec<Number>,
905 last_factor: Vec<Number>, neg_evals: Index,
907 }
908
909 impl DenseMock {
910 fn new() -> Self {
911 Self {
912 dim: 0,
913 nz: 0,
914 a: Vec::new(),
915 last_factor: Vec::new(),
916 neg_evals: 0,
917 }
918 }
919 }
920
921 impl SparseSymLinearSolverInterface for DenseMock {
922 fn initialize_structure(
923 &mut self,
924 dim: Index,
925 nz: Index,
926 _ia: &[Index],
927 _ja: &[Index],
928 ) -> ESymSolverStatus {
929 self.dim = dim;
930 self.nz = nz;
931 self.a = vec![0.0; nz as usize];
932 ESymSolverStatus::Success
933 }
934 fn values_array_mut(&mut self) -> &mut [Number] {
935 &mut self.a
936 }
937 fn multi_solve(
938 &mut self,
939 new_matrix: bool,
940 ia: &[Index],
941 ja: &[Index],
942 nrhs: Index,
943 rhs_vals: &mut [Number],
944 _check: bool,
945 _nev: Index,
946 ) -> ESymSolverStatus {
947 let n = self.dim as usize;
948 if new_matrix {
949 let mut dense = vec![0.0; n * n];
952 for k in 0..self.nz as usize {
953 let i = (ia[k] - 1) as usize;
954 let j = (ja[k] - 1) as usize;
955 dense[i * n + j] += self.a[k];
956 if i != j {
957 dense[j * n + i] += self.a[k];
958 }
959 }
960 self.last_factor = dense;
961 }
962 for col in 0..nrhs as usize {
964 let mut a = self.last_factor.clone();
965 let b = &mut rhs_vals[col * n..col * n + n];
966 let mut neg = 0_i32;
967 for k in 0..n {
968 let mut piv = k;
970 let mut piv_abs = a[k * n + k].abs();
971 for r in (k + 1)..n {
972 let av = a[r * n + k].abs();
973 if av > piv_abs {
974 piv_abs = av;
975 piv = r;
976 }
977 }
978 if piv != k {
979 for c in 0..n {
980 a.swap(k * n + c, piv * n + c);
981 }
982 b.swap(k, piv);
983 }
984 let p = a[k * n + k];
985 if p.abs() < 1e-14 {
986 return ESymSolverStatus::Singular;
987 }
988 if p < 0.0 {
989 neg += 1;
990 }
991 for r in (k + 1)..n {
992 let f = a[r * n + k] / p;
993 for c in k..n {
994 a[r * n + c] -= f * a[k * n + c];
995 }
996 b[r] -= f * b[k];
997 }
998 }
999 for k in (0..n).rev() {
1001 let mut s = b[k];
1002 for c in (k + 1)..n {
1003 s -= a[k * n + c] * b[c];
1004 }
1005 b[k] = s / a[k * n + k];
1006 }
1007 self.neg_evals = neg;
1008 }
1009 ESymSolverStatus::Success
1010 }
1011 fn number_of_neg_evals(&self) -> Index {
1012 self.neg_evals
1013 }
1014 fn increase_quality(&mut self) -> bool {
1015 false
1016 }
1017 fn provides_inertia(&self) -> bool {
1018 true
1019 }
1020 fn matrix_format(&self) -> EMatrixFormat {
1021 EMatrixFormat::TripletFormat
1022 }
1023 }
1024
1025 #[test]
1037 fn solves_5x5_kkt_through_dense_mock() {
1038 let w_space = SymTMatrixSpace::new(2, vec![1, 2], vec![1, 2]);
1040 let mut w = SymTMatrix::new(w_space);
1041 w.set_values(&[2.0, 3.0]);
1042
1043 let jc_space = GenTMatrixSpace::new(1, 2, vec![1, 1], vec![1, 2]);
1045 let mut j_c = GenTMatrix::new(jc_space);
1046 j_c.set_values(&[1.0, 1.0]);
1047
1048 let jd_space = GenTMatrixSpace::new(1, 2, vec![1], vec![1]);
1050 let mut j_d = GenTMatrix::new(jd_space);
1051 j_d.set_values(&[1.0]);
1052
1053 let s_space = DenseVectorSpace::new(1);
1055 let mut d_s = s_space.make_new_dense();
1056 d_s.set_values(&[1.0]);
1057
1058 let xs = DenseVectorSpace::new(2);
1066 let mut rx = xs.make_new_dense();
1067 rx.set_values(&[4.0, 4.0]);
1068 let mut rs = s_space.make_new_dense();
1069 rs.set_values(&[0.0]);
1070 let cs = DenseVectorSpace::new(1);
1071 let mut rc = cs.make_new_dense();
1072 rc.set_values(&[2.0]);
1073 let ds_space = DenseVectorSpace::new(1);
1074 let mut rd = ds_space.make_new_dense();
1075 rd.set_values(&[0.0]);
1076
1077 let mut sx = xs.make_new_dense();
1078 let mut ss = s_space.make_new_dense();
1079 let mut sc = cs.make_new_dense();
1080 let mut sd = ds_space.make_new_dense();
1081
1082 let linsol = TSymLinearSolver::new(Box::new(DenseMock::new()), None, false);
1083 let mut solver = StdAugSystemSolver::new(linsol);
1084
1085 let coeffs = AugSysCoeffs {
1086 w: Some(&w),
1087 w_factor: 1.0,
1088 d_x: None,
1089 delta_x: 0.0,
1090 d_s: Some(&d_s),
1091 delta_s: 0.0,
1092 j_c: &j_c,
1093 d_c: None,
1094 delta_c: 0.0,
1095 j_d: &j_d,
1096 d_d: None,
1097 delta_d: 0.0,
1098 };
1099 let rhs = AugSysRhs {
1100 rhs_x: &rx,
1101 rhs_s: &rs,
1102 rhs_c: &rc,
1103 rhs_d: &rd,
1104 };
1105 let mut sol = AugSysSol {
1106 sol_x: &mut sx,
1107 sol_s: &mut ss,
1108 sol_c: &mut sc,
1109 sol_d: &mut sd,
1110 };
1111 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1112 assert_eq!(status, ESymSolverStatus::Success);
1113
1114 for v in sx.values() {
1115 assert!((v - 1.0).abs() < 1e-10, "sol_x = {v}");
1116 }
1117 for v in ss.values() {
1118 assert!((v - 1.0).abs() < 1e-10, "sol_s = {v}");
1119 }
1120 for v in sc.values() {
1121 assert!((v - 1.0).abs() < 1e-10, "sol_c = {v}");
1122 }
1123 for v in sd.values() {
1124 assert!((v - 1.0).abs() < 1e-10, "sol_d = {v}");
1125 }
1126 }
1127
1128 #[test]
1138 fn lowrank_smw_matches_dense_w_on_constrained_system() {
1139 use crate::kkt::low_rank_aug_system_solver::LowRankAugSystemSolver;
1140 use pounce_linalg::diag_matrix::DiagMatrix;
1141 use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
1142 use pounce_linalg::multi_vector_matrix::MultiVectorMatrixSpace;
1143
1144 let n = 4usize;
1145 let sigma = 2.0;
1146 let vcols = [
1151 vec![0.6, 0.1, -0.2, 0.3],
1152 vec![0.2, 0.5, 0.1, -0.1],
1153 vec![-0.1, 0.2, 0.4, 0.2],
1154 vec![0.3, -0.2, 0.1, 0.4],
1155 vec![0.15, 0.25, -0.3, 0.1],
1156 vec![-0.2, 0.1, 0.2, 0.35],
1157 ];
1158 let ucols = [
1159 vec![0.3, -0.1, 0.2, 0.1],
1160 vec![0.1, 0.3, -0.2, 0.2],
1161 vec![0.2, 0.1, 0.1, -0.3],
1162 vec![-0.1, 0.2, 0.15, 0.1],
1163 vec![0.25, -0.15, 0.1, 0.2],
1164 vec![0.1, 0.2, -0.25, 0.15],
1165 ];
1166 let mut wfull = vec![0.0_f64; n * n];
1168 for i in 0..n {
1169 wfull[i * n + i] = sigma;
1170 }
1171 for c in vcols.iter() {
1172 for i in 0..n {
1173 for j in 0..n {
1174 wfull[i * n + j] += c[i] * c[j];
1175 }
1176 }
1177 }
1178 for c in ucols.iter() {
1179 for i in 0..n {
1180 for j in 0..n {
1181 wfull[i * n + j] -= c[i] * c[j];
1182 }
1183 }
1184 }
1185
1186 let make_jc = || {
1188 let sp = GenTMatrixSpace::new(1, 4, vec![1, 1, 1, 1], vec![1, 2, 3, 4]);
1189 let mut m = GenTMatrix::new(sp);
1190 m.set_values(&[1.0, 1.0, 1.0, 1.0]);
1191 m
1192 };
1193 let make_jd = || {
1197 let sp = GenTMatrixSpace::new(1, 4, vec![1, 1], vec![1, 3]);
1198 let mut m = GenTMatrix::new(sp);
1199 m.set_values(&[1.0, 1.0]);
1200 m
1201 };
1202
1203 let xs = DenseVectorSpace::new(4);
1204 let cs = DenseVectorSpace::new(1);
1205 let mk = |sp: &Rc<DenseVectorSpace>, vals: &[Number]| {
1206 let mut d = sp.make_new_dense();
1207 d.set_values(vals);
1208 d
1209 };
1210
1211 let solve_with = |w: &dyn pounce_linalg::SymMatrix,
1212 aug: &mut dyn AugSystemSolver|
1213 -> (Vec<Number>, Vec<Number>) {
1214 let j_c = make_jc();
1215 let j_d = make_jd();
1216 let rx = mk(&xs, &[1.0, 2.0, -1.0, 0.5]);
1217 let rs = mk(&cs, &[0.4]);
1218 let rc = mk(&cs, &[3.0]);
1219 let rd = mk(&cs, &[0.7]);
1220 let mut sx = mk(&xs, &[0.0, 0.0, 0.0, 0.0]);
1221 let mut ss = mk(&cs, &[0.0]);
1222 let mut sc = mk(&cs, &[0.0]);
1223 let mut sd = mk(&cs, &[0.0]);
1224 let d_s = mk(&cs, &[1.5]);
1225 let coeffs = AugSysCoeffs {
1226 w: Some(w),
1227 w_factor: 1.0,
1228 d_x: None,
1229 delta_x: 0.0,
1230 d_s: Some(&d_s),
1231 delta_s: 0.0,
1232 j_c: &j_c,
1233 d_c: None,
1234 delta_c: 0.0,
1235 j_d: &j_d,
1236 d_d: None,
1237 delta_d: 0.0,
1238 };
1239 let rhs = AugSysRhs {
1240 rhs_x: &rx,
1241 rhs_s: &rs,
1242 rhs_c: &rc,
1243 rhs_d: &rd,
1244 };
1245 let mut sol = AugSysSol {
1246 sol_x: &mut sx,
1247 sol_s: &mut ss,
1248 sol_c: &mut sc,
1249 sol_d: &mut sd,
1250 };
1251 let status = aug.solve(&coeffs, &rhs, &mut sol, false, 1);
1252 assert_eq!(status, ESymSolverStatus::Success);
1253 (sx.expanded_values(), sc.expanded_values())
1254 };
1255
1256 let mut wi = Vec::new();
1258 let mut wj = Vec::new();
1259 let mut wv = Vec::new();
1260 for i in 0..n {
1261 for j in 0..=i {
1262 wi.push(i as Index + 1);
1263 wj.push(j as Index + 1);
1264 wv.push(wfull[i * n + j]);
1265 }
1266 }
1267 let w_space = SymTMatrixSpace::new(4, wi, wj);
1268 let mut w_dense = SymTMatrix::new(w_space);
1269 w_dense.set_values(&wv);
1270 let mut std_solver = StdAugSystemSolver::new(TSymLinearSolver::new(
1271 Box::new(pounce_feral::FeralSolverInterface::new()),
1272 None,
1273 false,
1274 ));
1275 let (ref_x, ref_c) = solve_with(&w_dense, &mut std_solver);
1276
1277 let lr_space = LowRankUpdateSymMatrixSpace::new(4, None, false);
1279 let mut lr = lr_space.make_new_low_rank();
1280 let mut diag = xs.make_new_dense();
1281 diag.set_values(&[sigma; 4]);
1282 lr.set_diag(Rc::new(diag) as Rc<dyn Vector>);
1283 let build_mvm = |cols: &[Vec<Number>]| {
1284 let sp = MultiVectorMatrixSpace::new(cols.len() as Index, Rc::clone(&xs));
1285 let mut mvm = sp.make_new_multi_vector();
1286 for (k, c) in cols.iter().enumerate() {
1287 let mut cv = xs.make_new_dense();
1288 cv.set_values(c);
1289 mvm.set_vector(k as Index, Rc::new(cv) as Rc<dyn Vector>);
1290 }
1291 mvm
1292 };
1293 lr.set_v(Rc::new(build_mvm(&vcols)));
1294 lr.set_u(Rc::new(build_mvm(&ucols)));
1295 let _ = DiagMatrix::new(4); let mut lr_solver =
1298 LowRankAugSystemSolver::new(Box::new(StdAugSystemSolver::new(TSymLinearSolver::new(
1299 Box::new(pounce_feral::FeralSolverInterface::new()),
1300 None,
1301 false,
1302 ))));
1303 let (lr_x, lr_c) = solve_with(&lr, &mut lr_solver);
1304
1305 for (a, b) in ref_x.iter().zip(lr_x.iter()) {
1306 assert!((a - b).abs() < 1e-9, "sol_x mismatch: dense={a} smw={b}");
1307 }
1308 for (a, b) in ref_c.iter().zip(lr_c.iter()) {
1309 assert!((a - b).abs() < 1e-9, "sol_c mismatch: dense={a} smw={b}");
1310 }
1311 }
1312}