apex_solver/linalg/
implicit_schur.rs

1//! # Implicit Schur Complement Solver
2//!
3//! This module implements the **Implicit Schur Complement** method using matrix-free
4//! Preconditioned Conjugate Gradients (PCG) for bundle adjustment.
5//!
6//! ## Explicit vs Implicit Schur Complement
7//!
8//! **Implicit Schur:** This formulation never constructs the reduced camera matrix S
9//! explicitly. Instead, it solves the linear system using a matrix-free approach where
10//! only the matrix-vector product S·x is computed. This is highly memory-efficient for
11//! large-scale problems.
12//!
13//! **Explicit Schur:** The alternative formulation (see [`explicit_schur`](super::explicit_schur))
14//! physically constructs S = B - E C⁻¹ Eᵀ in memory and uses sparse Cholesky factorization.
15//!
16//! ## When to Use Implicit Schur
17//!
18//! - Very large bundle adjustment problems (> 10,000 cameras)
19//! - Memory-constrained environments
20//! - When iterative methods converge well (good preconditioning)
21//! - When the reduced camera system S is too large to store explicitly
22//!
23//! ## Algorithm
24//!
25//! 1. Form Schur complement implicitly: S = H_cc - H_cp * H_pp^{-1} * H_cp^T
26//! 2. Solve S*δc = g_reduced using PCG (matrix-free)
27//! 3. Back-substitute: δp = H_pp^{-1} * (g_p - H_cp^T * δc)
28//!
29//! ## Usage Example
30//!
31//! ```no_run
32//! # use apex_solver::linalg::{SchurSolverAdapter, SchurVariant, SchurPreconditioner};
33//! # use std::collections::HashMap;
34//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
35//! # let variables = HashMap::new();
36//! # let variable_index_map = HashMap::new();
37//! use apex_solver::linalg::{SchurSolverAdapter, SchurVariant, SchurPreconditioner};
38//!
39//! let mut solver = SchurSolverAdapter::new_with_structure_and_config(
40//!     &variables,
41//!     &variable_index_map,
42//!     SchurVariant::Iterative, // Implicit Schur with PCG
43//!     SchurPreconditioner::SchurJacobi, // Recommended for PCG
44//! )?;
45//! # Ok(())
46//! # }
47//! ```
48
49use super::explicit_schur::{SchurBlockStructure, SchurOrdering, SchurPreconditioner};
50use crate::core::problem::VariableEnum;
51use crate::linalg::{LinAlgError, LinAlgResult, StructuredSparseLinearSolver};
52use faer::Mat;
53use faer::sparse::{SparseColMat, Triplet};
54use nalgebra::{DMatrix, DVector, Matrix3};
55use rayon::prelude::*;
56use std::collections::HashMap;
57use std::ops::Mul;
58
59/// Iterative Schur complement solver using Preconditioned Conjugate Gradients
60#[derive(Debug, Clone)]
61pub struct IterativeSchurSolver {
62    block_structure: Option<SchurBlockStructure>,
63    ordering: SchurOrdering,
64
65    // CG parameters
66    max_cg_iterations: usize,
67    cg_tolerance: f64,
68
69    // Preconditioner type
70    preconditioner_type: SchurPreconditioner,
71
72    // Cached for matrix-vector products
73    landmark_block_inverses: Vec<Matrix3<f64>>,
74    hessian: Option<SparseColMat<usize, f64>>,
75    gradient: Option<Mat<f64>>,
76
77    // Workspace buffers for Schur operator (avoid repeated allocations)
78    workspace_lm: Vec<f64>,  // landmark DOF sized buffer
79    workspace_cam: Vec<f64>, // camera DOF sized buffer
80
81    // Visibility index: camera_block_idx -> Vec<landmark_block_idx>
82    // This avoids O(cameras * landmarks) iteration in preconditioner computation
83    camera_to_landmark_visibility: Vec<Vec<usize>>,
84}
85
86impl IterativeSchurSolver {
87    /// Create a new iterative Schur solver with default parameters
88    /// Default: Schur-Jacobi preconditioner, 500 max iterations, 1e-9 relative tolerance
89    /// These tighter settings match Ceres Solver behavior for accurate step computation.
90    pub fn new() -> Self {
91        Self {
92            block_structure: None,
93            ordering: SchurOrdering::default(),
94            max_cg_iterations: 500, // More iterations for large BA problems
95            cg_tolerance: 1e-9,     // Tighter tolerance for accurate steps
96            preconditioner_type: SchurPreconditioner::SchurJacobi,
97            landmark_block_inverses: Vec::new(),
98            hessian: None,
99            gradient: None,
100            workspace_lm: Vec::new(),
101            workspace_cam: Vec::new(),
102            camera_to_landmark_visibility: Vec::new(),
103        }
104    }
105
106    /// Create solver with custom CG parameters
107    pub fn with_cg_params(max_iterations: usize, tolerance: f64) -> Self {
108        Self {
109            block_structure: None,
110            ordering: SchurOrdering::default(),
111            max_cg_iterations: max_iterations,
112            cg_tolerance: tolerance,
113            preconditioner_type: SchurPreconditioner::SchurJacobi,
114            landmark_block_inverses: Vec::new(),
115            hessian: None,
116            gradient: None,
117            workspace_lm: Vec::new(),
118            workspace_cam: Vec::new(),
119            camera_to_landmark_visibility: Vec::new(),
120        }
121    }
122
123    /// Create solver with full configuration
124    pub fn with_config(
125        max_iterations: usize,
126        tolerance: f64,
127        preconditioner: SchurPreconditioner,
128    ) -> Self {
129        Self {
130            block_structure: None,
131            ordering: SchurOrdering::default(),
132            max_cg_iterations: max_iterations,
133            cg_tolerance: tolerance,
134            preconditioner_type: preconditioner,
135            landmark_block_inverses: Vec::new(),
136            hessian: None,
137            gradient: None,
138            workspace_lm: Vec::new(),
139            workspace_cam: Vec::new(),
140            camera_to_landmark_visibility: Vec::new(),
141        }
142    }
143
144    /// Initialize workspace buffers based on problem dimensions
145    fn init_workspaces(&mut self) {
146        if let Some(structure) = &self.block_structure {
147            let lm_dof = structure.landmark_dof;
148            let cam_dof = structure.camera_dof;
149
150            if self.workspace_lm.len() != lm_dof {
151                self.workspace_lm = vec![0.0; lm_dof];
152            }
153            if self.workspace_cam.len() != cam_dof {
154                self.workspace_cam = vec![0.0; cam_dof];
155            }
156        }
157    }
158
159    /// Apply Schur complement operator: S*x = (H_cc - H_cp * H_pp^{-1} * H_cp^T) * x
160    ///
161    /// This computes the matrix-vector product without explicitly forming S.
162    /// Uses workspace buffers to avoid allocations during PCG iterations.
163    fn apply_schur_operator_fast(
164        &self,
165        x: &Mat<f64>,
166        result: &mut Mat<f64>,
167        temp_lm: &mut [f64],
168        temp_cam: &mut [f64],
169    ) -> LinAlgResult<()> {
170        let structure = self
171            .block_structure
172            .as_ref()
173            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
174
175        let hessian = self
176            .hessian
177            .as_ref()
178            .ok_or_else(|| LinAlgError::InvalidInput("Hessian not computed".into()))?;
179
180        let symbolic = hessian.symbolic();
181        let (cam_start, cam_end) = structure.camera_col_range();
182        let (lm_start, lm_end) = structure.landmark_col_range();
183        let cam_dof = structure.camera_dof;
184
185        // Clear workspace buffers
186        temp_lm.iter_mut().for_each(|v| *v = 0.0);
187        temp_cam.iter_mut().for_each(|v| *v = 0.0);
188
189        // Fused Step 1+2: result = H_cc * x AND temp_lm = H_cp^T * x
190        // Process camera columns once, extracting both products
191        for col in cam_start..cam_end {
192            let local_col = col - cam_start;
193            let x_val = x[(local_col, 0)];
194            let row_indices = symbolic.row_idx_of_col_raw(col);
195            let col_values = hessian.val_of_col(col);
196
197            for (idx, &row) in row_indices.iter().enumerate() {
198                let val = col_values[idx];
199                if row >= cam_start && row < cam_end {
200                    // H_cc contribution
201                    let local_row = row - cam_start;
202                    result[(local_row, 0)] += val * x_val;
203                } else if row >= lm_start && row < lm_end {
204                    // H_cp^T contribution (camera col -> landmark row)
205                    let local_row = row - lm_start;
206                    temp_lm[local_row] += val * x_val;
207                }
208            }
209        }
210
211        // Step 3: Apply H_pp^{-1} in-place: temp_lm = H_pp^{-1} * temp_lm
212        for (block_idx, (_, start_col, _)) in structure.landmark_blocks.iter().enumerate() {
213            let inv_block = &self.landmark_block_inverses[block_idx];
214            let local_start = start_col - lm_start;
215
216            // Read input values
217            let in0 = temp_lm[local_start];
218            let in1 = temp_lm[local_start + 1];
219            let in2 = temp_lm[local_start + 2];
220
221            // Apply 3x3 inverse block
222            temp_lm[local_start] =
223                inv_block[(0, 0)] * in0 + inv_block[(0, 1)] * in1 + inv_block[(0, 2)] * in2;
224            temp_lm[local_start + 1] =
225                inv_block[(1, 0)] * in0 + inv_block[(1, 1)] * in1 + inv_block[(1, 2)] * in2;
226            temp_lm[local_start + 2] =
227                inv_block[(2, 0)] * in0 + inv_block[(2, 1)] * in1 + inv_block[(2, 2)] * in2;
228        }
229
230        // Step 4: temp_cam = H_cp * temp_lm (iterate over landmark columns)
231        for col in lm_start..lm_end {
232            let local_col = col - lm_start;
233            let lm_val = temp_lm[local_col];
234            let row_indices = symbolic.row_idx_of_col_raw(col);
235            let col_values = hessian.val_of_col(col);
236
237            for (idx, &row) in row_indices.iter().enumerate() {
238                if row >= cam_start && row < cam_end {
239                    let local_row = row - cam_start;
240                    temp_cam[local_row] += col_values[idx] * lm_val;
241                }
242            }
243        }
244
245        // Step 5: result = result - temp_cam = H_cc*x - H_cp*H_pp^{-1}*H_cp^T*x
246        for i in 0..cam_dof {
247            result[(i, 0)] -= temp_cam[i];
248        }
249
250        Ok(())
251    }
252
253    /// Extract H_cp^T and multiply with vector: (H_cp^T) * x
254    fn extract_camera_landmark_transpose_mvp(
255        &self,
256        hessian: &SparseColMat<usize, f64>,
257        x: &Mat<f64>,
258    ) -> LinAlgResult<Mat<f64>> {
259        let structure = self
260            .block_structure
261            .as_ref()
262            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
263        let (cam_start, cam_end) = structure.camera_col_range();
264        let (lm_start, lm_end) = structure.landmark_col_range();
265        let lm_dof = structure.landmark_dof;
266
267        let mut result = Mat::<f64>::zeros(lm_dof, 1);
268        let symbolic = hessian.symbolic();
269
270        // H_cp^T * x = iterate over camera columns, accumulate into landmark rows
271        for col in cam_start..cam_end {
272            let local_col = col - cam_start;
273            let row_indices = symbolic.row_idx_of_col_raw(col);
274            let col_values = hessian.val_of_col(col);
275
276            for (idx, &row) in row_indices.iter().enumerate() {
277                if row >= lm_start && row < lm_end {
278                    let local_row = row - lm_start;
279                    result[(local_row, 0)] += col_values[idx] * x[(local_col, 0)];
280                }
281            }
282        }
283
284        Ok(result)
285    }
286
287    /// Extract H_cp and multiply with vector
288    fn extract_camera_landmark_mvp(
289        &self,
290        hessian: &SparseColMat<usize, f64>,
291        x: &Mat<f64>,
292    ) -> LinAlgResult<Mat<f64>> {
293        let structure = self
294            .block_structure
295            .as_ref()
296            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
297        let (cam_start, cam_end) = structure.camera_col_range();
298        let (lm_start, lm_end) = structure.landmark_col_range();
299        let cam_dof = structure.camera_dof;
300
301        let mut result = Mat::<f64>::zeros(cam_dof, 1);
302        let symbolic = hessian.symbolic();
303
304        // H_cp * x: iterate over landmark columns
305        for col in lm_start..lm_end {
306            let local_col = col - lm_start;
307            let row_indices = symbolic.row_idx_of_col_raw(col);
308            let col_values = hessian.val_of_col(col);
309
310            for (idx, &row) in row_indices.iter().enumerate() {
311                if row >= cam_start && row < cam_end {
312                    let local_row = row - cam_start;
313                    result[(local_row, 0)] += col_values[idx] * x[(local_col, 0)];
314                }
315            }
316        }
317
318        Ok(result)
319    }
320
321    /// Apply H_pp^{-1} using cached block inverses
322    fn apply_landmark_inverse(&self, input: &Mat<f64>, output: &mut Mat<f64>) -> LinAlgResult<()> {
323        let structure = self
324            .block_structure
325            .as_ref()
326            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
327
328        for (block_idx, (_, start_col, _)) in structure.landmark_blocks.iter().enumerate() {
329            let inv_block = &self.landmark_block_inverses[block_idx];
330            let local_start = start_col - structure.landmark_col_range().0;
331
332            for i in 0..3 {
333                let mut sum = 0.0;
334                for j in 0..3 {
335                    sum += inv_block[(i, j)] * input[(local_start + j, 0)];
336                }
337                output[(local_start + i, 0)] = sum;
338            }
339        }
340
341        Ok(())
342    }
343
344    /// Compute block-Jacobi preconditioner: inverts camera diagonal blocks of H_cc only
345    ///
346    /// Instead of scalar diagonal (1/H_ii), this inverts the full camera blocks.
347    /// For cameras with 6 DOF (SE3), this creates 6×6 inverse blocks.
348    ///
349    /// NOTE: This is NOT the true Schur-Jacobi preconditioner. It only uses
350    /// diagonal blocks of H_cc, not the Schur complement S. For better convergence,
351    /// use `compute_schur_jacobi_preconditioner()` instead.
352    fn compute_block_preconditioner(&self) -> LinAlgResult<Vec<DMatrix<f64>>> {
353        let structure = self
354            .block_structure
355            .as_ref()
356            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
357
358        let hessian = self
359            .hessian
360            .as_ref()
361            .ok_or_else(|| LinAlgError::InvalidInput("Hessian not computed".into()))?;
362
363        let symbolic = hessian.symbolic();
364
365        let mut precond_blocks = Vec::with_capacity(structure.camera_blocks.len());
366
367        for (_, start_col, size) in &structure.camera_blocks {
368            // Extract the diagonal block for this camera
369            let mut block = DMatrix::<f64>::zeros(*size, *size);
370
371            for local_col in 0..*size {
372                let global_col = start_col + local_col;
373                let row_indices = symbolic.row_idx_of_col_raw(global_col);
374                let col_values = hessian.val_of_col(global_col);
375
376                for (idx, &global_row) in row_indices.iter().enumerate() {
377                    if global_row >= *start_col && global_row < start_col + size {
378                        let local_row = global_row - start_col;
379                        block[(local_row, local_col)] = col_values[idx];
380                    }
381                }
382            }
383
384            // Invert with regularization for numerical stability
385            let inv_block = match block.clone().try_inverse() {
386                Some(inv) => inv,
387                None => {
388                    // Add regularization and retry
389                    let reg = 1e-6 * block.diagonal().iter().sum::<f64>().abs() / *size as f64;
390                    let reg = reg.max(1e-8);
391                    for i in 0..*size {
392                        block[(i, i)] += reg;
393                    }
394                    block
395                        .try_inverse()
396                        .unwrap_or_else(|| DMatrix::identity(*size, *size))
397                }
398            };
399
400            precond_blocks.push(inv_block);
401        }
402
403        Ok(precond_blocks)
404    }
405
406    /// Apply block-Jacobi preconditioner: z = M^{-1} * r
407    ///
408    /// For each camera block, multiply by the inverse block matrix.
409    fn apply_block_preconditioner(
410        &self,
411        r: &Mat<f64>,
412        precond_blocks: &[DMatrix<f64>],
413    ) -> LinAlgResult<Mat<f64>> {
414        let structure = self
415            .block_structure
416            .as_ref()
417            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
418
419        let cam_dof = structure.camera_dof;
420        let mut z = Mat::<f64>::zeros(cam_dof, 1);
421        let cam_start = structure.camera_col_range().0;
422
423        for (block_idx, (_, start_col, size)) in structure.camera_blocks.iter().enumerate() {
424            let local_start = start_col - cam_start;
425            let inv_block = &precond_blocks[block_idx];
426
427            // Extract r block as DVector
428            let mut r_block = DVector::<f64>::zeros(*size);
429            for i in 0..*size {
430                r_block[i] = r[(local_start + i, 0)];
431            }
432
433            // Apply inverse: z_block = M^{-1} * r_block
434            let z_block = inv_block * r_block;
435
436            // Write back to z
437            for i in 0..*size {
438                z[(local_start + i, 0)] = z_block[i];
439            }
440        }
441
442        Ok(z)
443    }
444
445    /// Compute TRUE Schur-Jacobi preconditioner: diagonal blocks of the Schur complement S
446    ///
447    /// This is what Ceres Solver uses for SCHUR_JACOBI preconditioner.
448    ///
449    /// For each camera i:
450    ///   S[i,i] = H_cc[i,i] - Σ_j H_cp[i,j] * H_pp[j,j]^{-1} * H_cp[i,j]^T
451    ///
452    /// where the sum is over all landmarks j observed by camera i.
453    ///
454    /// This preconditioner captures the effect of point elimination on each camera block,
455    /// leading to much faster PCG convergence (typically 20-40 iterations vs 100+).
456    fn compute_schur_jacobi_preconditioner(&self) -> LinAlgResult<Vec<DMatrix<f64>>> {
457        let structure = self
458            .block_structure
459            .as_ref()
460            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
461
462        let hessian = self
463            .hessian
464            .as_ref()
465            .ok_or_else(|| LinAlgError::InvalidInput("Hessian not computed".into()))?;
466
467        let symbolic = hessian.symbolic();
468
469        // Borrow visibility index for use in parallel iterator
470        let visibility = &self.camera_to_landmark_visibility;
471
472        // Compute S[i,i] for each camera in parallel
473        // S[i,i] = H_cc[i,i] - Σ_j H_cp[i,j] * H_pp[j,j]^{-1} * H_cp[i,j]^T
474        // Using visibility index: only iterate over connected landmarks (O(observations) instead of O(cameras * landmarks))
475        let precond_blocks: Vec<DMatrix<f64>> = structure
476            .camera_blocks
477            .par_iter()
478            .enumerate()
479            .map(|(cam_idx, (_, cam_col_start, cam_size))| {
480                // Step 1: Extract H_cc[i,i] diagonal block
481                let mut s_ii = DMatrix::<f64>::zeros(*cam_size, *cam_size);
482
483                for local_col in 0..*cam_size {
484                    let global_col = cam_col_start + local_col;
485                    let row_indices = symbolic.row_idx_of_col_raw(global_col);
486                    let col_values = hessian.val_of_col(global_col);
487
488                    for (idx, &global_row) in row_indices.iter().enumerate() {
489                        if global_row >= *cam_col_start && global_row < cam_col_start + cam_size {
490                            let local_row = global_row - cam_col_start;
491                            s_ii[(local_row, local_col)] = col_values[idx];
492                        }
493                    }
494                }
495
496                // Step 2: For each landmark OBSERVED by this camera (visibility-indexed)
497                // This is the key optimization: O(avg_landmarks_per_camera) instead of O(all_landmarks)
498                let visible_landmarks = if cam_idx < visibility.len() {
499                    &visibility[cam_idx]
500                } else {
501                    &Vec::new() as &Vec<usize>
502                };
503
504                for &lm_block_idx in visible_landmarks {
505                    if lm_block_idx >= structure.landmark_blocks.len() {
506                        continue;
507                    }
508                    let (_, lm_col_start, _) = &structure.landmark_blocks[lm_block_idx];
509
510                    // Extract H_cp[i,j] block (cam_size x 3)
511                    let mut h_cp = DMatrix::<f64>::zeros(*cam_size, 3);
512
513                    for col_offset in 0..3 {
514                        let global_col = lm_col_start + col_offset;
515                        let row_indices = symbolic.row_idx_of_col_raw(global_col);
516                        let col_values = hessian.val_of_col(global_col);
517
518                        for (idx, &global_row) in row_indices.iter().enumerate() {
519                            if global_row >= *cam_col_start && global_row < cam_col_start + cam_size
520                            {
521                                let local_row = global_row - cam_col_start;
522                                h_cp[(local_row, col_offset)] = col_values[idx];
523                            }
524                        }
525                    }
526
527                    // Get H_pp[j,j]^{-1} from cached inverses
528                    let hpp_inv = &self.landmark_block_inverses[lm_block_idx];
529
530                    // Compute contribution: H_cp * H_pp^{-1} * H_cp^T
531                    // First: temp = H_cp * H_pp^{-1} (cam_size x 3)
532                    let mut temp = DMatrix::<f64>::zeros(*cam_size, 3);
533                    for i in 0..*cam_size {
534                        for j in 0..3 {
535                            let mut sum = 0.0;
536                            for k in 0..3 {
537                                sum += h_cp[(i, k)] * hpp_inv[(k, j)];
538                            }
539                            temp[(i, j)] = sum;
540                        }
541                    }
542
543                    // Then: contribution = temp * H_cp^T (cam_size x cam_size)
544                    for i in 0..*cam_size {
545                        for j in 0..*cam_size {
546                            let mut sum = 0.0;
547                            for k in 0..3 {
548                                sum += temp[(i, k)] * h_cp[(j, k)];
549                            }
550                            s_ii[(i, j)] -= sum;
551                        }
552                    }
553                }
554
555                // Step 3: Invert S[i,i] with regularization if needed
556                match s_ii.clone().try_inverse() {
557                    Some(inv) => inv,
558                    None => {
559                        // Add regularization and retry
560                        let trace = s_ii.trace();
561                        let reg = (1e-6 * trace.abs() / *cam_size as f64).max(1e-8);
562                        for i in 0..*cam_size {
563                            s_ii[(i, i)] += reg;
564                        }
565                        s_ii.try_inverse()
566                            .unwrap_or_else(|| DMatrix::identity(*cam_size, *cam_size))
567                    }
568                }
569            })
570            .collect();
571
572        Ok(precond_blocks)
573    }
574
575    /// Solve S*x = b using Preconditioned Conjugate Gradients with block preconditioner
576    /// Uses optimized Schur operator with workspace buffers to minimize allocations.
577    fn solve_pcg_block(
578        &self,
579        b: &Mat<f64>,
580        precond_blocks: &[DMatrix<f64>],
581        workspace_lm: &mut [f64],
582        workspace_cam: &mut [f64],
583    ) -> LinAlgResult<Mat<f64>> {
584        let cam_dof = b.nrows();
585        let mut x = Mat::<f64>::zeros(cam_dof, 1);
586
587        // r = b - S*x (x starts at 0, so r = b)
588        let mut r = b.clone();
589
590        // z = M^{-1} * r (block preconditioner)
591        let mut z = self.apply_block_preconditioner(&r, precond_blocks)?;
592
593        let mut p = z.clone();
594        let mut rz_old = 0.0;
595        for i in 0..cam_dof {
596            rz_old += r[(i, 0)] * z[(i, 0)];
597        }
598
599        // Compute initial residual norm for relative convergence
600        let b_norm: f64 = (0..cam_dof)
601            .map(|i| b[(i, 0)] * b[(i, 0)])
602            .sum::<f64>()
603            .sqrt();
604        let tol = self.cg_tolerance * b_norm.max(1.0);
605
606        // Ap buffer (reused each iteration)
607        let mut ap = Mat::<f64>::zeros(cam_dof, 1);
608
609        for iter in 0..self.max_cg_iterations {
610            // Ap = S * p (using fast operator with workspace buffers)
611            // Reset ap to zeros
612            for i in 0..cam_dof {
613                ap[(i, 0)] = 0.0;
614            }
615            self.apply_schur_operator_fast(&p, &mut ap, workspace_lm, workspace_cam)?;
616
617            // alpha = (r^T z) / (p^T Ap)
618            let mut p_ap = 0.0;
619            for i in 0..cam_dof {
620                p_ap += p[(i, 0)] * ap[(i, 0)];
621            }
622
623            if p_ap.abs() < 1e-20 {
624                tracing::debug!("PCG: p^T*A*p near zero at iteration {}", iter);
625                break;
626            }
627
628            let alpha = rz_old / p_ap;
629
630            // x = x + alpha * p
631            for i in 0..cam_dof {
632                x[(i, 0)] += alpha * p[(i, 0)];
633            }
634
635            // r = r - alpha * Ap
636            for i in 0..cam_dof {
637                r[(i, 0)] -= alpha * ap[(i, 0)];
638            }
639
640            // Check convergence
641            let r_norm: f64 = (0..cam_dof)
642                .map(|i| r[(i, 0)] * r[(i, 0)])
643                .sum::<f64>()
644                .sqrt();
645
646            if r_norm < tol {
647                tracing::debug!(
648                    "PCG converged in {} iterations (residual={:.2e})",
649                    iter + 1,
650                    r_norm
651                );
652                break;
653            }
654
655            // z = M^{-1} * r (block preconditioner)
656            z = self.apply_block_preconditioner(&r, precond_blocks)?;
657
658            // beta = (r_{k+1}^T z_{k+1}) / (r_k^T z_k)
659            let mut rz_new = 0.0;
660            for i in 0..cam_dof {
661                rz_new += r[(i, 0)] * z[(i, 0)];
662            }
663
664            if rz_old.abs() < 1e-30 {
665                break;
666            }
667
668            let beta = rz_new / rz_old;
669
670            // p = z + beta * p
671            for i in 0..cam_dof {
672                p[(i, 0)] = z[(i, 0)] + beta * p[(i, 0)];
673            }
674
675            rz_old = rz_new;
676        }
677
678        Ok(x)
679    }
680
681    /// Extract 3x3 diagonal blocks from H_pp and invert them with numerical robustness
682    ///
683    /// This function uses parallel processing for the block inversions (156K+ blocks).
684    /// Each block's condition number is checked and regularization applied as needed.
685    fn invert_landmark_blocks(&mut self, hessian: &SparseColMat<usize, f64>) -> LinAlgResult<()> {
686        let structure = self
687            .block_structure
688            .as_ref()
689            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
690
691        let symbolic = hessian.symbolic();
692
693        // Step 1: Extract all 3x3 blocks (sequential - requires sparse matrix access)
694        let blocks: Vec<(usize, Matrix3<f64>)> = structure
695            .landmark_blocks
696            .iter()
697            .enumerate()
698            .map(|(i, (_, start_col, _))| {
699                let mut block = Matrix3::<f64>::zeros();
700
701                for local_col in 0..3 {
702                    let global_col = start_col + local_col;
703                    let row_indices = symbolic.row_idx_of_col_raw(global_col);
704                    let col_values = hessian.val_of_col(global_col);
705
706                    for (idx, &row) in row_indices.iter().enumerate() {
707                        if row >= *start_col && row < start_col + 3 {
708                            let local_row = row - start_col;
709                            block[(local_row, local_col)] = col_values[idx];
710                        }
711                    }
712                }
713
714                (i, block)
715            })
716            .collect();
717
718        // Step 2: Invert all blocks in parallel
719        // Thresholds for numerical robustness
720        const CONDITION_THRESHOLD: f64 = 1e10;
721        const MIN_EIGENVALUE_THRESHOLD: f64 = 1e-12;
722        const REGULARIZATION_SCALE: f64 = 1e-6;
723
724        let results: Vec<Result<Matrix3<f64>, (usize, String)>> = blocks
725            .par_iter()
726            .map(|(i, block)| {
727                // Check conditioning and apply regularization if needed
728                let eigenvalues = block.symmetric_eigenvalues();
729                let min_ev = eigenvalues.min();
730                let max_ev = eigenvalues.max();
731
732                if min_ev < MIN_EIGENVALUE_THRESHOLD {
733                    // Severely ill-conditioned: add strong regularization
734                    let reg = REGULARIZATION_SCALE + max_ev * REGULARIZATION_SCALE;
735                    let regularized = block + Matrix3::identity() * reg;
736                    regularized.try_inverse().ok_or_else(|| {
737                        (
738                            *i,
739                            format!("singular even with regularization (min_ev={:.2e})", min_ev),
740                        )
741                    })
742                } else if max_ev / min_ev > CONDITION_THRESHOLD {
743                    // Ill-conditioned: add moderate regularization
744                    let extra_reg = max_ev * REGULARIZATION_SCALE;
745                    let regularized = block + Matrix3::identity() * extra_reg;
746                    regularized.try_inverse().ok_or_else(|| {
747                        (
748                            *i,
749                            format!("ill-conditioned (cond={:.2e})", max_ev / min_ev),
750                        )
751                    })
752                } else {
753                    // Well-conditioned: standard inversion
754                    block
755                        .try_inverse()
756                        .ok_or_else(|| (*i, "singular".to_string()))
757                }
758            })
759            .collect();
760
761        // Step 3: Collect results and check for errors
762        self.landmark_block_inverses.clear();
763        self.landmark_block_inverses.reserve(results.len());
764
765        for result in results {
766            match result {
767                Ok(inv) => self.landmark_block_inverses.push(inv),
768                Err((i, msg)) => {
769                    return Err(LinAlgError::SingularMatrix(format!(
770                        "Landmark block {} {}",
771                        i, msg
772                    )));
773                }
774            }
775        }
776
777        Ok(())
778    }
779
780    /// Build camera->landmark visibility index from H_cp structure
781    ///
782    /// This scans the Hessian to find which landmarks each camera observes,
783    /// enabling O(observations) preconditioner computation instead of O(cameras * landmarks).
784    fn build_visibility_index(&mut self, hessian: &SparseColMat<usize, f64>) -> LinAlgResult<()> {
785        let structure = self
786            .block_structure
787            .as_ref()
788            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
789
790        let symbolic = hessian.symbolic();
791        let (cam_start, cam_end) = structure.camera_col_range();
792        let num_cameras = structure.camera_blocks.len();
793
794        // Build a map from global camera row -> camera block index
795        let mut cam_row_to_block: HashMap<usize, usize> = HashMap::new();
796        for (cam_idx, (_, start_col, size)) in structure.camera_blocks.iter().enumerate() {
797            for offset in 0..*size {
798                cam_row_to_block.insert(start_col + offset, cam_idx);
799            }
800        }
801
802        // Initialize visibility: one vec per camera
803        let mut visibility: Vec<Vec<usize>> = vec![Vec::new(); num_cameras];
804
805        // Scan landmark columns to find camera connections
806        for (lm_block_idx, (_, lm_col_start, _)) in structure.landmark_blocks.iter().enumerate() {
807            // Check first column of this landmark block (all 3 columns have same row pattern)
808            let global_col = *lm_col_start;
809            if global_col >= hessian.ncols() {
810                continue;
811            }
812
813            let row_indices = symbolic.row_idx_of_col_raw(global_col);
814
815            // Find which cameras observe this landmark
816            for &row in row_indices {
817                if row >= cam_start
818                    && row < cam_end
819                    && let Some(&cam_idx) = cam_row_to_block.get(&row)
820                {
821                    // Only add if not already present (avoid duplicates)
822                    if visibility[cam_idx].last() != Some(&lm_block_idx) {
823                        visibility[cam_idx].push(lm_block_idx);
824                    }
825                }
826            }
827        }
828
829        self.camera_to_landmark_visibility = visibility;
830        Ok(())
831    }
832
833    /// Internal solve using the already-cached Hessian and gradient.
834    /// This avoids rebuilding the Hessian which would lose the damping from solve_augmented_equation.
835    fn solve_with_cached_hessian(&mut self) -> LinAlgResult<Mat<f64>> {
836        let hessian = self
837            .hessian
838            .as_ref()
839            .ok_or_else(|| LinAlgError::InvalidInput("Hessian not cached".into()))?
840            .clone();
841        let gradient = self
842            .gradient
843            .as_ref()
844            .ok_or_else(|| LinAlgError::InvalidInput("Gradient not cached".into()))?
845            .clone();
846
847        // Extract structure info
848        let structure = self
849            .block_structure
850            .as_ref()
851            .ok_or_else(|| LinAlgError::InvalidInput("Block structure not initialized".into()))?;
852        let cam_dof = structure.camera_dof;
853        let lm_dof = structure.landmark_dof;
854        let (cam_start, _cam_end) = structure.camera_col_range();
855        let (lm_start, _lm_end) = structure.landmark_col_range();
856
857        // Invert landmark blocks
858        self.invert_landmark_blocks(&hessian)?;
859
860        // Build visibility index for efficient preconditioner computation
861        self.build_visibility_index(&hessian)?;
862
863        // Extract reduced RHS: g_c - H_cp * H_pp^{-1} * g_p
864        let mut g_reduced = Mat::<f64>::zeros(cam_dof, 1);
865        for i in 0..cam_dof {
866            g_reduced[(i, 0)] = gradient[(cam_start + i, 0)];
867        }
868
869        let mut g_lm = Mat::<f64>::zeros(lm_dof, 1);
870        for i in 0..lm_dof {
871            g_lm[(i, 0)] = gradient[(lm_start + i, 0)];
872        }
873
874        let mut temp = Mat::<f64>::zeros(lm_dof, 1);
875        self.apply_landmark_inverse(&g_lm, &mut temp)?;
876
877        let correction = self.extract_camera_landmark_mvp(&hessian, &temp)?;
878        for i in 0..cam_dof {
879            g_reduced[(i, 0)] -= correction[(i, 0)];
880        }
881
882        // Initialize workspace buffers if needed
883        self.init_workspaces();
884
885        // Solve S*δc = g_reduced using PCG with appropriate preconditioner
886        let precond_blocks = match self.preconditioner_type {
887            SchurPreconditioner::SchurJacobi => {
888                // True Schur-Jacobi: diagonal blocks of S (Ceres-style, best convergence)
889                self.compute_schur_jacobi_preconditioner()?
890            }
891            SchurPreconditioner::BlockDiagonal => {
892                // Block diagonal of H_cc only (faster to compute, worse convergence)
893                self.compute_block_preconditioner()?
894            }
895            SchurPreconditioner::None => {
896                // Identity preconditioner (for debugging)
897                let structure = self.block_structure.as_ref().ok_or_else(|| {
898                    LinAlgError::InvalidInput("Block structure not initialized".into())
899                })?;
900                structure
901                    .camera_blocks
902                    .iter()
903                    .map(|(_, _, size)| DMatrix::identity(*size, *size))
904                    .collect()
905            }
906        };
907
908        // Use workspace buffers for PCG iterations
909        let mut workspace_lm = std::mem::take(&mut self.workspace_lm);
910        let mut workspace_cam = std::mem::take(&mut self.workspace_cam);
911
912        let delta_cam = self.solve_pcg_block(
913            &g_reduced,
914            &precond_blocks,
915            &mut workspace_lm,
916            &mut workspace_cam,
917        )?;
918
919        // Restore workspace buffers
920        self.workspace_lm = workspace_lm;
921        self.workspace_cam = workspace_cam;
922
923        // Back-substitute for landmarks
924        let hcp_t_delta_cam = self.extract_camera_landmark_transpose_mvp(&hessian, &delta_cam)?;
925
926        let mut rhs_lm = Mat::<f64>::zeros(lm_dof, 1);
927        for i in 0..lm_dof {
928            rhs_lm[(i, 0)] = g_lm[(i, 0)] - hcp_t_delta_cam[(i, 0)];
929        }
930
931        let mut delta_lm = Mat::<f64>::zeros(lm_dof, 1);
932        self.apply_landmark_inverse(&rhs_lm, &mut delta_lm)?;
933
934        // Combine camera and landmark updates
935        let total_dof = cam_dof + lm_dof;
936        let mut delta = Mat::<f64>::zeros(total_dof, 1);
937
938        for i in 0..cam_dof {
939            delta[(cam_start + i, 0)] = delta_cam[(i, 0)];
940        }
941        for i in 0..lm_dof {
942            delta[(lm_start + i, 0)] = delta_lm[(i, 0)];
943        }
944
945        Ok(delta)
946    }
947}
948
949impl Default for IterativeSchurSolver {
950    fn default() -> Self {
951        Self::new()
952    }
953}
954
955impl StructuredSparseLinearSolver for IterativeSchurSolver {
956    fn initialize_structure(
957        &mut self,
958        variables: &HashMap<String, VariableEnum>,
959        variable_index_map: &HashMap<String, usize>,
960    ) -> LinAlgResult<()> {
961        let mut structure = SchurBlockStructure::new();
962
963        for (name, variable) in variables {
964            let manifold_type = variable.manifold_type();
965            let start_col = *variable_index_map.get(name).ok_or_else(|| {
966                LinAlgError::InvalidInput(format!("Variable {} not in index map", name))
967            })?;
968            let size = variable.get_size();
969
970            if self.ordering.should_eliminate(name, &manifold_type, size) {
971                structure
972                    .landmark_blocks
973                    .push((name.clone(), start_col, size));
974                structure.landmark_dof += size;
975
976                if size != 3 {
977                    return Err(LinAlgError::InvalidInput(format!(
978                        "Landmark {} has DOF {}, expected 3",
979                        name, size
980                    )));
981                }
982                structure.num_landmarks += 1;
983            } else {
984                structure
985                    .camera_blocks
986                    .push((name.clone(), start_col, size));
987                structure.camera_dof += size;
988            }
989        }
990
991        structure.camera_blocks.sort_by_key(|(_, col, _)| *col);
992        structure.landmark_blocks.sort_by_key(|(_, col, _)| *col);
993
994        if structure.camera_blocks.is_empty() {
995            return Err(LinAlgError::InvalidInput(
996                "No camera variables found".into(),
997            ));
998        }
999        if structure.landmark_blocks.is_empty() {
1000            return Err(LinAlgError::InvalidInput(
1001                "No landmark variables found".into(),
1002            ));
1003        }
1004
1005        self.block_structure = Some(structure);
1006        Ok(())
1007    }
1008
1009    fn solve_normal_equation(
1010        &mut self,
1011        residuals: &Mat<f64>,
1012        jacobian: &SparseColMat<usize, f64>,
1013    ) -> LinAlgResult<Mat<f64>> {
1014        // Build H = J^T * J, g = -J^T * r
1015        let jt = jacobian
1016            .transpose()
1017            .to_col_major()
1018            .map_err(|e| LinAlgError::MatrixConversion(format!("Transpose failed: {:?}", e)))?;
1019        let hessian = jt.mul(jacobian);
1020        let jtr = jacobian.transpose().mul(residuals);
1021        let mut gradient = Mat::<f64>::zeros(jtr.nrows(), 1);
1022        for i in 0..jtr.nrows() {
1023            gradient[(i, 0)] = -jtr[(i, 0)];
1024        }
1025
1026        self.hessian = Some(hessian);
1027        self.gradient = Some(gradient);
1028
1029        // Solve using the cached Hessian
1030        self.solve_with_cached_hessian()
1031    }
1032
1033    fn solve_augmented_equation(
1034        &mut self,
1035        residuals: &Mat<f64>,
1036        jacobian: &SparseColMat<usize, f64>,
1037        lambda: f64,
1038    ) -> LinAlgResult<Mat<f64>> {
1039        // Build H = J^T * J + λI
1040        let jt = jacobian
1041            .transpose()
1042            .to_col_major()
1043            .map_err(|e| LinAlgError::MatrixConversion(format!("Transpose failed: {:?}", e)))?;
1044        let jtr = jt.mul(residuals);
1045        let mut hessian = jacobian
1046            .transpose()
1047            .to_col_major()
1048            .map_err(|e| LinAlgError::MatrixConversion(format!("Transpose failed: {:?}", e)))?
1049            .mul(jacobian);
1050
1051        // Add damping to diagonal
1052        let n = hessian.ncols();
1053        let symbolic = hessian.symbolic();
1054        let mut triplets = Vec::new();
1055
1056        for col in 0..n {
1057            let row_indices = symbolic.row_idx_of_col_raw(col);
1058            let col_values = hessian.val_of_col(col);
1059
1060            for (idx, &row) in row_indices.iter().enumerate() {
1061                triplets.push(Triplet::new(row, col, col_values[idx]));
1062            }
1063
1064            // Add lambda to diagonal
1065            triplets.push(Triplet::new(col, col, lambda));
1066        }
1067
1068        hessian = SparseColMat::try_new_from_triplets(n, n, &triplets).map_err(|e| {
1069            LinAlgError::InvalidInput(format!("Failed to build damped Hessian: {:?}", e))
1070        })?;
1071
1072        let mut gradient = Mat::<f64>::zeros(jtr.nrows(), 1);
1073        for i in 0..jtr.nrows() {
1074            gradient[(i, 0)] = -jtr[(i, 0)];
1075        }
1076
1077        self.hessian = Some(hessian);
1078        self.gradient = Some(gradient.clone());
1079
1080        // Solve using the cached damped Hessian (don't call solve_normal_equation
1081        // which would rebuild the Hessian without damping)
1082        self.solve_with_cached_hessian()
1083    }
1084
1085    fn get_hessian(&self) -> Option<&SparseColMat<usize, f64>> {
1086        self.hessian.as_ref()
1087    }
1088
1089    fn get_gradient(&self) -> Option<&Mat<f64>> {
1090        self.gradient.as_ref()
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    #[test]
1099    fn test_iterative_schur_creation() {
1100        let solver = IterativeSchurSolver::new();
1101        // Default: 500 max iterations, 1e-9 tolerance, Schur-Jacobi preconditioner
1102        assert_eq!(solver.max_cg_iterations, 500);
1103        assert_eq!(solver.cg_tolerance, 1e-9);
1104        assert_eq!(solver.preconditioner_type, SchurPreconditioner::SchurJacobi);
1105    }
1106
1107    #[test]
1108    fn test_with_custom_params() {
1109        let solver = IterativeSchurSolver::with_cg_params(100, 1e-8);
1110        assert_eq!(solver.max_cg_iterations, 100);
1111        assert_eq!(solver.cg_tolerance, 1e-8);
1112        // Should still use default Schur-Jacobi preconditioner
1113        assert_eq!(solver.preconditioner_type, SchurPreconditioner::SchurJacobi);
1114    }
1115
1116    #[test]
1117    fn test_with_full_config() {
1118        let solver =
1119            IterativeSchurSolver::with_config(200, 1e-10, SchurPreconditioner::BlockDiagonal);
1120        assert_eq!(solver.max_cg_iterations, 200);
1121        assert_eq!(solver.cg_tolerance, 1e-10);
1122        assert_eq!(
1123            solver.preconditioner_type,
1124            SchurPreconditioner::BlockDiagonal
1125        );
1126    }
1127}