Skip to main content

gam_solve/arrow_schur/
solve_options.rs

1//! Solver configuration and the batched block-solve abstraction: BA solver
2//! modes, PCG/trust-region/mixed-precision/proximal options, diagnostics, and
3//! the [`BatchedBlockSolver`] trait with its CPU implementation.
4
5use super::*;
6
7/// BA Schur solve variant for the reduced shared `β` system.
8///
9/// * [`ArrowSolverMode::Direct`] is BA's dense reduced-camera-system solve:
10///   eliminate the per-point/per-row blocks, form the reduced system, and
11///   Cholesky factor it. This is the Ceres/g2o default for modest camera
12///   counts and is appropriate here for `K <= 2000`.
13///   **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
14///
15/// * [`ArrowSolverMode::SqrtBA`] ports Square-Root BA (Demmel/Gao/Gu et al.,
16///   CVPR 2021): Schur terms are formed as `(L_i^-1 H_tβ_i)^T
17///   (L_i^-1 H_tβ_i)` from the per-row square-root factor `L_i`, avoiding
18///   explicit `H_tt^-1 H_tβ` products. It is the preferred direct path when
19///   single-precision assembly is introduced or when row blocks are poorly
20///   conditioned.
21///   **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
22///
23/// * [`ArrowSolverMode::InexactPCG`] ports "Bundle Adjustment in the Large"
24///   (Agarwal et al.): the Schur system is solved inexactly by PCG with a
25///   Jacobi Schur preconditioner, avoiding dense `K × K` factorization for
26///   SAE-manifold scale shared systems.
27///   **GPU support: CPU only** until the row-procedural H_tβ GPU PCG path
28///   (issue #288 Part B) is wired. The topology selector must not request
29///   `InexactPCG` via the GPU entry point; `solve_arrow_newton_step` returns
30///   `GpuRequiresDenseSystem` for matrix-free systems, and the wrapper in
31///   `solver/gpu/arrow_schur_gpu.rs` routes those to CPU InexactPCG
32///   automatically. At K ≥ 5000 the GPU PCG path will supersede the CPU path
33///   once the row-procedural H_tβ kernel and boxed GPU matvec backend in
34///   `run_pcg_with_preconditioner` are wired.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ArrowSolverMode {
37    Direct,
38    SqrtBA,
39    InexactPCG,
40}
41
42impl ArrowSolverMode {
43    /// BA-size heuristic: dense RCS for modest `K`, inexact Schur PCG for
44    /// large shared systems. This follows Agarwal et al.'s direct-vs-iterative
45    /// split for large BA, mapped from cameras to decoder coefficients.
46    pub const fn automatic(k: usize) -> Self {
47        if k <= DIRECT_SOLVE_MAX_K {
48            Self::Direct
49        } else {
50            Self::InexactPCG
51        }
52    }
53
54    /// Square-Root BA is the direct-solve stability mode for future f32
55    /// callers. Large `K` still routes to inexact PCG because dense Schur
56    /// storage dominates precision concerns at that scale.
57    pub const fn automatic_for_single_precision(k: usize) -> Self {
58        if k <= DIRECT_SOLVE_MAX_K {
59            Self::SqrtBA
60        } else {
61            Self::InexactPCG
62        }
63    }
64}
65
66/// Reason the Steihaug-CG loop stopped.
67#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
68pub enum PcgStopReason {
69    /// Residual fell below the relative tolerance threshold.
70    #[default]
71    Converged,
72    /// Loop exhausted max_iterations without converging.
73    MaxIter,
74    /// Step hit the trust-region boundary (Steihaug boundary projection).
75    TrustRegion,
76    /// Negative curvature detected in an unbounded solve.
77    Indefinite,
78    /// Non-positive or non-finite preconditioned residual after an update.
79    Stagnation,
80}
81
82/// Per-solve instrumentation counters returned alongside the PCG solution.
83///
84/// All fields default to zero; callers that do not need diagnostics simply
85/// ignore the value. The struct is Copy so passing it through return tuples
86/// is zero-overhead.
87#[derive(Debug, Default, Clone, Copy)]
88pub struct PcgDiagnostics {
89    /// Number of CG iterations executed.
90    pub iterations: usize,
91    /// Total calls to the Schur matvec A·p.
92    pub matvec_calls: usize,
93    /// Total calls to the preconditioner M^{-1}·r.
94    pub precond_apply_calls: usize,
95    /// Number of times the LM ridge was escalated before a successful factor.
96    pub ridge_escalations: usize,
97    /// Relative residual at termination; 0.0 when the RHS was zero.
98    pub final_relative_residual: f64,
99    /// Why the loop stopped.
100    pub stopping_reason: PcgStopReason,
101    /// Mixed-precision certificate outcome for this solve.
102    pub mixed_precision_status: MixedPrecisionStatus,
103    /// True only when the reduced-Schur solve was **actually executed on the
104    /// device**: either the fully device-resident batched Arrow-Schur Direct
105    /// sequence (`try_device_arrow_direct` → `solve_arrow_newton_step`) or the
106    /// device-resident matrix-free SAE PCG (`solve_sae_matrix_free_pcg`, which
107    /// runs the matvec in CUDA kernels over device-resident frames). It is NOT
108    /// set merely because a GPU runtime exists and a dispatch gate fired (#1209).
109    pub used_device_arrow: bool,
110    /// True when a reduced-Schur matvec backend was injected through
111    /// `maybe_inject_gpu_schur_matvec` but the matvec itself runs as a
112    /// **host** (CPU Rust/Rayon) procedural closure — both the matrix-free
113    /// `build_row_procedural_matvec` branch and the `cuda::build_schur_matvec_backend`
114    /// branch return host closures that evaluate `Σ_i Y_iᵀ(Y_i x)` on the CPU,
115    /// even when a CUDA context was opened to build the per-row factors. This
116    /// path must NOT report `used_device_arrow`: the arithmetic is host-side
117    /// (#1209). Distinct field so perf accounting never mistakes a host
118    /// procedural matvec for true device execution.
119    pub injected_host_procedural_matvec: bool,
120}
121
122/// Outcome of an opt-in mixed-precision arrow solve.
123#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
124pub enum MixedPrecisionStatus {
125    /// The caller did not request mixed precision or this solve mode cannot use it.
126    #[default]
127    Off,
128    /// The f32 factor solve was refined until the f64 backward-error certificate held.
129    Certified { refinement_steps: usize },
130    /// The kappa gate or solve shape rejected mixed precision and the f64 path ran.
131    /// The declining reason is logged at `info` level when the fallback fires.
132    F64Fallback,
133}
134
135/// PCG controls for BA's inexact reduced-camera-system solve.
136///
137/// The defaults mirror the loose inner tolerances used by inexact-step LM in
138/// "Bundle Adjustment in the Large": solve the Schur system only accurately
139/// enough for a useful trust-region step, then let the outer LM iteration
140/// correct the remaining error.
141#[derive(Debug, Clone)]
142pub struct ArrowPcgOptions {
143    pub max_iterations: usize,
144    pub relative_tolerance: f64,
145}
146
147impl Default for ArrowPcgOptions {
148    fn default() -> Self {
149        Self {
150            max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
151            relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
152        }
153    }
154}
155
156/// Trust-region controls for Steihaug-CG on the reduced BA system.
157///
158/// This is the Ceres-style guard around LM: `ridge_t`/`ridge_beta` provide
159/// Levenberg damping, while the trust radius bounds the reduced shared step
160/// in Euclidean β coordinates using Steihaug's truncated-CG stopping rules for
161/// boundary hits and negative curvature.
162#[derive(Debug, Clone)]
163pub struct ArrowTrustRegionOptions {
164    pub radius: f64,
165    pub steihaug_relative_tolerance: f64,
166    pub max_iterations: usize,
167}
168
169impl Default for ArrowTrustRegionOptions {
170    fn default() -> Self {
171        Self {
172            radius: DEFAULT_TRUST_REGION_RADIUS,
173            steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
174            max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
175        }
176    }
177}
178
179/// Opt-in Carson--Higham mixed-precision refinement for dense arrow solves.
180///
181/// Default is [`ArrowSolvePrecisionPolicy::F64Only`]: exact f64 solves remain the default.
182/// [`ArrowSolvePrecisionPolicy::CertifiedMixed`] stores f32 copies of the per-row Cholesky
183/// factors and dense Schur factor, solves corrections in f32, and recomputes the
184/// residual in f64 against the original arrow blocks. The standard refinement
185/// certificate is the normwise backward error
186///
187/// `||r||_inf / (||H||_inf ||x||_inf + ||b||_inf) <= residual_relative_tolerance`.
188///
189/// The kappa gate enforces `kappa_estimate * u_f32 < kappa_unit_roundoff_margin`;
190/// when it fails, the solve reports [`MixedPrecisionStatus::F64Fallback`] and
191/// logs the reason before using the f64 path.
192#[derive(Debug, Clone, Copy, PartialEq)]
193pub enum ArrowSolvePrecisionPolicy {
194    F64Only,
195    CertifiedMixed {
196        max_refinement_steps: usize,
197        residual_relative_tolerance: f64,
198        kappa_unit_roundoff_margin: f64,
199    },
200}
201
202impl Default for ArrowSolvePrecisionPolicy {
203    fn default() -> Self {
204        Self::F64Only
205    }
206}
207
208impl ArrowSolvePrecisionPolicy {
209    pub fn certified_mixed() -> Self {
210        Self::CertifiedMixed {
211            max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
212            residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
213            kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
214        }
215    }
216
217    pub(crate) fn is_enabled(self) -> bool {
218        matches!(self, ArrowSolvePrecisionPolicy::CertifiedMixed { .. })
219    }
220}
221
222/// Complete BA Schur solve options.
223///
224/// Use [`ArrowSolveOptions::automatic`] for normal latent-coordinate fits;
225/// use [`ArrowSolveOptions::sqrt_ba`] when the assembler has single-precision
226/// row blocks or an ill-conditioned gauge; use [`ArrowSolveOptions::inexact_pcg`]
227/// for SAE-manifold scale `K`.
228#[derive(Clone)]
229pub struct ArrowSolveOptions {
230    pub mode: ArrowSolverMode,
231    pub pcg: ArrowPcgOptions,
232    pub trust_region: ArrowTrustRegionOptions,
233    /// Row chunk size for streaming direct/Square-Root Schur assembly.
234    pub streaming_chunk_size: Option<usize>,
235    /// Use the Riemannian latent projection before the Schur reduction. The
236    /// reduced Steihaug solve itself remains in Euclidean β coordinates.
237    pub riemannian_trust_region: bool,
238    /// Optional GPU-backed Schur matvec for CPU-driven `InexactPCG` at K ≥ 5000.
239    ///
240    /// When set, `run_pcg_with_preconditioner` delegates each `S·p` call to
241    /// this closure instead of the CPU `schur_matvec`. Constructed by
242    /// `crate::gpu_kernels::arrow_schur::gpu_schur_matvec_backend` when `cuda_selected()`
243    /// and the system has dense per-row H_tβ slabs. `None` means CPU-only PCG.
244    pub gpu_matvec: Option<GpuSchurMatvec>,
245    /// Skip the ill-conditioning *rejection* (the κ-based
246    /// [`ArrowSchurError::PerRowFactorIllConditioned`] per-row guard and the
247    /// matching reduced-Schur κ guard) while still requiring genuine positive
248    /// definiteness (a non-PD Cholesky pivot still errors).
249    ///
250    /// The κ guards exist to protect the accuracy of the Newton *step*: a
251    /// barely-PD `H_tt^(i)` or an over-conditioned reduced Schur yields an
252    /// inaccurate `Δβ`/`Δt`. Evidence-only callers
253    /// (e.g. `SaeManifoldTerm::reml_criterion_with_cache`) do not consume the
254    /// step — they need only the factor cache for the log-determinant
255    /// (`½log|H|`, exact from `diag(L)` regardless of κ) and the selected-inverse
256    /// traces. For those callers the κ rejection is a false abort when ρ sweeps
257    /// to extreme values, so this flag lifts it and hands the
258    /// "is this step trustworthy" decision back to the caller.
259    ///
260    /// Default `false`: ordinary solves keep the full guard.
261    pub tolerate_ill_conditioning: bool,
262    /// Arrow solve precision policy. Default is f64-only.
263    pub solve_precision: ArrowSolvePrecisionPolicy,
264    /// Optional spectral positive-definiteness floor on the *reduced Schur
265    /// complement* `S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`,
266    /// as a relative fraction of `S`'s largest eigenvalue.
267    ///
268    /// `None` (default) keeps the strict contract: a non-PD `S` errors as
269    /// [`ArrowSchurError::SchurFactorFailed`] so the LM outer loop lifts
270    /// `ridge_beta` globally and re-forms `S`.
271    ///
272    /// `Some(floor)` engages the #1026 SAE co-collapse cure on the SOLVE path:
273    /// when the reduced Schur Cholesky refuses (collapsed atoms drive a per-row
274    /// `H_tt` near-singular, so the accumulated `(H_tt)⁻¹` over-subtracts `S`
275    /// into an INDEFINITE matrix), instead of rejecting and over-damping every
276    /// β direction with a global ridge, symmetric-eigendecompose `S` and clamp
277    /// every eigenvalue UP to `floor·max(λ)`. This is Levenberg–Marquardt
278    /// restricted to exactly the indefinite/collapsed subspace: the
279    /// well-conditioned β directions (`λ ≫ floor·max λ`) are untouched and the
280    /// step in those directions is the exact Newton step, while only the
281    /// collapsed directions receive the minimal damping needed for a PD solve.
282    /// The inner Newton then makes a real descent step rather than crawling
283    /// behind an inflated global ridge. Mirrors the per-row spectral floor the
284    /// evidence path uses for #1377/#1117/#1118
285    /// ([`super::factorization::factor_spectral_deflated_evidence_row`]); the
286    /// difference is the floored value — a small positive `floor·max λ`
287    /// (Tikhonov) for the solve, vs unit stiffness `+1` (`log 1 = 0`) for the
288    /// evidence log-det.
289    ///
290    /// Only consulted by the dense Direct / SqrtBA reduced solve (the only
291    /// caller of [`super::reduced_solve::solve_dense_reduced_system`]); the
292    /// InexactPCG path is unaffected.
293    pub schur_pd_floor: Option<f64>,
294}
295
296impl std::fmt::Debug for ArrowSolveOptions {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("ArrowSolveOptions")
299            .field("mode", &self.mode)
300            .field("pcg", &self.pcg)
301            .field("trust_region", &self.trust_region)
302            .field("streaming_chunk_size", &self.streaming_chunk_size)
303            .field("riemannian_trust_region", &self.riemannian_trust_region)
304            .field("gpu_matvec", &self.gpu_matvec.is_some())
305            .field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
306            .field("solve_precision", &self.solve_precision)
307            .field("schur_pd_floor", &self.schur_pd_floor)
308            .finish()
309    }
310}
311
312/// Globalization guard for non-convex arrow-Schur inner steps.
313///
314/// The raw Schur solve is exactly Newton. For non-convex analytic penalties,
315/// full Newton can cycle. This controller adds a proximal LM shift `mu I` to
316/// both blocks and accepts only Armijo-decreasing trial points.
317#[derive(Debug, Clone)]
318pub struct ArrowProximalCorrectionOptions {
319    pub initial_ridge: f64,
320    pub ridge_growth: f64,
321    pub max_attempts: usize,
322    pub armijo_c1: f64,
323    pub gradient_tolerance: f64,
324    /// Relative objective resolution below which the proximal correction
325    /// declares convergence instead of failing.
326    ///
327    /// Near a stationary point the largest decrease the damped Newton model can
328    /// still achieve shrinks to the floating-point resolution of the objective
329    /// itself: at proximal ridge `μ → μ_max` the accepted step length is
330    /// `O(‖g‖ / μ)`, so the realised change in the objective falls below
331    /// `rel_tol · (|f| + 1)`. At that scale the Armijo sufficient-decrease test
332    /// compares two values that differ only by rounding noise, and no further
333    /// productive decrease is achievable. Rather than raise
334    /// `AdaptiveCorrectionFailed`, the loop then returns the incumbent state
335    /// (a zero step) as converged. This does NOT mask genuine non-convergence:
336    /// it triggers only when every attempted step either fails to decrease the
337    /// objective by more than this resolution OR increases it by no more than
338    /// this resolution (pure rounding). A step that genuinely reduces the
339    /// objective is always taken first.
340    pub convergence_objective_rel_tol: f64,
341}
342
343impl Default for ArrowProximalCorrectionOptions {
344    fn default() -> Self {
345        Self {
346            initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
347            ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
348            max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
349            armijo_c1: DEFAULT_ARMIJO_C1,
350            gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
351            convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
352        }
353    }
354}
355
356/// Accepted proximal arrow-Schur step and the damping that made it descent.
357#[derive(Debug, Clone)]
358pub struct ArrowAcceptedProximalStep {
359    pub delta_t: Array1<f64>,
360    pub delta_beta: Array1<f64>,
361    pub ridge_t: f64,
362    pub ridge_beta: f64,
363    pub proximal_ridge: f64,
364    pub objective_value: f64,
365    pub trial_objective_value: f64,
366    pub gradient_dot_step: f64,
367    pub attempts: usize,
368}
369
370impl ArrowSolveOptions {
371    /// Select Direct for `K <= 2000` and InexactPCG above, following BA RCS
372    /// practice for dense-vs-iterative reduced systems.
373    pub fn automatic(k: usize) -> Self {
374        Self {
375            mode: ArrowSolverMode::automatic(k),
376            pcg: ArrowPcgOptions::default(),
377            trust_region: ArrowTrustRegionOptions::default(),
378            streaming_chunk_size: None,
379            riemannian_trust_region: false,
380            gpu_matvec: None,
381            tolerate_ill_conditioning: false,
382            solve_precision: ArrowSolvePrecisionPolicy::F64Only,
383            schur_pd_floor: None,
384        }
385    }
386
387    /// Force dense reduced-camera-system Cholesky, the classic BA direct
388    /// solve for small `K`.
389    pub fn direct() -> Self {
390        Self {
391            mode: ArrowSolverMode::Direct,
392            pcg: ArrowPcgOptions::default(),
393            trust_region: ArrowTrustRegionOptions::default(),
394            streaming_chunk_size: None,
395            riemannian_trust_region: false,
396            gpu_matvec: None,
397            tolerate_ill_conditioning: false,
398            solve_precision: ArrowSolvePrecisionPolicy::F64Only,
399            schur_pd_floor: None,
400        }
401    }
402
403    /// Force Square-Root BA Schur assembly for the direct reduced solve.
404    pub fn sqrt_ba() -> Self {
405        Self {
406            mode: ArrowSolverMode::SqrtBA,
407            pcg: ArrowPcgOptions::default(),
408            trust_region: ArrowTrustRegionOptions::default(),
409            streaming_chunk_size: None,
410            riemannian_trust_region: false,
411            gpu_matvec: None,
412            tolerate_ill_conditioning: false,
413            solve_precision: ArrowSolvePrecisionPolicy::F64Only,
414            schur_pd_floor: None,
415        }
416    }
417
418    /// Force inexact BA Schur PCG with Jacobi preconditioning.
419    pub fn inexact_pcg() -> Self {
420        Self {
421            mode: ArrowSolverMode::InexactPCG,
422            pcg: ArrowPcgOptions::default(),
423            trust_region: ArrowTrustRegionOptions::default(),
424            streaming_chunk_size: None,
425            riemannian_trust_region: false,
426            gpu_matvec: None,
427            tolerate_ill_conditioning: false,
428            solve_precision: ArrowSolvePrecisionPolicy::F64Only,
429            schur_pd_floor: None,
430        }
431    }
432
433    pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
434        self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
435        self
436    }
437
438    /// Lift the ill-conditioning *rejection* for evidence/log-det-only callers
439    /// while still requiring genuine PD. See [`Self::tolerate_ill_conditioning`].
440    ///
441    /// Use this when the returned `(Δt, Δβ)` Newton step is discarded and only
442    /// the factor cache is consumed (log-determinant + selected-inverse traces).
443    /// The cache stays undamped at `ridge_t = 0`, so the log-determinant is
444    /// exact regardless of κ.
445    pub fn with_ill_conditioning_tolerated(mut self) -> Self {
446        self.tolerate_ill_conditioning = true;
447        self
448    }
449
450    pub fn with_solve_precision_policy(mut self, policy: ArrowSolvePrecisionPolicy) -> Self {
451        self.solve_precision = policy;
452        self
453    }
454
455    /// Turn certified mixed precision ON for the streaming/residency reduced
456    /// solve unless the caller already pinned an explicit policy (#1014).
457    ///
458    /// Only `F64Only` (the inherited default) is upgraded to `CertifiedMixed`;
459    /// a caller that deliberately set a policy keeps it. The reduced-Schur f64
460    /// factor and every evidence log-determinant are unaffected — see
461    /// [`mixed_precision_reduced_beta`].
462    #[must_use]
463    pub fn with_streaming_solve_precision_default(&self) -> Self {
464        let mut out = self.clone();
465        if matches!(out.solve_precision, ArrowSolvePrecisionPolicy::F64Only) {
466            out.solve_precision = ArrowSolvePrecisionPolicy::certified_mixed();
467        }
468        out
469    }
470}
471
472/// CPU/GPU seam for BA point-block work.
473///
474/// BA systems spend most time in independent point-block factorizations,
475/// triangular solves, and Schur block products. MegBA maps exactly these
476/// operations to GPU kernels. This trait keeps that boundary explicit so a
477/// CUDA/Ceres backend can replace [`CpuBatchedBlockSolver`] without changing
478/// `ArrowSchurSystem` algebra.
479pub trait BatchedBlockSolver {
480    /// Factor every per-row point block `H_tt^(i) + ridge_t I`, as in BA's
481    /// point elimination stage.
482    ///
483    /// `tolerate_ill_conditioning` lifts the per-row κ rejection (still
484    /// requiring genuine PD); see [`ArrowSolveOptions::tolerate_ill_conditioning`].
485    fn factor_blocks(
486        &self,
487        rows: &[ArrowRowBlock],
488        ridge_t: f64,
489        d: usize,
490        tolerate_ill_conditioning: bool,
491    ) -> Result<ArrowFactorSlab, ArrowSchurError>;
492
493    /// Solve one factored point block against a vector RHS.
494    fn solve_block_vector(
495        &self,
496        factor: ArrayView2<'_, f64>,
497        rhs: ArrayView1<'_, f64>,
498    ) -> Array1<f64>;
499
500    /// Solve one factored point block against a dense matrix RHS.
501    fn solve_block_matrix(
502        &self,
503        factor: ArrayView2<'_, f64>,
504        rhs: ArrayView2<'_, f64>,
505    ) -> Array2<f64>;
506
507    /// Apply the Square-Root BA lower-triangular solve `L_i^-1 rhs`.
508    fn sqrt_solve_block_matrix(
509        &self,
510        factor: ArrayView2<'_, f64>,
511        rhs: ArrayView2<'_, f64>,
512    ) -> Array2<f64>;
513
514    /// Subtract a row-local Schur product from the dense reduced system.
515    fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
516}
517
518#[derive(Debug, Clone)]
519pub struct ArrowRowGaugeDeflation {
520    pub directions: Arc<[Vec<Array1<f64>>]>,
521}
522
523impl ArrowRowGaugeDeflation {
524    pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
525        Self {
526            directions: Arc::from(directions.into_boxed_slice()),
527        }
528    }
529
530    pub(crate) fn row(&self, row: usize) -> &[Array1<f64>] {
531        self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
532    }
533}
534
535/// Current CPU implementation of the BA batched block interface.
536///
537/// It is intentionally plain Rust loops because `d` is tiny. The trait shape,
538/// not this implementation, is the load-bearing part for the future MegBA or
539/// Ceres backend.
540#[derive(Debug, Clone, Copy, Default)]
541pub struct CpuBatchedBlockSolver;
542
543impl BatchedBlockSolver for CpuBatchedBlockSolver {
544    fn factor_blocks(
545        &self,
546        rows: &[ArrowRowBlock],
547        ridge_t: f64,
548        d: usize,
549        tolerate_ill_conditioning: bool,
550    ) -> Result<ArrowFactorSlab, ArrowSchurError> {
551        // Multi-GPU fast path: the per-row blocks `H_tt^(i) + ridge_t·I` are
552        // independent same-size SPD systems — exactly the batch
553        // `gam_gpu::try_cholesky_batched_lower_inplace` spreads across ALL
554        // usable devices (the batched POTRF tiles over the pool). It is only
555        // valid when every row is the uniform `d×d` shape; heterogeneous row
556        // dimensions keep the per-row CPU loop because the current cuSOLVER
557        // batched POTRF wrapper accepts one `(d, d)` shape per launch. It only
558        // succeeds when EVERY block is PD at
559        // the base ridge; a non-PD block returns `None`, so we fall back to the
560        // exact per-row CPU path that performs minimal per-block ridge
561        // escalation. After a successful batched factorization we re-apply the
562        // identical κ-conditioning rejection `factor_one_row` enforces, so the
563        // result is bit-for-bit equivalent (modulo IEEE reduction order) to the
564        // CPU loop: a barely-PD but ill-conditioned block forces the whole batch
565        // back onto the per-row path so its ridge can lift, never silently using
566        // a contaminated factor.
567        if let Some(batched) =
568            try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
569        {
570            return Ok(batched);
571        }
572        let mut out = Vec::with_capacity(rows.len());
573        for (row_idx, row) in rows.iter().enumerate() {
574            out.push(factor_one_row(
575                row,
576                ridge_t,
577                d,
578                row_idx,
579                tolerate_ill_conditioning,
580            )?);
581        }
582        Ok(ArrowFactorSlab::from_blocks(out))
583    }
584
585    fn solve_block_vector(
586        &self,
587        factor: ArrayView2<'_, f64>,
588        rhs: ArrayView1<'_, f64>,
589    ) -> Array1<f64> {
590        match (factor.nrows(), factor.ncols(), rhs.len()) {
591            (1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
592            (2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
593            (3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
594            (4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
595            _ => cholesky_solve_vector(factor, rhs),
596        }
597    }
598
599    fn solve_block_matrix(
600        &self,
601        factor: ArrayView2<'_, f64>,
602        rhs: ArrayView2<'_, f64>,
603    ) -> Array2<f64> {
604        cholesky_solve_matrix(factor, rhs)
605    }
606
607    fn sqrt_solve_block_matrix(
608        &self,
609        factor: ArrayView2<'_, f64>,
610        rhs: ArrayView2<'_, f64>,
611    ) -> Array2<f64> {
612        forward_substitution_lower_matrix(factor, rhs)
613    }
614
615    fn block_gemm_subtract(
616        &self,
617        schur: &mut Array2<f64>,
618        left: &Array2<f64>,
619        right: &Array2<f64>,
620    ) {
621        // Performance: ndarray Array2 is row-major, so `right[[c, b]]` is
622        // unit-strided in `b`. The canonical (a, b, c) order produced
623        // strided reads of `left[[c, a]]` for every (a, b); reorder to
624        // (c, a, b) so the inner `b`-loop is contiguous in `right` and
625        // `left[[c, a]]` is hoisted out of the inner loop.
626        let k = schur.nrows();
627        let d = left.nrows();
628        assert_eq!(left.ncols(), k);
629        assert_eq!(right.ncols(), k);
630        assert_eq!(schur.ncols(), k);
631        for c in 0..d {
632            for a in 0..k {
633                let lca = left[[c, a]];
634                if lca == 0.0 {
635                    continue;
636                }
637                for b in 0..k {
638                    schur[[a, b]] -= lca * right[[c, b]];
639                }
640            }
641        }
642    }
643}