gam_solve/arrow_schur/solve_options.rs
1//! Solver configuration and the batched block-solve abstraction: BA solver
2//! modes, PCG/trust-region/mixed-precision/proximal options, diagnostics, and
3//! the [`BatchedBlockSolver`] trait with its CPU implementation.
4
5use super::*;
6
7/// BA Schur solve variant for the reduced shared `β` system.
8///
9/// * [`ArrowSolverMode::Direct`] is BA's dense reduced-camera-system solve:
10/// eliminate the per-point/per-row blocks, form the reduced system, and
11/// Cholesky factor it. This is the Ceres/g2o default for modest camera
12/// counts and is appropriate here for `K <= 2000`.
13/// **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
14///
15/// * [`ArrowSolverMode::SqrtBA`] ports Square-Root BA (Demmel/Gao/Gu et al.,
16/// CVPR 2021): Schur terms are formed as `(L_i^-1 H_tβ_i)^T
17/// (L_i^-1 H_tβ_i)` from the per-row square-root factor `L_i`, avoiding
18/// explicit `H_tt^-1 H_tβ` products. It is the preferred direct path when
19/// single-precision assembly is introduced or when row blocks are poorly
20/// conditioned.
21/// **GPU support: ✓** — requires dense H_ββ and dense per-row H_tβ slabs.
22///
23/// * [`ArrowSolverMode::InexactPCG`] ports "Bundle Adjustment in the Large"
24/// (Agarwal et al.): the Schur system is solved inexactly by PCG with a
25/// Jacobi Schur preconditioner, avoiding dense `K × K` factorization for
26/// SAE-manifold scale shared systems.
27/// **GPU support: CPU only** until the row-procedural H_tβ GPU PCG path
28/// (issue #288 Part B) is wired. The topology selector must not request
29/// `InexactPCG` via the GPU entry point; `solve_arrow_newton_step` returns
30/// `GpuRequiresDenseSystem` for matrix-free systems, and the wrapper in
31/// `solver/gpu/arrow_schur_gpu.rs` routes those to CPU InexactPCG
32/// automatically. At K ≥ 5000 the GPU PCG path will supersede the CPU path
33/// once the row-procedural H_tβ kernel and boxed GPU matvec backend in
34/// `run_pcg_with_preconditioner` are wired.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum ArrowSolverMode {
37 Direct,
38 SqrtBA,
39 InexactPCG,
40}
41
42impl ArrowSolverMode {
43 /// BA-size heuristic: dense RCS for modest `K`, inexact Schur PCG for
44 /// large shared systems. This follows Agarwal et al.'s direct-vs-iterative
45 /// split for large BA, mapped from cameras to decoder coefficients.
46 pub const fn automatic(k: usize) -> Self {
47 if k <= DIRECT_SOLVE_MAX_K {
48 Self::Direct
49 } else {
50 Self::InexactPCG
51 }
52 }
53
54 /// Square-Root BA is the direct-solve stability mode for future f32
55 /// callers. Large `K` still routes to inexact PCG because dense Schur
56 /// storage dominates precision concerns at that scale.
57 pub const fn automatic_for_single_precision(k: usize) -> Self {
58 if k <= DIRECT_SOLVE_MAX_K {
59 Self::SqrtBA
60 } else {
61 Self::InexactPCG
62 }
63 }
64}
65
66/// Reason the Steihaug-CG loop stopped.
67#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
68pub enum PcgStopReason {
69 /// Residual fell below the relative tolerance threshold.
70 #[default]
71 Converged,
72 /// Loop exhausted max_iterations without converging.
73 MaxIter,
74 /// Step hit the trust-region boundary (Steihaug boundary projection).
75 TrustRegion,
76 /// Negative curvature detected in an unbounded solve.
77 Indefinite,
78 /// Non-positive or non-finite preconditioned residual after an update.
79 Stagnation,
80}
81
82/// Per-solve instrumentation counters returned alongside the PCG solution.
83///
84/// All fields default to zero; callers that do not need diagnostics simply
85/// ignore the value. The struct is Copy so passing it through return tuples
86/// is zero-overhead.
87#[derive(Debug, Default, Clone, Copy)]
88pub struct PcgDiagnostics {
89 /// Number of CG iterations executed.
90 pub iterations: usize,
91 /// Total calls to the Schur matvec A·p.
92 pub matvec_calls: usize,
93 /// Total calls to the preconditioner M^{-1}·r.
94 pub precond_apply_calls: usize,
95 /// Number of times the LM ridge was escalated before a successful factor.
96 pub ridge_escalations: usize,
97 /// Relative residual at termination; 0.0 when the RHS was zero.
98 pub final_relative_residual: f64,
99 /// Why the loop stopped.
100 pub stopping_reason: PcgStopReason,
101 /// Mixed-precision certificate outcome for this solve.
102 pub mixed_precision_status: MixedPrecisionStatus,
103 /// True only when the reduced-Schur solve was **actually executed on the
104 /// device**: either the fully device-resident batched Arrow-Schur Direct
105 /// sequence (`try_device_arrow_direct` → `solve_arrow_newton_step`) or the
106 /// device-resident matrix-free SAE PCG (`solve_sae_matrix_free_pcg`, which
107 /// runs the matvec in CUDA kernels over device-resident frames). It is NOT
108 /// set merely because a GPU runtime exists and a dispatch gate fired (#1209).
109 pub used_device_arrow: bool,
110 /// True when a reduced-Schur matvec backend was injected through
111 /// `maybe_inject_gpu_schur_matvec` but the matvec itself runs as a
112 /// **host** (CPU Rust/Rayon) procedural closure — both the matrix-free
113 /// `build_row_procedural_matvec` branch and the `cuda::build_schur_matvec_backend`
114 /// branch return host closures that evaluate `Σ_i Y_iᵀ(Y_i x)` on the CPU,
115 /// even when a CUDA context was opened to build the per-row factors. This
116 /// path must NOT report `used_device_arrow`: the arithmetic is host-side
117 /// (#1209). Distinct field so perf accounting never mistakes a host
118 /// procedural matvec for true device execution.
119 pub injected_host_procedural_matvec: bool,
120}
121
122/// Outcome of an opt-in mixed-precision arrow solve.
123#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
124pub enum MixedPrecisionStatus {
125 /// The caller did not request mixed precision or this solve mode cannot use it.
126 #[default]
127 Off,
128 /// The f32 factor solve was refined until the f64 backward-error certificate held.
129 Certified { refinement_steps: usize },
130 /// The kappa gate or solve shape rejected mixed precision and the f64 path ran.
131 /// The declining reason is logged at `info` level when the fallback fires.
132 F64Fallback,
133}
134
135/// PCG controls for BA's inexact reduced-camera-system solve.
136///
137/// The defaults mirror the loose inner tolerances used by inexact-step LM in
138/// "Bundle Adjustment in the Large": solve the Schur system only accurately
139/// enough for a useful trust-region step, then let the outer LM iteration
140/// correct the remaining error.
141#[derive(Debug, Clone)]
142pub struct ArrowPcgOptions {
143 pub max_iterations: usize,
144 pub relative_tolerance: f64,
145}
146
147impl Default for ArrowPcgOptions {
148 fn default() -> Self {
149 Self {
150 max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
151 relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
152 }
153 }
154}
155
156/// Trust-region controls for Steihaug-CG on the reduced BA system.
157///
158/// This is the Ceres-style guard around LM: `ridge_t`/`ridge_beta` provide
159/// Levenberg damping, while the trust radius bounds the reduced shared step
160/// in Euclidean β coordinates using Steihaug's truncated-CG stopping rules for
161/// boundary hits and negative curvature.
162#[derive(Debug, Clone)]
163pub struct ArrowTrustRegionOptions {
164 pub radius: f64,
165 pub steihaug_relative_tolerance: f64,
166 pub max_iterations: usize,
167}
168
169impl Default for ArrowTrustRegionOptions {
170 fn default() -> Self {
171 Self {
172 radius: DEFAULT_TRUST_REGION_RADIUS,
173 steihaug_relative_tolerance: DEFAULT_PCG_RELATIVE_TOLERANCE,
174 max_iterations: DEFAULT_PCG_MAX_ITERATIONS,
175 }
176 }
177}
178
179/// Opt-in Carson--Higham mixed-precision refinement for dense arrow solves.
180///
181/// Default is [`ArrowSolvePrecisionPolicy::F64Only`]: exact f64 solves remain the default.
182/// [`ArrowSolvePrecisionPolicy::CertifiedMixed`] stores f32 copies of the per-row Cholesky
183/// factors and dense Schur factor, solves corrections in f32, and recomputes the
184/// residual in f64 against the original arrow blocks. The standard refinement
185/// certificate is the normwise backward error
186///
187/// `||r||_inf / (||H||_inf ||x||_inf + ||b||_inf) <= residual_relative_tolerance`.
188///
189/// The kappa gate enforces `kappa_estimate * u_f32 < kappa_unit_roundoff_margin`;
190/// when it fails, the solve reports [`MixedPrecisionStatus::F64Fallback`] and
191/// logs the reason before using the f64 path.
192#[derive(Debug, Clone, Copy, PartialEq)]
193pub enum ArrowSolvePrecisionPolicy {
194 F64Only,
195 CertifiedMixed {
196 max_refinement_steps: usize,
197 residual_relative_tolerance: f64,
198 kappa_unit_roundoff_margin: f64,
199 },
200}
201
202impl Default for ArrowSolvePrecisionPolicy {
203 fn default() -> Self {
204 Self::F64Only
205 }
206}
207
208impl ArrowSolvePrecisionPolicy {
209 pub fn certified_mixed() -> Self {
210 Self::CertifiedMixed {
211 max_refinement_steps: DEFAULT_MIXED_PRECISION_MAX_REFINEMENTS,
212 residual_relative_tolerance: DEFAULT_MIXED_PRECISION_CERTIFICATE_TOLERANCE,
213 kappa_unit_roundoff_margin: DEFAULT_MIXED_PRECISION_KAPPA_MARGIN,
214 }
215 }
216
217 pub(crate) fn is_enabled(self) -> bool {
218 matches!(self, ArrowSolvePrecisionPolicy::CertifiedMixed { .. })
219 }
220}
221
222/// Complete BA Schur solve options.
223///
224/// Use [`ArrowSolveOptions::automatic`] for normal latent-coordinate fits;
225/// use [`ArrowSolveOptions::sqrt_ba`] when the assembler has single-precision
226/// row blocks or an ill-conditioned gauge; use [`ArrowSolveOptions::inexact_pcg`]
227/// for SAE-manifold scale `K`.
228#[derive(Clone)]
229pub struct ArrowSolveOptions {
230 pub mode: ArrowSolverMode,
231 pub pcg: ArrowPcgOptions,
232 pub trust_region: ArrowTrustRegionOptions,
233 /// Row chunk size for streaming direct/Square-Root Schur assembly.
234 pub streaming_chunk_size: Option<usize>,
235 /// Use the Riemannian latent projection before the Schur reduction. The
236 /// reduced Steihaug solve itself remains in Euclidean β coordinates.
237 pub riemannian_trust_region: bool,
238 /// Optional GPU-backed Schur matvec for CPU-driven `InexactPCG` at K ≥ 5000.
239 ///
240 /// When set, `run_pcg_with_preconditioner` delegates each `S·p` call to
241 /// this closure instead of the CPU `schur_matvec`. Constructed by
242 /// `crate::gpu_kernels::arrow_schur::gpu_schur_matvec_backend` when `cuda_selected()`
243 /// and the system has dense per-row H_tβ slabs. `None` means CPU-only PCG.
244 pub gpu_matvec: Option<GpuSchurMatvec>,
245 /// Skip the ill-conditioning *rejection* (the κ-based
246 /// [`ArrowSchurError::PerRowFactorIllConditioned`] per-row guard and the
247 /// matching reduced-Schur κ guard) while still requiring genuine positive
248 /// definiteness (a non-PD Cholesky pivot still errors).
249 ///
250 /// The κ guards exist to protect the accuracy of the Newton *step*: a
251 /// barely-PD `H_tt^(i)` or an over-conditioned reduced Schur yields an
252 /// inaccurate `Δβ`/`Δt`. Evidence-only callers
253 /// (e.g. `SaeManifoldTerm::reml_criterion_with_cache`) do not consume the
254 /// step — they need only the factor cache for the log-determinant
255 /// (`½log|H|`, exact from `diag(L)` regardless of κ) and the selected-inverse
256 /// traces. For those callers the κ rejection is a false abort when ρ sweeps
257 /// to extreme values, so this flag lifts it and hands the
258 /// "is this step trustworthy" decision back to the caller.
259 ///
260 /// Default `false`: ordinary solves keep the full guard.
261 pub tolerate_ill_conditioning: bool,
262 /// Arrow solve precision policy. Default is f64-only.
263 pub solve_precision: ArrowSolvePrecisionPolicy,
264 /// Optional spectral positive-definiteness floor on the *reduced Schur
265 /// complement* `S = H_ββ + ridge_β·I − Σ_i H_tβ^(i)ᵀ (H_tt^(i))⁻¹ H_tβ^(i)`,
266 /// as a relative fraction of `S`'s largest eigenvalue.
267 ///
268 /// `None` (default) keeps the strict contract: a non-PD `S` errors as
269 /// [`ArrowSchurError::SchurFactorFailed`] so the LM outer loop lifts
270 /// `ridge_beta` globally and re-forms `S`.
271 ///
272 /// `Some(floor)` engages the #1026 SAE co-collapse cure on the SOLVE path:
273 /// when the reduced Schur Cholesky refuses (collapsed atoms drive a per-row
274 /// `H_tt` near-singular, so the accumulated `(H_tt)⁻¹` over-subtracts `S`
275 /// into an INDEFINITE matrix), instead of rejecting and over-damping every
276 /// β direction with a global ridge, symmetric-eigendecompose `S` and clamp
277 /// every eigenvalue UP to `floor·max(λ)`. This is Levenberg–Marquardt
278 /// restricted to exactly the indefinite/collapsed subspace: the
279 /// well-conditioned β directions (`λ ≫ floor·max λ`) are untouched and the
280 /// step in those directions is the exact Newton step, while only the
281 /// collapsed directions receive the minimal damping needed for a PD solve.
282 /// The inner Newton then makes a real descent step rather than crawling
283 /// behind an inflated global ridge. Mirrors the per-row spectral floor the
284 /// evidence path uses for #1377/#1117/#1118
285 /// ([`super::factorization::factor_spectral_deflated_evidence_row`]); the
286 /// difference is the floored value — a small positive `floor·max λ`
287 /// (Tikhonov) for the solve, vs unit stiffness `+1` (`log 1 = 0`) for the
288 /// evidence log-det.
289 ///
290 /// Only consulted by the dense Direct / SqrtBA reduced solve (the only
291 /// caller of [`super::reduced_solve::solve_dense_reduced_system`]); the
292 /// InexactPCG path is unaffected.
293 pub schur_pd_floor: Option<f64>,
294}
295
296impl std::fmt::Debug for ArrowSolveOptions {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("ArrowSolveOptions")
299 .field("mode", &self.mode)
300 .field("pcg", &self.pcg)
301 .field("trust_region", &self.trust_region)
302 .field("streaming_chunk_size", &self.streaming_chunk_size)
303 .field("riemannian_trust_region", &self.riemannian_trust_region)
304 .field("gpu_matvec", &self.gpu_matvec.is_some())
305 .field("tolerate_ill_conditioning", &self.tolerate_ill_conditioning)
306 .field("solve_precision", &self.solve_precision)
307 .field("schur_pd_floor", &self.schur_pd_floor)
308 .finish()
309 }
310}
311
312/// Globalization guard for non-convex arrow-Schur inner steps.
313///
314/// The raw Schur solve is exactly Newton. For non-convex analytic penalties,
315/// full Newton can cycle. This controller adds a proximal LM shift `mu I` to
316/// both blocks and accepts only Armijo-decreasing trial points.
317#[derive(Debug, Clone)]
318pub struct ArrowProximalCorrectionOptions {
319 pub initial_ridge: f64,
320 pub ridge_growth: f64,
321 pub max_attempts: usize,
322 pub armijo_c1: f64,
323 pub gradient_tolerance: f64,
324 /// Relative objective resolution below which the proximal correction
325 /// declares convergence instead of failing.
326 ///
327 /// Near a stationary point the largest decrease the damped Newton model can
328 /// still achieve shrinks to the floating-point resolution of the objective
329 /// itself: at proximal ridge `μ → μ_max` the accepted step length is
330 /// `O(‖g‖ / μ)`, so the realised change in the objective falls below
331 /// `rel_tol · (|f| + 1)`. At that scale the Armijo sufficient-decrease test
332 /// compares two values that differ only by rounding noise, and no further
333 /// productive decrease is achievable. Rather than raise
334 /// `AdaptiveCorrectionFailed`, the loop then returns the incumbent state
335 /// (a zero step) as converged. This does NOT mask genuine non-convergence:
336 /// it triggers only when every attempted step either fails to decrease the
337 /// objective by more than this resolution OR increases it by no more than
338 /// this resolution (pure rounding). A step that genuinely reduces the
339 /// objective is always taken first.
340 pub convergence_objective_rel_tol: f64,
341}
342
343impl Default for ArrowProximalCorrectionOptions {
344 fn default() -> Self {
345 Self {
346 initial_ridge: DEFAULT_PROXIMAL_INITIAL_RIDGE,
347 ridge_growth: DEFAULT_PROXIMAL_RIDGE_GROWTH,
348 max_attempts: DEFAULT_PROXIMAL_MAX_ATTEMPTS,
349 armijo_c1: DEFAULT_ARMIJO_C1,
350 gradient_tolerance: DEFAULT_GRADIENT_TOLERANCE,
351 convergence_objective_rel_tol: DEFAULT_PROXIMAL_CONVERGENCE_REL_TOL,
352 }
353 }
354}
355
356/// Accepted proximal arrow-Schur step and the damping that made it descent.
357#[derive(Debug, Clone)]
358pub struct ArrowAcceptedProximalStep {
359 pub delta_t: Array1<f64>,
360 pub delta_beta: Array1<f64>,
361 pub ridge_t: f64,
362 pub ridge_beta: f64,
363 pub proximal_ridge: f64,
364 pub objective_value: f64,
365 pub trial_objective_value: f64,
366 pub gradient_dot_step: f64,
367 pub attempts: usize,
368}
369
370impl ArrowSolveOptions {
371 /// Select Direct for `K <= 2000` and InexactPCG above, following BA RCS
372 /// practice for dense-vs-iterative reduced systems.
373 pub fn automatic(k: usize) -> Self {
374 Self {
375 mode: ArrowSolverMode::automatic(k),
376 pcg: ArrowPcgOptions::default(),
377 trust_region: ArrowTrustRegionOptions::default(),
378 streaming_chunk_size: None,
379 riemannian_trust_region: false,
380 gpu_matvec: None,
381 tolerate_ill_conditioning: false,
382 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
383 schur_pd_floor: None,
384 }
385 }
386
387 /// Force dense reduced-camera-system Cholesky, the classic BA direct
388 /// solve for small `K`.
389 pub fn direct() -> Self {
390 Self {
391 mode: ArrowSolverMode::Direct,
392 pcg: ArrowPcgOptions::default(),
393 trust_region: ArrowTrustRegionOptions::default(),
394 streaming_chunk_size: None,
395 riemannian_trust_region: false,
396 gpu_matvec: None,
397 tolerate_ill_conditioning: false,
398 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
399 schur_pd_floor: None,
400 }
401 }
402
403 /// Force Square-Root BA Schur assembly for the direct reduced solve.
404 pub fn sqrt_ba() -> Self {
405 Self {
406 mode: ArrowSolverMode::SqrtBA,
407 pcg: ArrowPcgOptions::default(),
408 trust_region: ArrowTrustRegionOptions::default(),
409 streaming_chunk_size: None,
410 riemannian_trust_region: false,
411 gpu_matvec: None,
412 tolerate_ill_conditioning: false,
413 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
414 schur_pd_floor: None,
415 }
416 }
417
418 /// Force inexact BA Schur PCG with Jacobi preconditioning.
419 pub fn inexact_pcg() -> Self {
420 Self {
421 mode: ArrowSolverMode::InexactPCG,
422 pcg: ArrowPcgOptions::default(),
423 trust_region: ArrowTrustRegionOptions::default(),
424 streaming_chunk_size: None,
425 riemannian_trust_region: false,
426 gpu_matvec: None,
427 tolerate_ill_conditioning: false,
428 solve_precision: ArrowSolvePrecisionPolicy::F64Only,
429 schur_pd_floor: None,
430 }
431 }
432
433 pub fn with_streaming_chunk_size(mut self, chunk_size: Option<usize>) -> Self {
434 self.streaming_chunk_size = chunk_size.filter(|&chunk| chunk > 0);
435 self
436 }
437
438 /// Lift the ill-conditioning *rejection* for evidence/log-det-only callers
439 /// while still requiring genuine PD. See [`Self::tolerate_ill_conditioning`].
440 ///
441 /// Use this when the returned `(Δt, Δβ)` Newton step is discarded and only
442 /// the factor cache is consumed (log-determinant + selected-inverse traces).
443 /// The cache stays undamped at `ridge_t = 0`, so the log-determinant is
444 /// exact regardless of κ.
445 pub fn with_ill_conditioning_tolerated(mut self) -> Self {
446 self.tolerate_ill_conditioning = true;
447 self
448 }
449
450 pub fn with_solve_precision_policy(mut self, policy: ArrowSolvePrecisionPolicy) -> Self {
451 self.solve_precision = policy;
452 self
453 }
454
455 /// Turn certified mixed precision ON for the streaming/residency reduced
456 /// solve unless the caller already pinned an explicit policy (#1014).
457 ///
458 /// Only `F64Only` (the inherited default) is upgraded to `CertifiedMixed`;
459 /// a caller that deliberately set a policy keeps it. The reduced-Schur f64
460 /// factor and every evidence log-determinant are unaffected — see
461 /// [`mixed_precision_reduced_beta`].
462 #[must_use]
463 pub fn with_streaming_solve_precision_default(&self) -> Self {
464 let mut out = self.clone();
465 if matches!(out.solve_precision, ArrowSolvePrecisionPolicy::F64Only) {
466 out.solve_precision = ArrowSolvePrecisionPolicy::certified_mixed();
467 }
468 out
469 }
470}
471
472/// CPU/GPU seam for BA point-block work.
473///
474/// BA systems spend most time in independent point-block factorizations,
475/// triangular solves, and Schur block products. MegBA maps exactly these
476/// operations to GPU kernels. This trait keeps that boundary explicit so a
477/// CUDA/Ceres backend can replace [`CpuBatchedBlockSolver`] without changing
478/// `ArrowSchurSystem` algebra.
479pub trait BatchedBlockSolver {
480 /// Factor every per-row point block `H_tt^(i) + ridge_t I`, as in BA's
481 /// point elimination stage.
482 ///
483 /// `tolerate_ill_conditioning` lifts the per-row κ rejection (still
484 /// requiring genuine PD); see [`ArrowSolveOptions::tolerate_ill_conditioning`].
485 fn factor_blocks(
486 &self,
487 rows: &[ArrowRowBlock],
488 ridge_t: f64,
489 d: usize,
490 tolerate_ill_conditioning: bool,
491 ) -> Result<ArrowFactorSlab, ArrowSchurError>;
492
493 /// Solve one factored point block against a vector RHS.
494 fn solve_block_vector(
495 &self,
496 factor: ArrayView2<'_, f64>,
497 rhs: ArrayView1<'_, f64>,
498 ) -> Array1<f64>;
499
500 /// Solve one factored point block against a dense matrix RHS.
501 fn solve_block_matrix(
502 &self,
503 factor: ArrayView2<'_, f64>,
504 rhs: ArrayView2<'_, f64>,
505 ) -> Array2<f64>;
506
507 /// Apply the Square-Root BA lower-triangular solve `L_i^-1 rhs`.
508 fn sqrt_solve_block_matrix(
509 &self,
510 factor: ArrayView2<'_, f64>,
511 rhs: ArrayView2<'_, f64>,
512 ) -> Array2<f64>;
513
514 /// Subtract a row-local Schur product from the dense reduced system.
515 fn block_gemm_subtract(&self, schur: &mut Array2<f64>, left: &Array2<f64>, right: &Array2<f64>);
516}
517
518#[derive(Debug, Clone)]
519pub struct ArrowRowGaugeDeflation {
520 pub directions: Arc<[Vec<Array1<f64>>]>,
521}
522
523impl ArrowRowGaugeDeflation {
524 pub fn new(directions: Vec<Vec<Array1<f64>>>) -> Self {
525 Self {
526 directions: Arc::from(directions.into_boxed_slice()),
527 }
528 }
529
530 pub(crate) fn row(&self, row: usize) -> &[Array1<f64>] {
531 self.directions.get(row).map(Vec::as_slice).unwrap_or(&[])
532 }
533}
534
535/// Current CPU implementation of the BA batched block interface.
536///
537/// It is intentionally plain Rust loops because `d` is tiny. The trait shape,
538/// not this implementation, is the load-bearing part for the future MegBA or
539/// Ceres backend.
540#[derive(Debug, Clone, Copy, Default)]
541pub struct CpuBatchedBlockSolver;
542
543impl BatchedBlockSolver for CpuBatchedBlockSolver {
544 fn factor_blocks(
545 &self,
546 rows: &[ArrowRowBlock],
547 ridge_t: f64,
548 d: usize,
549 tolerate_ill_conditioning: bool,
550 ) -> Result<ArrowFactorSlab, ArrowSchurError> {
551 // Multi-GPU fast path: the per-row blocks `H_tt^(i) + ridge_t·I` are
552 // independent same-size SPD systems — exactly the batch
553 // `gam_gpu::try_cholesky_batched_lower_inplace` spreads across ALL
554 // usable devices (the batched POTRF tiles over the pool). It is only
555 // valid when every row is the uniform `d×d` shape; heterogeneous row
556 // dimensions keep the per-row CPU loop because the current cuSOLVER
557 // batched POTRF wrapper accepts one `(d, d)` shape per launch. It only
558 // succeeds when EVERY block is PD at
559 // the base ridge; a non-PD block returns `None`, so we fall back to the
560 // exact per-row CPU path that performs minimal per-block ridge
561 // escalation. After a successful batched factorization we re-apply the
562 // identical κ-conditioning rejection `factor_one_row` enforces, so the
563 // result is bit-for-bit equivalent (modulo IEEE reduction order) to the
564 // CPU loop: a barely-PD but ill-conditioned block forces the whole batch
565 // back onto the per-row path so its ridge can lift, never silently using
566 // a contaminated factor.
567 if let Some(batched) =
568 try_factor_blocks_batched(rows, ridge_t, d, tolerate_ill_conditioning)
569 {
570 return Ok(batched);
571 }
572 let mut out = Vec::with_capacity(rows.len());
573 for (row_idx, row) in rows.iter().enumerate() {
574 out.push(factor_one_row(
575 row,
576 ridge_t,
577 d,
578 row_idx,
579 tolerate_ill_conditioning,
580 )?);
581 }
582 Ok(ArrowFactorSlab::from_blocks(out))
583 }
584
585 fn solve_block_vector(
586 &self,
587 factor: ArrayView2<'_, f64>,
588 rhs: ArrayView1<'_, f64>,
589 ) -> Array1<f64> {
590 match (factor.nrows(), factor.ncols(), rhs.len()) {
591 (1, 1, 1) => cholesky_solve_vector_fixed::<1>(factor, rhs),
592 (2, 2, 2) => cholesky_solve_vector_fixed::<2>(factor, rhs),
593 (3, 3, 3) => cholesky_solve_vector_fixed::<3>(factor, rhs),
594 (4, 4, 4) => cholesky_solve_vector_fixed::<4>(factor, rhs),
595 _ => cholesky_solve_vector(factor, rhs),
596 }
597 }
598
599 fn solve_block_matrix(
600 &self,
601 factor: ArrayView2<'_, f64>,
602 rhs: ArrayView2<'_, f64>,
603 ) -> Array2<f64> {
604 cholesky_solve_matrix(factor, rhs)
605 }
606
607 fn sqrt_solve_block_matrix(
608 &self,
609 factor: ArrayView2<'_, f64>,
610 rhs: ArrayView2<'_, f64>,
611 ) -> Array2<f64> {
612 forward_substitution_lower_matrix(factor, rhs)
613 }
614
615 fn block_gemm_subtract(
616 &self,
617 schur: &mut Array2<f64>,
618 left: &Array2<f64>,
619 right: &Array2<f64>,
620 ) {
621 // Performance: ndarray Array2 is row-major, so `right[[c, b]]` is
622 // unit-strided in `b`. The canonical (a, b, c) order produced
623 // strided reads of `left[[c, a]]` for every (a, b); reorder to
624 // (c, a, b) so the inner `b`-loop is contiguous in `right` and
625 // `left[[c, a]]` is hoisted out of the inner loop.
626 let k = schur.nrows();
627 let d = left.nrows();
628 assert_eq!(left.ncols(), k);
629 assert_eq!(right.ncols(), k);
630 assert_eq!(schur.ncols(), k);
631 for c in 0..d {
632 for a in 0..k {
633 let lca = left[[c, a]];
634 if lca == 0.0 {
635 continue;
636 }
637 for b in 0..k {
638 schur[[a, b]] -= lca * right[[c, b]];
639 }
640 }
641 }
642 }
643}