gam_solve/arrow_schur/solve_options.rs
1//! Solver configuration and the batched block-solve abstraction: BA solver
2//! modes, PCG/trust-region/mixed-precision/proximal options, diagnostics, and
3//! the [`BatchedBlockSolver`] trait with its CPU implementation.
4
5use super::*;
6
7/// BA Schur solve variant for the reduced shared `β` system.
8///
9/// * [`ArrowSolverMode::Direct`] is BA's dense reduced-camera-system solve:
10/// eliminate the per-point/per-row blocks, form the reduced system, and
11/// Cholesky factor it. This is the Ceres/g2o default for modest camera
12/// counts and is appropriate here for `K <= 2000`.
13/// **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
14///
15/// * [`ArrowSolverMode::SqrtBA`] ports Square-Root BA (Demmel/Gao/Gu et al.,
16/// CVPR 2021): Schur terms are formed as `(L_i^-1 H_tβ_i)^T
17/// (L_i^-1 H_tβ_i)` from the per-row square-root factor `L_i`, avoiding
18/// explicit `H_tt^-1 H_tβ` products. It is the preferred direct path when
19/// single-precision assembly is introduced or when row blocks are poorly
20/// conditioned.
21/// **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
22///
23/// * [`ArrowSolverMode::InexactPCG`] ports "Bundle Adjustment in the Large"
24/// (Agarwal et al.): the Schur system is solved inexactly by PCG with a
25/// Jacobi Schur preconditioner, avoiding dense `K × K` factorization for
26/// SAE-manifold scale shared systems.
27/// **GPU support: CPU only** until the row-procedural H_tβ GPU PCG path
28/// (issue #288 Part B) is wired. The topology selector must not request
29/// `InexactPCG` via the GPU entry point; `solve_arrow_newton_step` returns
30/// `GpuRequiresDenseSystem` for matrix-free systems, and the wrapper in
31/// `solver/gpu/arrow_schur_gpu.rs` routes those to CPU InexactPCG
32/// automatically. At K ≥ 5000 the GPU PCG path will supersede the CPU path
33/// once the row-procedural H_tβ kernel and boxed GPU matvec backend in
34/// `run_pcg_with_preconditioner` are wired.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ArrowSolverMode {
37 Direct,
38 SqrtBA,
39 InexactPCG,
40}
41
42impl ArrowSolverMode {
43 /// BA-size heuristic: dense RCS for modest `K`, inexact Schur PCG for
44 /// large shared systems. This follows Agarwal et al.'s direct-vs-iterative
45 /// split for large BA, mapped from cameras to decoder coefficients.
46 pub const fn automatic(k: usize) -> Self {
47 if k <= DIRECT_SOLVE_MAX_K {
48 Self::Direct
49 } else {
50 Self::InexactPCG
51 }
52 }
53
54 /// Square-Root BA is the direct-solve stability mode for future f32
55 /// callers. Large `K` still routes to inexact PCG because dense Schur
56 /// storage dominates precision concerns at that scale.
57 pub const fn automatic_for_single_precision(k: usize) -> Self {
58 if k <= DIRECT_SOLVE_MAX_K {
59 Self::SqrtBA
60 } else {
61 Self::InexactPCG
62 }
63 }
64}
65
66/// Reason the Steihaug-CG loop stopped.
67#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
68pub enum PcgStopReason {
69 /// Residual fell below the relative tolerance threshold.
70 #[default]
71 Converged,
72 /// Loop exhausted max_iterations without converging.
73 MaxIter,
74 /// Step hit the trust-region boundary (Steihaug boundary projection).
75 TrustRegion,
76 /// Negative curvature detected in an unbounded solve.
77 Indefinite,
78 /// Non-positive or non-finite preconditioned residual after an update.
79 Stagnation,
80}
81
82/// Per-solve instrumentation counters returned alongside the PCG solution.
83///
84/// All fields default to zero; callers that do not need diagnostics simply
85/// ignore the value. The struct is Copy so passing it through return tuples
86/// is zero-overhead.
87#[derive(Debug, Default, Clone, Copy)]
88pub struct PcgDiagnostics {
89 /// Number of CG iterations executed.
90 pub iterations: usize,
91 /// Total calls to the Schur matvec A·p.
92 pub matvec_calls: usize,
93 /// Total calls to the preconditioner M^{-1}·r.
94 pub precond_apply_calls: usize,
95 /// Number of times the LM ridge was escalated before a successful factor.
96 pub ridge_escalations: usize,
97 /// Relative residual at termination; 0.0 when the RHS was zero.
98 pub final_relative_residual: f64,
99 /// Why the loop stopped.
100 pub stopping_reason: PcgStopReason,
101 /// Mixed-precision certificate outcome for this solve.
102 pub mixed_precision_status: MixedPrecisionStatus,
103 /// True only when the reduced-Schur solve was **actually executed on the
104 /// device**: either the fully device-resident batched Arrow-Schur Direct
105 /// sequence (`try_device_arrow_direct` → `solve_arrow_newton_step`) or the
106 /// device-resident matrix-free SAE PCG (`solve_sae_matrix_free_pcg`, which
107 /// runs the matvec in CUDA kernels over device-resident frames). It is NOT
108 /// set merely because a GPU runtime exists and a dispatch gate fired (#1209).
109 pub used_device_arrow: bool,
110 /// True when a reduced-Schur matvec backend was injected through
111 /// `maybe_inject_gpu_schur_matvec` but the matvec itself runs as a
112 /// **host** (CPU Rust/Rayon) procedural closure — both the matrix-free
113 /// `build_row_procedural_matvec` branch and the `cuda::build_schur_matvec_backend`
114 /// branch return host closures that evaluate `Σ_i Y_iᵀ(Y_i x)` on the CPU,
115 /// even when a CUDA context was opened to build the per-row factors. This
116 /// path must NOT report `used_device_arrow`: the arithmetic is host-side
117 /// (#1209). Distinct field so perf accounting never mistakes a host
118 /// procedural matvec for true device execution.
119 pub injected_host_procedural_matvec: bool,
120}
121
122/// Outcome of an opt-in mixed-precision arrow solve.
123#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
124pub enum MixedPrecisionStatus {
125 /// The caller did not request mixed precision or this solve mode cannot use it.
126 #[default]
127 Off,
128 /// The f32 factor solve was refined until the f64 backward-error certificate held.
129 Certified { refinement_steps: usize },
130 /// The kappa gate or solve shape rejected mixed precision and the f64 path ran.
131 /// The declining reason is logged at `info` level when the fallback fires.
132 F64Fallback,
133}
134
135/// PCG controls for BA's inexact reduced-camera-system solve.
136///
137/// The defaults mirror the loose inner tolerances used by inexact-step LM in
138/// "Bundle Adjustment in the Large": solve the Schur system only accurately
139/// enough for a useful trust-region step, then let the outer LM iteration
140/// correct the remaining error.
141#[derive(Debug, Clone)]
142pub struct ArrowPcgOptions {
143 pub max_iterations: usize,
144 pub relative_tolerance: f64,
145}
146
147impl Default for ArrowPcgOptions {
148 fn default() -> Self {
149 Self {
150 max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
151 relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
152 }
153 }
154}
155
156/// Trust-region controls for Steihaug-CG on the reduced BA system.
157///
158/// This is the Ceres-style guard around LM: `ridge_t`/`ridge_beta` provide
159/// Levenberg damping, while the trust radius bounds the reduced shared step
160/// in Euclidean β coordinates using Steihaug's truncated-CG stopping rules for
161/// boundary hits and negative curvature.
162#[derive(Debug, Clone)]
163pub struct ArrowTrustRegionOptions {
164 pub radius: f64,
165 pub steihaug_relative_tolerance: f64,
166 pub max_iterations: usize,
167}
168
169impl Default for ArrowTrustRegionOptions {
170 fn default() -> Self {
171 Self {
172 radius: DEFAULT_TRUST_REGION_RADIUS,
173 steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
174 max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
175 }
176 }
177}
178
179/// Opt-in Carson--Higham mixed-precision refinement for dense arrow solves.
180///
181/// Default is [`ArrowSolvePrecisionPolicy::F64Only`]: exact f64 solves remain the default.
182/// [`ArrowSolvePrecisionPolicy::CertifiedMixed`] stores f32 copies of the per-row Cholesky
183/// factors and dense Schur factor, solves corrections in f32, and recomputes the
184/// residual in f64 against the original arrow blocks. The standard refinement
185/// certificate is the normwise backward error
186///
187/// `||r||_inf / (||H||_inf ||x||_inf + ||b||_inf) <= residual_relative_tolerance`.
188///
189/// The kappa gate enforces `kappa_estimate * u_f32 < kappa_unit_roundoff_margin`;
190/// when it fails, the solve reports [`MixedPrecisionStatus::F64Fallback`] and
191/// logs the reason before using the f64 path.
192#[derive(Debug, Clone, Copy, PartialEq)]
193pub enum ArrowSolvePrecisionPolicy {
194 F64Only,
195 CertifiedMixed {
196 max_refinement_steps: usize,
197 residual_relative_tolerance: f64,
198 kappa_unit_roundoff_margin: f64,
199 },
200}
201
202impl Default for ArrowSolvePrecisionPolicy {
203 fn default() -> Self {
204 Self::F64Only
205 }
206}
207
208impl ArrowSolvePrecisionPolicy {
209 pub fn certified_mixed() -> Self {
210 Self::CertifiedMixed {
211 max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
212 residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
213 kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
214 }
215 }
216
217 pub(crate) fn is_enabled(self) -> bool {
218 matches!(self, ArrowSolvePrecisionPolicy::CertifiedMixed { .. })
219 }
220}
221
222/// Complete BA Schur solve options.
223///
224/// Use [`ArrowSolveOptions::automatic`] for normal latent-coordinate fits;
225/// use [`ArrowSolveOptions::sqrt_ba`] when the assembler has single-precision
226/// row blocks or an ill-conditioned gauge; use [`ArrowSolveOptions::inexact_pcg`]
227/// for SAE-manifold scale `K`.
228#[derive(Clone)]
229pub struct ArrowSolveOptions {
230 pub mode: ArrowSolverMode,
231 pub pcg: ArrowPcgOptions,
232 pub trust_region: ArrowTrustRegionOptions,
233 /// Row chunk size for streaming direct/Square-Root Schur assembly.
234 pub streaming_chunk_size: Option<usize>,
235 /// Use the Riemannian latent projection before the Schur reduction. The
236 /// reduced Steihaug solve itself remains in Euclidean β coordinates.
237 pub riemannian_trust_region: bool,
238 /// Optional GPU-backed Schur matvec for CPU-driven `InexactPCG` at K ≥ 5000.
239 ///
240 /// When set, `run_pcg_with_preconditioner` delegates each `S·p` call to
241 /// this closure instead of the CPU `schur_matvec`. Constructed by
242 /// `crate::gpu_kernels::arrow_schur::gpu_schur_matvec_backend` when `cuda_selected()`
243 /// and the system has dense per-row H_tβ slabs. `None` means CPU-only PCG.
244 pub gpu_matvec: Option<GpuSchurMatvec>,
245 /// Skip the ill-conditioning *rejection* (the κ-based
246 /// [`ArrowSchurError::PerRowFactorIllConditioned`] per-row guard and the
247 /// matching reduced-Schur κ guard) while still requiring genuine positive
248 /// definiteness (a non-PD Cholesky pivot still errors).
249 ///
250 /// The κ guards exist to protect the accuracy of the Newton *step*: a
251 /// barely-PD `H_tt^(i)` or an over-conditioned reduced Schur yields an
252 /// inaccurate `Δβ`/`Δt`. Evidence-only callers
253 /// (e.g. `SaeManifoldTerm::reml_criterion_with_cache`) do not consume the
254 /// step — they need only the factor cache for the log-determinant
255 /// (`½log|H|`, exact from `diag(L)` regardless of κ) and the selected-inverse
256 /// traces. For those callers the κ rejection is a false abort when ρ sweeps
257 /// to extreme values, so this flag lifts it and hands the
258 /// "is this step trustworthy" decision back to the caller.
259 ///
260 /// Default `false`: ordinary solves keep the full guard.
261 pub tolerate_ill_conditioning: bool,
262 /// Arrow solve precision policy. Default is f64-only.
263 pub solve_precision: ArrowSolvePrecisionPolicy,
264 /// Optional spectral positive-definiteness floor on the *reduced Schur
265 /// complement* `S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`,
266 /// as a relative fraction of `S`'s largest eigenvalue.
267 ///
268 /// `None` (default) keeps the strict contract: a non-PD `S` errors as
269 /// [`ArrowSchurError::SchurFactorFailed`] so the LM outer loop lifts
270 /// `ridge_beta` globally and re-forms `S`.
271 ///
272 /// `Some(floor)` engages the #1026 SAE co-collapse cure on the SOLVE path:
273 /// when the reduced Schur Cholesky refuses (collapsed atoms drive a per-row
274 /// `H_tt` near-singular, so the accumulated `(H_tt)⁻¹` over-subtracts `S`
275 /// into an INDEFINITE matrix), instead of rejecting and over-damping every
276 /// β direction with a global ridge, symmetric-eigendecompose `S` and clamp
277 /// every eigenvalue UP to `floor·max(λ)`. This is Levenberg–Marquardt
278 /// restricted to exactly the indefinite/collapsed subspace: the
279 /// well-conditioned β directions (`λ ≫ floor·max λ`) are untouched and the
280 /// step in those directions is the exact Newton step, while only the
281 /// collapsed directions receive the minimal damping needed for a PD solve.
282 /// The inner Newton then makes a real descent step rather than crawling
283 /// behind an inflated global ridge. Mirrors the per-row spectral floor the
284 /// evidence path uses for #1377/#1117/#1118
285 /// ([`super::factorization::factor_spectral_deflated_evidence_row`]); the
286 /// difference is the floored value — a small positive `floor·max λ`
287 /// (Tikhonov) for the solve, vs unit stiffness `+1` (`log 1 = 0`) for the
288 /// evidence log-det.
289 ///
290 /// Only consulted by the dense Direct / SqrtBA reduced solve (the only
291 /// caller of [`super::reduced_solve::solve_dense_reduced_system`]); the
292 /// InexactPCG path is unaffected.
293 pub schur_pd_floor: Option<f64>,
294}
295
296impl std::fmt::Debug for ArrowSolveOptions {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("ArrowSolveOptions")
299 .field("mode", &self.mode)
300 .field("pcg", &self.pcg)
301 .field("trust_region", &self.trust_region)
302 .field("streaming_chunk_size", &self.streaming_chunk_size)
303 .field("riemannian_trust_region", &self.riemannian_trust_region)
304 .field("gpu_matvec", &self.gpu_matvec.is_some())
305 .field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
306 .field("solve_precision", &self.solve_precision)
307 .field("schur_pd_floor", &self.schur_pd_floor)
308 .finish()
309 }
310}
311
312/// Globalization guard for non-convex arrow-Schur inner steps.
313///
314/// The raw Schur solve is exactly Newton. For non-convex analytic penalties,
315/// full Newton can cycle. This controller adds a proximal LM shift `mu I` to
316/// both blocks and accepts only Armijo-decreasing trial points.
317#[derive(Debug, Clone)]
318pub struct ArrowProximalCorrectionOptions {
319 pub initial_ridge: f64,
320 pub ridge_growth: f64,
321 pub max_attempts: usize,
322 pub armijo_c1: f64,
323 pub gradient_tolerance: f64,
324 /// Relative objective resolution below which the proximal correction
325 /// declares convergence instead of failing.
326 ///
327 /// Near a stationary point the largest decrease the damped Newton model can
328 /// still achieve shrinks to the floating-point resolution of the objective
329 /// itself: at proximal ridge `μ → μ_max` the accepted step length is
330 /// `O(‖g‖ / μ)`, so the realised change in the objective falls below
331 /// `rel_tol · (|f| + 1)`. At that scale the Armijo sufficient-decrease test
332 /// compares two values that differ only by rounding noise, and no further
333 /// productive decrease is achievable. Rather than raise
334 /// `AdaptiveCorrectionFailed`, the loop then returns the incumbent state
335 /// (a zero step) as converged. This does NOT mask genuine non-convergence:
336 /// it triggers only when every attempted step either fails to decrease the
337 /// objective by more than this resolution OR increases it by no more than
338 /// this resolution (pure rounding). A step that genuinely reduces the
339 /// objective is always taken first.
340 pub convergence_objective_rel_tol: f64,
341}
342
343impl Default for ArrowProximalCorrectionOptions {
344 fn default() -> Self {
345 Self {
346 initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
347 ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
348 max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
349 armijo_c1: DEFAULT_ARMIJO_C1,
350 gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
351 convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
352 }
353 }
354}
355
356/// Accepted proximal arrow-Schur step and the damping that made it descent.
357#[derive(Debug, Clone)]
358pub struct ArrowAcceptedProximalStep {
359 pub delta_t: Array1<f64>,
360 pub delta_beta: Array1<f64>,
361 pub ridge_t: f64,
362 pub ridge_beta: f64,
363 pub proximal_ridge: f64,
364 pub objective_value: f64,
365 pub trial_objective_value: f64,
366 pub gradient_dot_step: f64,
367 pub attempts: usize,
368}
369
370impl ArrowSolveOptions {
371 /// Select Direct for `K <= 2000` and InexactPCG above, following BA RCS
372 /// practice for dense-vs-iterative reduced systems.
373 pub fn automatic(k: usize) -> Self {
374 Self {
375 mode: ArrowSolverMode::automatic(k),
376 pcg: ArrowPcgOptions::default(),
377 trust_region: ArrowTrustRegionOptions::default(),
378 streaming_chunk_size: None,
379 riemannian_trust_region: false,
380 gpu_matvec: None,
381 tolerate_ill_conditioning: false,
382 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
383 schur_pd_floor: None,
384 }
385 }
386
387 /// Force dense reduced-camera-system Cholesky, the classic BA direct
388 /// solve for small `K`.
389 pub fn direct() -> Self {
390 Self {
391 mode: ArrowSolverMode::Direct,
392 pcg: ArrowPcgOptions::default(),
393 trust_region: ArrowTrustRegionOptions::default(),
394 streaming_chunk_size: None,
395 riemannian_trust_region: false,
396 gpu_matvec: None,
397 tolerate_ill_conditioning: false,
398 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
399 schur_pd_floor: None,
400 }
401 }
402
403 /// Force Square-Root BA Schur assembly for the direct reduced solve.
404 pub fn sqrt_ba() -> Self {
405 Self {
406 mode: ArrowSolverMode::SqrtBA,
407 pcg: ArrowPcgOptions::default(),
408 trust_region: ArrowTrustRegionOptions::default(),
409 streaming_chunk_size: None,
410 riemannian_trust_region: false,
411 gpu_matvec: None,
412 tolerate_ill_conditioning: false,
413 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
414 schur_pd_floor: None,
415 }
416 }
417
418 /// Force inexact BA Schur PCG with Jacobi preconditioning.
419 pub fn inexact_pcg() -> Self {
420 Self {
421 mode: ArrowSolverMode::InexactPCG,
422 pcg: ArrowPcgOptions::default(),
423 trust_region: ArrowTrustRegionOptions::default(),
424 streaming_chunk_size: None,
425 riemannian_trust_region: false,
426 gpu_matvec: None,
427 tolerate_ill_conditioning: false,
428 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
429 schur_pd_floor: None,
430 }
431 }
432
433 pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
434 self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
435 self
436 }
437
438 /// Lift the ill-conditioning *rejection* for evidence/log-det-only callers
439 /// while still requiring genuine PD. See [`Self::tolerate_ill_conditioning`].
440 ///
441 /// Use this when the returned `(Δt, Δβ)` Newton step is discarded and only
442 /// the factor cache is consumed (log-determinant + selected-inverse traces).
443 /// The cache stays undamped at `ridge_t = 0`, so the log-determinant is
444 /// exact regardless of κ.
445 pub fn with_ill_conditioning_tolerated(mut self) -> Self {
446 self.tolerate_ill_conditioning = true;
447 self
448 }
449
450 /// Enable the spectral PD-floor on an indefinite reduced Schur (the SAE solve
451 /// path): floor the collapsed / dead-atom directions up to `floor·max(λ)` and
452 /// re-factor instead of hard-erroring. An overcomplete manifold-SAE fit parks
453 /// surplus atoms dead, so the reduced Schur (and the undamped evidence factor
454 /// at the optimum) can have near-zero / slightly-negative eigenvalues on the
455 /// dead subspace; flooring those lets the live subspace's exact Newton /
456 /// log-det proceed instead of aborting the whole fit on a non-PD pivot. `None`
457 /// (default) keeps the strict refusal for BA / non-SAE callers.
458 pub fn with_schur_pd_floor(mut self, floor: f64) -> Self {
459 self.schur_pd_floor = Some(floor);
460 self
461 }
462
463 pub fn with_solve_precision_policy(mut self, policy: ArrowSolvePrecisionPolicy) -> Self {
464 self.solve_precision = policy;
465 self
466 }
467
468 /// Turn certified mixed precision ON for the streaming/residency reduced
469 /// solve unless the caller already pinned an explicit policy (#1014).
470 ///
471 /// Only `F64Only` (the inherited default) is upgraded to `CertifiedMixed`;
472 /// a caller that deliberately set a policy keeps it. The reduced-Schur f64
473 /// factor and every evidence log-determinant are unaffected — see
474 /// [`mixed_precision_reduced_beta`].
475 #[must_use]
476 pub fn with_streaming_solve_precision_default(&self) -> Self {
477 let mut out = self.clone();
478 if matches!(out.solve_precision, ArrowSolvePrecisionPolicy::F64Only) {
479 out.solve_precision = ArrowSolvePrecisionPolicy::certified_mixed();
480 }
481 out
482 }
483}
484
485/// CPU/GPU seam for BA point-block work.
486///
487/// BA systems spend most time in independent point-block factorizations,
488/// triangular solves, and Schur block products. MegBA maps exactly these
489/// operations to GPU kernels. This trait keeps that boundary explicit so a
490/// CUDA/Ceres backend can replace [`CpuBatchedBlockSolver`] without changing
491/// `ArrowSchurSystem` algebra.
492pub trait BatchedBlockSolver {
493 /// Factor every per-row point block `H_tt^(i) + ridge_t I`, as in BA's
494 /// point elimination stage.
495 ///
496 /// `tolerate_ill_conditioning` lifts the per-row κ rejection (still
497 /// requiring genuine PD); see [`ArrowSolveOptions::tolerate_ill_conditioning`].
498 fn factor_blocks(
499 &self,
500 rows: &[ArrowRowBlock],
501 ridge_t: f64,
502 d: usize,
503 tolerate_ill_conditioning: bool,
504 ) -> Result<ArrowFactorSlab, ArrowSchurError>;
505
506 /// Solve one factored point block against a vector RHS.
507 fn solve_block_vector(
508 &self,
509 factor: ArrayView2<'_, f64>,
510 rhs: ArrayView1<'_, f64>,
511 ) -> Array1<f64>;
512
513 /// Solve one factored point block against a dense matrix RHS.
514 fn solve_block_matrix(
515 &self,
516 factor: ArrayView2<'_, f64>,
517 rhs: ArrayView2<'_, f64>,
518 ) -> Array2<f64>;
519
520 /// Apply the Square-Root BA lower-triangular solve `L_i^-1 rhs`.
521 fn sqrt_solve_block_matrix(
522 &self,
523 factor: ArrayView2<'_, f64>,
524 rhs: ArrayView2<'_, f64>,
525 ) -> Array2<f64>;
526
527 /// Subtract a row-local Schur product from the dense reduced system.
528 fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
529}
530
531#[derive(Debug, Clone)]
532pub struct ArrowRowGaugeDeflation {
533 pub directions: Arc<[Vec<Array1<f64>>]>,
534}
535
536impl ArrowRowGaugeDeflation {
537 pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
538 Self {
539 directions: Arc::from(directions.into_boxed_slice()),
540 }
541 }
542
543 pub(crate) fn row(&self, row: usize) -> &[Array1<f64>] {
544 self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
545 }
546}
547
548/// Current CPU implementation of the BA batched block interface.
549///
550/// It is intentionally plain Rust loops because `d` is tiny. The trait shape,
551/// not this implementation, is the load-bearing part for the future MegBA or
552/// Ceres backend.
553#[derive(Debug, Clone, Copy, Default)]
554pub struct CpuBatchedBlockSolver;
555
556impl BatchedBlockSolver for CpuBatchedBlockSolver {
557 fn factor_blocks(
558 &self,
559 rows: &[ArrowRowBlock],
560 ridge_t: f64,
561 d: usize,
562 tolerate_ill_conditioning: bool,
563 ) -> Result<ArrowFactorSlab, ArrowSchurError> {
564 // Multi-GPU fast path: the per-row blocks `H_tt^(i) + ridge_t·I` are
565 // independent same-size SPD systems — exactly the batch
566 // `gam_gpu::try_cholesky_batched_lower_inplace` spreads across ALL
567 // usable devices (the batched POTRF tiles over the pool). It is only
568 // valid when every row is the uniform `d×d` shape; heterogeneous row
569 // dimensions keep the per-row CPU loop because the current cuSOLVER
570 // batched POTRF wrapper accepts one `(d, d)` shape per launch. It only
571 // succeeds when EVERY block is PD at
572 // the base ridge; a non-PD block returns `None`, so we fall back to the
573 // exact per-row CPU path that performs minimal per-block ridge
574 // escalation. After a successful batched factorization we re-apply the
575 // identical κ-conditioning rejection `factor_one_row` enforces, so the
576 // result is bit-for-bit equivalent (modulo IEEE reduction order) to the
577 // CPU loop: a barely-PD but ill-conditioned block forces the whole batch
578 // back onto the per-row path so its ridge can lift, never silently using
579 // a contaminated factor.
580 if let Some(batched) =
581 try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
582 {
583 return Ok(batched);
584 }
585 let mut out = Vec::with_capacity(rows.len());
586 for (row_idx, row) in rows.iter().enumerate() {
587 out.push(factor_one_row(
588 row,
589 ridge_t,
590 d,
591 row_idx,
592 tolerate_ill_conditioning,
593 )?);
594 }
595 Ok(ArrowFactorSlab::from_blocks(out))
596 }
597
598 fn solve_block_vector(
599 &self,
600 factor: ArrayView2<'_, f64>,
601 rhs: ArrayView1<'_, f64>,
602 ) -> Array1<f64> {
603 match (factor.nrows(), factor.ncols(), rhs.len()) {
604 (1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
605 (2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
606 (3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
607 (4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
608 _ => cholesky_solve_vector(factor, rhs),
609 }
610 }
611
612 fn solve_block_matrix(
613 &self,
614 factor: ArrayView2<'_, f64>,
615 rhs: ArrayView2<'_, f64>,
616 ) -> Array2<f64> {
617 cholesky_solve_matrix(factor, rhs)
618 }
619
620 fn sqrt_solve_block_matrix(
621 &self,
622 factor: ArrayView2<'_, f64>,
623 rhs: ArrayView2<'_, f64>,
624 ) -> Array2<f64> {
625 forward_substitution_lower_matrix(factor, rhs)
626 }
627
628 fn block_gemm_subtract(
629 &self,
630 schur: &mut Array2<f64>,
631 left: &Array2<f64>,
632 right: &Array2<f64>,
633 ) {
634 // Performance: ndarray Array2 is row-major, so `right[[c, b]]` is
635 // unit-strided in `b`. The canonical (a, b, c) order produced
636 // strided reads of `left[[c, a]]` for every (a, b); reorder to
637 // (c, a, b) so the inner `b`-loop is contiguous in `right` and
638 // `left[[c, a]]` is hoisted out of the inner loop.
639 let k = schur.nrows();
640 let d = left.nrows();
641 assert_eq!(left.ncols(), k);
642 assert_eq!(right.ncols(), k);
643 assert_eq!(schur.ncols(), k);
644 for c in 0..d {
645 for a in 0..k {
646 let lca = left[[c, a]];
647 if lca == 0.0 {
648 continue;
649 }
650 for b in 0..k {
651 schur[[a, b]] -= lca * right[[c, b]];
652 }
653 }
654 }
655 }
656}