Skip to main content

pounce_linsol/
t_sym_solver.rs

1//! Algorithm-side wrapper that drives a sparse symmetric backend.
2//!
3//! Port of `Algorithm/LinearSolvers/IpTSymLinearSolver.{hpp,cpp}` from
4//! Ipopt 3.14.x. This is the layer between an algorithm's "give me a
5//! `SymMatrix` plus RHS, return the solution" expectation and the
6//! per-backend [`SparseSymLinearSolverInterface`] contract.
7//!
8//! Responsibilities, mirroring upstream:
9//!
10//! * Marshal a triplet `(airn, ajcn, vals)` matrix into the layout the
11//!   backend declared via [`SparseSymLinearSolverInterface::matrix_format`].
12//! * If a [`TSymScalingMethod`] is configured, compute symmetric
13//!   scaling factors `s` once per refactor and apply
14//!   `A' = diag(s) A diag(s)` / `b' = diag(s) b` / `x = diag(s) x'`.
15//! * Drive the backend's `CALL_AGAIN` retry loop (MA57 grow case).
16//! * Forward `IncreaseQuality` to the backend, with the upstream
17//!   "switch on linear-system scaling on demand" optimization.
18//!
19//! Tag-based change detection (`SymMatrix::HasChanged`) is *not* part
20//! of this wrapper; callers in the Phase-6 KKT layer pass `new_matrix`
21//! explicitly. That keeps Phase-4 self-contained while leaving the
22//! upstream semantics intact.
23
24use crate::scaling::TSymScalingMethod;
25use crate::sparse_sym_iface::{EMatrixFormat, FactorPattern, SparseSymLinearSolverInterface};
26use crate::status::ESymSolverStatus;
27use crate::sym_solver::SymLinearSolver;
28use pounce_common::types::{Index, Number};
29use pounce_linalg::triplet_convert::{TriFull, TripletToCsrConverter};
30
31/// Driver wrapping a [`SparseSymLinearSolverInterface`] (and optionally
32/// a [`TSymScalingMethod`]).
33pub struct TSymLinearSolver {
34    backend: Box<dyn SparseSymLinearSolverInterface>,
35    scaling_method: Option<Box<dyn TSymScalingMethod>>,
36    matrix_format: EMatrixFormat,
37    converter: Option<TripletToCsrConverter>,
38
39    /// `true` once [`Self::initialize_structure`] has succeeded.
40    initialized: bool,
41    /// `true` once the row/column index arrays are populated (= ditto
42    /// in this port; see upstream's `have_structure_` for warm-start
43    /// semantics).
44    have_structure: bool,
45    /// `true` if the wrapper should currently apply scaling.
46    use_scaling: bool,
47    /// Set by [`Self::increase_quality`] when scaling-on-demand fires;
48    /// triggers a one-shot scaling-factor recompute on the next solve.
49    just_switched_on_scaling: bool,
50    /// Mirrors `linear_scaling_on_demand`. `true` keeps scaling off
51    /// until `increase_quality` switches it on; `false` scales every
52    /// refactor.
53    linear_scaling_on_demand: bool,
54
55    dim: Index,
56    nonzeros_triplet: Index,
57    nonzeros_compressed: Index,
58
59    /// 1-based row indices, one per triplet entry.
60    airn: Vec<Index>,
61    /// 1-based column indices, one per triplet entry.
62    ajcn: Vec<Index>,
63    /// Per-row symmetric scaling factors (length `dim`). Empty unless
64    /// a scaling method is configured.
65    scaling_factors: Vec<Number>,
66}
67
68impl std::fmt::Debug for TSymLinearSolver {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("TSymLinearSolver")
71            .field("matrix_format", &self.matrix_format)
72            .field("dim", &self.dim)
73            .field("nonzeros_triplet", &self.nonzeros_triplet)
74            .field("nonzeros_compressed", &self.nonzeros_compressed)
75            .field("use_scaling", &self.use_scaling)
76            .field("initialized", &self.initialized)
77            .finish_non_exhaustive()
78    }
79}
80
81impl TSymLinearSolver {
82    /// Build a driver around `backend`. Pass `Some(scaling)` to enable
83    /// symmetric scaling. `linear_scaling_on_demand=true` matches
84    /// upstream's default and keeps scaling off until
85    /// [`Self::increase_quality`] turns it on.
86    pub fn new(
87        backend: Box<dyn SparseSymLinearSolverInterface>,
88        scaling_method: Option<Box<dyn TSymScalingMethod>>,
89        linear_scaling_on_demand: bool,
90    ) -> Self {
91        let matrix_format = backend.matrix_format();
92        let converter = match matrix_format {
93            EMatrixFormat::TripletFormat => None,
94            EMatrixFormat::CsrFormat0Offset => {
95                Some(TripletToCsrConverter::new(0, TriFull::Triangular))
96            }
97            EMatrixFormat::CsrFormat1Offset => {
98                Some(TripletToCsrConverter::new(1, TriFull::Triangular))
99            }
100            EMatrixFormat::CsrFullFormat0Offset => {
101                Some(TripletToCsrConverter::new(0, TriFull::Full))
102            }
103            EMatrixFormat::CsrFullFormat1Offset => {
104                Some(TripletToCsrConverter::new(1, TriFull::Full))
105            }
106        };
107        let use_scaling = scaling_method.is_some() && !linear_scaling_on_demand;
108        Self {
109            backend,
110            scaling_method,
111            matrix_format,
112            converter,
113            initialized: false,
114            have_structure: false,
115            use_scaling,
116            just_switched_on_scaling: false,
117            linear_scaling_on_demand,
118            dim: 0,
119            nonzeros_triplet: 0,
120            nonzeros_compressed: 0,
121            airn: Vec::new(),
122            ajcn: Vec::new(),
123            scaling_factors: Vec::new(),
124        }
125    }
126
127    /// Pin the triplet sparsity pattern. Must be called once before
128    /// the first [`Self::multi_solve`]. `airn` / `ajcn` are 1-based.
129    /// Mirrors the bulk of `TSymLinearSolver::InitializeStructure`.
130    pub fn initialize_structure(
131        &mut self,
132        dim: Index,
133        airn: &[Index],
134        ajcn: &[Index],
135    ) -> ESymSolverStatus {
136        assert_eq!(airn.len(), ajcn.len());
137        let nz = airn.len() as Index;
138        self.dim = dim;
139        self.nonzeros_triplet = nz;
140        self.airn = airn.to_vec();
141        self.ajcn = ajcn.to_vec();
142
143        let (ia, ja, nonzeros) = match self.converter.as_mut() {
144            None => (&self.airn[..], &self.ajcn[..], self.nonzeros_triplet),
145            Some(conv) => {
146                let nonzeros_compressed = conv.initialize(self.dim, &self.airn, &self.ajcn);
147                self.nonzeros_compressed = nonzeros_compressed;
148                (conv.ia(), conv.ja(), nonzeros_compressed)
149            }
150        };
151        let status = self.backend.initialize_structure(dim, nonzeros, ia, ja);
152        if status != ESymSolverStatus::Success {
153            return status;
154        }
155        if self.scaling_method.is_some() {
156            self.scaling_factors = vec![0.0; dim as usize];
157        }
158        self.have_structure = true;
159        self.initialized = true;
160        status
161    }
162
163    /// Solve `A x = b` (or multiple RHS).
164    ///
165    /// `vals` is the new triplet-format value array (length
166    /// `nonzeros_triplet`). `new_matrix=true` requests a refactor
167    /// (and a fresh scaling-factor computation if scaling is on);
168    /// `new_matrix=false` reuses the existing factor and just runs
169    /// back-substitution.
170    ///
171    /// `rhs_vals` packs `nrhs` columns, each length `dim`, in
172    /// column-major layout. Solutions overwrite `rhs_vals`.
173    #[allow(clippy::too_many_arguments)]
174    pub fn multi_solve(
175        &mut self,
176        vals: &[Number],
177        new_matrix: bool,
178        nrhs: Index,
179        rhs_vals: &mut [Number],
180        check_neg_evals: bool,
181        number_of_neg_evals: Index,
182    ) -> ESymSolverStatus {
183        debug_assert!(self.initialized);
184        debug_assert_eq!(vals.len(), self.nonzeros_triplet as usize);
185        debug_assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
186
187        // One-shot KKT dump for backend-comparison testing. Triggered
188        // when POUNCE_DBG_KKT_DUMP is set to a file path; writes one
189        // binary record (dim, nnz, nrhs, ia[], ja[], vals[], rhs[]) on
190        // the Nth multi_solve call (N = POUNCE_DBG_KKT_DUMP_SKIP, default 0),
191        // then disables itself.
192        //
193        // DEPRECATED: superseded by the unified `--dump kkt:<spec>` CLI
194        // surface (see `pounce_common::diagnostics`). Kept for the
195        // existing offline FERAL/MA57/LAPACK binary-comparison tool;
196        // the env var emits a one-shot warning on first observation.
197        {
198            use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
199            static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
200            static DUMPED: AtomicBool = AtomicBool::new(false);
201            let n_call = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
202            let skip: usize = std::env::var("POUNCE_DBG_KKT_DUMP_SKIP")
203                .ok()
204                .and_then(|s| s.parse().ok())
205                .unwrap_or(0);
206            // The trigger path is only *read* (never mutated), and the
207            // one-shot "disable after firing" is an atomic claim via
208            // `claim_kkt_dump`. This replaces the previous
209            // `unsafe std::env::remove_var` step, whose SAFETY note
210            // ("single-threaded … main IPM thread") was false: pounce-feral
211            // runs rayon-parallel outer solves (feral lib.rs:159-168), so
212            // multiple `multi_solve` calls can be in flight at once and
213            // `remove_var` would race any concurrent environment read
214            // (undefined behavior). The atomic guarantees exactly one
215            // thread writes the dump while the process environment is left
216            // untouched.
217            if let Ok(path) = std::env::var("POUNCE_DBG_KKT_DUMP") {
218                if claim_kkt_dump(n_call, skip, &DUMPED) {
219                    tracing::warn!(
220                        target: "pounce::linsol",
221                        "POUNCE_DBG_KKT_DUMP is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
222                    );
223                    use std::io::Write;
224                    if let Ok(mut f) = std::fs::File::create(&path) {
225                        let dim = self.dim as u64;
226                        let nnz = self.nonzeros_triplet as u64;
227                        let nrhs64 = nrhs as u64;
228                        let _ = f.write_all(&dim.to_le_bytes());
229                        let _ = f.write_all(&nnz.to_le_bytes());
230                        let _ = f.write_all(&nrhs64.to_le_bytes());
231                        for &i in &self.airn {
232                            let _ = f.write_all(&(i as i64).to_le_bytes());
233                        }
234                        for &j in &self.ajcn {
235                            let _ = f.write_all(&(j as i64).to_le_bytes());
236                        }
237                        for &v in vals {
238                            let _ = f.write_all(&v.to_le_bytes());
239                        }
240                        for &v in &*rhs_vals {
241                            let _ = f.write_all(&v.to_le_bytes());
242                        }
243                        let _ = f.flush();
244                    }
245                }
246            }
247        }
248
249        // Push values + (optional) scaling into the backend.
250        let mut new_matrix = new_matrix;
251        if new_matrix || self.just_switched_on_scaling {
252            self.give_matrix_to_solver(true, vals);
253            new_matrix = true;
254        }
255
256        // Apply scaling to RHS columns (multiply by `s_i` per row).
257        if self.use_scaling {
258            for irhs in 0..nrhs as usize {
259                let base = irhs * self.dim as usize;
260                for i in 0..self.dim as usize {
261                    rhs_vals[base + i] *= self.scaling_factors[i];
262                }
263            }
264        }
265
266        // Backend solve, with `CALL_AGAIN` retry (MA57 grow path).
267        // Pre-resolve the index arrays into local pointers so we can
268        // hand them to the backend without re-borrowing `self`.
269        let status = loop {
270            let (ia_ptr, ia_len, ja_ptr, ja_len) = match self.converter.as_ref() {
271                None => (
272                    self.airn.as_ptr(),
273                    self.airn.len(),
274                    self.ajcn.as_ptr(),
275                    self.ajcn.len(),
276                ),
277                Some(c) => (c.ia().as_ptr(), c.ia().len(), c.ja().as_ptr(), c.ja().len()),
278            };
279            // SAFETY: the slices live in `self.airn/ajcn` or in the
280            // converter, both owned by `self`; the pointers are valid
281            // for the duration of this `multi_solve` call.
282            let (ia, ja) = unsafe {
283                (
284                    std::slice::from_raw_parts(ia_ptr, ia_len),
285                    std::slice::from_raw_parts(ja_ptr, ja_len),
286                )
287            };
288            let s = self.backend.multi_solve(
289                new_matrix,
290                ia,
291                ja,
292                nrhs,
293                rhs_vals,
294                check_neg_evals,
295                number_of_neg_evals,
296            );
297            if s == ESymSolverStatus::CallAgain {
298                self.give_matrix_to_solver(false, vals);
299                continue;
300            }
301            break s;
302        };
303
304        if status == ESymSolverStatus::Success && self.use_scaling {
305            // Solution comes back in scaled coordinates `x' = diag(s)
306            // x`; restore by another diag(s) multiply (since the
307            // scaled system is `(D A D)(D^-1 x) = D b` and we passed
308            // `D b`, the backend returns `D^-1 x`, hence multiply by
309            // `D` once more — see cpp:286-289).
310            for irhs in 0..nrhs as usize {
311                let base = irhs * self.dim as usize;
312                for i in 0..self.dim as usize {
313                    rhs_vals[base + i] *= self.scaling_factors[i];
314                }
315            }
316        }
317
318        status
319    }
320
321    /// Push `vals` (triplet-format) into the backend in the right
322    /// layout, optionally computing scaling factors and applying the
323    /// symmetric scale. Mirrors `TSymLinearSolver::GiveMatrixToSolver`.
324    fn give_matrix_to_solver(&mut self, new_matrix: bool, vals: &[Number]) {
325        // For triplet-format backends we write directly into the
326        // backend's array; for CSR backends we marshal via a temporary
327        // and call `convert_values`.
328        if self.matrix_format == EMatrixFormat::TripletFormat && !self.use_scaling {
329            let pa = self.backend.values_array_mut();
330            pa[..self.nonzeros_triplet as usize]
331                .copy_from_slice(&vals[..self.nonzeros_triplet as usize]);
332            return;
333        }
334
335        // Stage values in a local buffer so we can scale before
336        // shipping to the backend.
337        let mut atriplet: Vec<Number> = vals[..self.nonzeros_triplet as usize].to_vec();
338
339        if self.use_scaling {
340            if new_matrix || self.just_switched_on_scaling {
341                // `use_scaling` implies the scaling method is set
342                // (checked at construction time).
343                let Some(method) = self.scaling_method.as_mut() else {
344                    unreachable!("use_scaling without a scaling method")
345                };
346                let ok = method.compute_sym_t_scaling_factors(
347                    self.dim,
348                    self.nonzeros_triplet,
349                    &self.airn,
350                    &self.ajcn,
351                    &atriplet,
352                    &mut self.scaling_factors,
353                );
354                assert!(ok, "scaling method failed");
355                self.just_switched_on_scaling = false;
356            }
357            for (i, a) in atriplet
358                .iter_mut()
359                .enumerate()
360                .take(self.nonzeros_triplet as usize)
361            {
362                let r = (self.airn[i] - 1) as usize;
363                let c = (self.ajcn[i] - 1) as usize;
364                *a *= self.scaling_factors[r] * self.scaling_factors[c];
365            }
366        }
367
368        if self.matrix_format == EMatrixFormat::TripletFormat {
369            let pa = self.backend.values_array_mut();
370            pa[..self.nonzeros_triplet as usize].copy_from_slice(&atriplet);
371        } else {
372            let Some(conv) = self.converter.as_ref() else {
373                unreachable!("non-triplet matrix_format requires a converter");
374            };
375            let pa = self.backend.values_array_mut();
376            conv.convert_values(&atriplet, &mut pa[..self.nonzeros_compressed as usize]);
377        }
378    }
379
380    /// Pass-through to the backend's diagnostic factor-pattern
381    /// accessor. Returns `None` when the backend does not expose its
382    /// factor data (e.g. MA57). Consumed by the `--dump kkt:*+L` path.
383    pub fn factor_pattern(&self, want_values: bool) -> Option<FactorPattern> {
384        self.backend.factor_pattern(want_values)
385    }
386}
387
388impl SymLinearSolver for TSymLinearSolver {
389    fn number_of_neg_evals(&self) -> Index {
390        self.backend.number_of_neg_evals()
391    }
392
393    /// Mirrors upstream's `IncreaseQuality`: switching scaling on at
394    /// the wrapper level (`linear_scaling_on_demand=true` path) takes
395    /// precedence over asking the backend for tighter pivoting.
396    fn increase_quality(&mut self) -> bool {
397        if self.scaling_method.is_some() && !self.use_scaling && self.linear_scaling_on_demand {
398            self.use_scaling = true;
399            self.just_switched_on_scaling = true;
400            return true;
401        }
402        self.backend.increase_quality()
403    }
404
405    fn provides_inertia(&self) -> bool {
406        self.backend.provides_inertia()
407    }
408}
409
410/// Gate the deprecated one-shot `POUNCE_DBG_KKT_DUMP` dump
411/// ([`TSymLinearSolver::multi_solve`]). Returns `true` for exactly one
412/// caller — the first whose `n_call` has reached `skip` — claiming the
413/// dump with an atomic compare on `dumped`.
414///
415/// Extracted from `multi_solve` so the one-shot logic is unit-testable
416/// with a local `AtomicBool` (the live call site uses a `static`), and to
417/// make explicit that disabling the dump is an atomic claim rather than a
418/// mutation of the process environment. The previous implementation called
419/// `unsafe std::env::remove_var`, which is unsound when any other thread
420/// may read the environment concurrently — and pounce-feral runs
421/// rayon-parallel outer solves, so several `multi_solve` calls can race.
422/// This claims via `swap` and never touches the environment.
423fn claim_kkt_dump(n_call: usize, skip: usize, dumped: &std::sync::atomic::AtomicBool) -> bool {
424    use std::sync::atomic::Ordering;
425    if n_call < skip {
426        return false;
427    }
428    !dumped.swap(true, Ordering::SeqCst)
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::scaling::IdentityScalingMethod;
435
436    /// Mock triplet-format backend that exposes the values array,
437    /// records the most-recent solve call, and returns a hand-rolled
438    /// solution. Lets us exercise the wrapper without an FFI dep.
439    #[derive(Default)]
440    struct MockBackend {
441        dim: Index,
442        nz: Index,
443        a: Vec<Number>,
444        last_solve_was_new_matrix: bool,
445        last_solve_was_scaled_a: Option<Vec<Number>>,
446        canned_solution: Vec<Number>,
447        neg_evals: Index,
448        increase_quality_calls: u32,
449        max_increase_quality_calls: u32,
450    }
451
452    impl SparseSymLinearSolverInterface for MockBackend {
453        fn initialize_structure(
454            &mut self,
455            dim: Index,
456            nz: Index,
457            _ia: &[Index],
458            _ja: &[Index],
459        ) -> ESymSolverStatus {
460            self.dim = dim;
461            self.nz = nz;
462            self.a = vec![0.0; nz as usize];
463            ESymSolverStatus::Success
464        }
465        fn values_array_mut(&mut self) -> &mut [Number] {
466            &mut self.a
467        }
468        fn multi_solve(
469            &mut self,
470            new_matrix: bool,
471            _ia: &[Index],
472            _ja: &[Index],
473            nrhs: Index,
474            rhs_vals: &mut [Number],
475            _check: bool,
476            _nev: Index,
477        ) -> ESymSolverStatus {
478            self.last_solve_was_new_matrix = new_matrix;
479            self.last_solve_was_scaled_a = Some(self.a.clone());
480            assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
481            for irhs in 0..nrhs as usize {
482                let base = irhs * self.dim as usize;
483                rhs_vals[base..base + self.dim as usize].copy_from_slice(&self.canned_solution);
484            }
485            ESymSolverStatus::Success
486        }
487        fn number_of_neg_evals(&self) -> Index {
488            self.neg_evals
489        }
490        fn increase_quality(&mut self) -> bool {
491            self.increase_quality_calls += 1;
492            self.increase_quality_calls <= self.max_increase_quality_calls
493        }
494        fn provides_inertia(&self) -> bool {
495            true
496        }
497        fn matrix_format(&self) -> EMatrixFormat {
498            EMatrixFormat::TripletFormat
499        }
500    }
501
502    fn make_2x2_indef_pattern() -> ([Index; 3], [Index; 3]) {
503        ([1, 2, 2], [1, 1, 2])
504    }
505
506    #[test]
507    fn unscaled_triplet_solve_passes_values_through() {
508        let backend = MockBackend {
509            canned_solution: vec![10.0, 20.0],
510            ..Default::default()
511        };
512        let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
513        let (irn, jcn) = make_2x2_indef_pattern();
514        assert_eq!(
515            solver.initialize_structure(2, &irn, &jcn),
516            ESymSolverStatus::Success
517        );
518
519        let vals = [2.0, 1.0, 3.0];
520        let mut rhs = [3.0, 4.0];
521        assert_eq!(
522            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
523            ESymSolverStatus::Success
524        );
525        // Mock writes its canned solution.
526        assert_eq!(rhs, [10.0, 20.0]);
527        assert!(solver.provides_inertia());
528    }
529
530    #[test]
531    fn identity_scaling_does_not_change_values() {
532        let backend = MockBackend {
533            canned_solution: vec![1.0, 1.0],
534            ..Default::default()
535        };
536        // linear_scaling_on_demand=false → scaling active from the
537        // first solve.
538        let mut solver = TSymLinearSolver::new(
539            Box::new(backend),
540            Some(Box::new(IdentityScalingMethod)),
541            false,
542        );
543        let (irn, jcn) = make_2x2_indef_pattern();
544        solver.initialize_structure(2, &irn, &jcn);
545
546        let vals = [2.0, 1.0, 3.0];
547        let mut rhs = [4.0, 5.0];
548        assert_eq!(
549            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
550            ESymSolverStatus::Success
551        );
552        // Identity scaling: backend should have received the original
553        // value array unchanged, and the canned solution survives the
554        // unscale step (multiplied by 1.0 twice).
555        assert_eq!(rhs, [1.0, 1.0]);
556    }
557
558    #[test]
559    fn nontrivial_scaling_premultiplies_matrix_and_postmultiplies_solution() {
560        // Scaling method that returns s = (2, 3). After scaling, the
561        // backend should see (D A D) where D = diag(2,3); solving with
562        // RHS (D b) returns (D^-1 x), and the wrapper unscales by D
563        // once more to recover x.
564        struct DiagTwoThree;
565        impl TSymScalingMethod for DiagTwoThree {
566            fn compute_sym_t_scaling_factors(
567                &mut self,
568                _n: Index,
569                _nnz: Index,
570                _airn: &[Index],
571                _ajcn: &[Index],
572                _a: &[Number],
573                scaling_factors: &mut [Number],
574            ) -> bool {
575                scaling_factors[0] = 2.0;
576                scaling_factors[1] = 3.0;
577                true
578            }
579        }
580
581        let backend = MockBackend {
582            // Wrapper passes scaled RHS = (2*4, 3*5) = (8, 15).
583            // Mock returns `canned_solution` ignoring the input;
584            // wrapper then unscales: x = D · canned = (2 * c0, 3 * c1).
585            canned_solution: vec![7.0, 11.0],
586            ..Default::default()
587        };
588        let mut solver =
589            TSymLinearSolver::new(Box::new(backend), Some(Box::new(DiagTwoThree)), false);
590        let (irn, jcn) = make_2x2_indef_pattern();
591        solver.initialize_structure(2, &irn, &jcn);
592
593        let vals = [2.0, 1.0, 3.0];
594        let mut rhs = [4.0, 5.0];
595        assert_eq!(
596            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
597            ESymSolverStatus::Success
598        );
599        assert_eq!(rhs, [2.0 * 7.0, 3.0 * 11.0]);
600    }
601
602    #[test]
603    fn increase_quality_switches_on_scaling_first() {
604        let backend = MockBackend {
605            canned_solution: vec![0.0, 0.0],
606            max_increase_quality_calls: 5,
607            ..Default::default()
608        };
609        let mut solver = TSymLinearSolver::new(
610            Box::new(backend),
611            Some(Box::new(IdentityScalingMethod)),
612            true, // on demand
613        );
614        // First IncreaseQuality flips on scaling, does NOT touch the
615        // backend.
616        assert!(solver.increase_quality());
617        // Second IncreaseQuality goes to the backend.
618        assert!(solver.increase_quality());
619    }
620
621    #[test]
622    fn increase_quality_without_scaling_goes_straight_to_backend() {
623        let backend = MockBackend {
624            max_increase_quality_calls: 1,
625            ..Default::default()
626        };
627        let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
628        assert!(solver.increase_quality());
629        // Backend caps at 1; second call returns false.
630        assert!(!solver.increase_quality());
631    }
632
633    /// The deprecated KKT-dump gate must fire exactly once, on the first
634    /// call at/after the skip count, and never re-fire — without mutating
635    /// the environment (L9: replaces the old `unsafe env::remove_var`).
636    #[test]
637    fn claim_kkt_dump_is_one_shot_after_skip() {
638        use std::sync::atomic::AtomicBool;
639        let dumped = AtomicBool::new(false);
640        // Below the skip count: never claims.
641        assert!(!super::claim_kkt_dump(0, 2, &dumped));
642        assert!(!super::claim_kkt_dump(1, 2, &dumped));
643        // First call at/after skip claims exactly once; later calls no-op.
644        assert!(super::claim_kkt_dump(2, 2, &dumped));
645        assert!(!super::claim_kkt_dump(3, 2, &dumped));
646        assert!(!super::claim_kkt_dump(4, 2, &dumped));
647    }
648
649    /// The one-shot claim must be race-free across threads — the property
650    /// the old `unsafe env::remove_var` disable step could not provide
651    /// under pounce-feral's rayon-parallel outer solves. Exactly one of
652    /// many concurrent callers may win the claim.
653    #[test]
654    fn claim_kkt_dump_claims_exactly_once_under_concurrency() {
655        use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
656        use std::sync::{Arc, Barrier};
657        let dumped = Arc::new(AtomicBool::new(false));
658        let wins = Arc::new(AtomicUsize::new(0));
659        let n_threads = 32;
660        let barrier = Arc::new(Barrier::new(n_threads));
661        let mut handles = Vec::new();
662        for _ in 0..n_threads {
663            let d = Arc::clone(&dumped);
664            let w = Arc::clone(&wins);
665            let b = Arc::clone(&barrier);
666            handles.push(std::thread::spawn(move || {
667                // Maximize contention: all threads reach the claim together.
668                b.wait();
669                if super::claim_kkt_dump(0, 0, &d) {
670                    w.fetch_add(1, Ordering::SeqCst);
671                }
672            }));
673        }
674        for h in handles {
675            h.join().unwrap();
676        }
677        assert_eq!(
678            wins.load(Ordering::SeqCst),
679            1,
680            "exactly one thread must claim the one-shot KKT dump"
681        );
682    }
683}