Skip to main content

apex_solver/linalg/
mod.rs

1pub mod dense;
2pub mod sparse;
3pub mod utils;
4
5use crate::core::problem::VariableEnum;
6use faer::Mat;
7use std::{
8    collections::HashMap,
9    fmt::{self, Debug, Display, Formatter},
10};
11use thiserror::Error;
12
13pub use sparse::{
14    IterativeSchurSolver, SchurBlockStructure, SchurOrdering, SchurPreconditioner, SchurVariant,
15    SparseCholeskySolver, SparseQRSolver, SparseSchurComplementSolver,
16};
17
18pub use dense::{DenseCholeskySolver, DenseQRSolver};
19
20pub use crate::linearizer::cpu::{DenseMode, LinearizationMode, SparseMode};
21
22// ============================================================================
23// Jacobian mode selection
24// ============================================================================
25
26/// Controls which Jacobian assembly strategy the Problem uses.
27///
28/// Set this when constructing a [`Problem`](crate::core::problem::Problem):
29/// - `Problem::new(JacobianMode::Sparse)` — sparse (default, best for large-scale problems)
30/// - `Problem::new(JacobianMode::Dense)` — dense (best for small-to-medium problems < ~500 DOF)
31/// - `Problem::default()` — equivalent to `JacobianMode::Sparse`
32///
33/// The optimizer reads this field and dispatches to the appropriate assembly path.
34/// `LinearSolverType` selects the specific algorithm within the sparse path.
35#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
36pub enum JacobianMode {
37    /// Sparse Jacobian using symbolic structure and `SparseColMat`. Best for large problems.
38    #[default]
39    Sparse,
40    /// Dense Jacobian using `Mat<f64>`. Best for small-to-medium problems (< ~500 DOF).
41    Dense,
42}
43
44// ============================================================================
45// Linear solver type selection
46// ============================================================================
47
48#[non_exhaustive]
49#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
50pub enum LinearSolverType {
51    #[default]
52    SparseCholesky,
53    SparseQR,
54    SparseSchurComplement,
55    DenseCholesky,
56    DenseQR,
57}
58
59impl Display for LinearSolverType {
60    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
61        match self {
62            LinearSolverType::SparseCholesky => write!(f, "Sparse Cholesky"),
63            LinearSolverType::SparseQR => write!(f, "Sparse QR"),
64            LinearSolverType::SparseSchurComplement => write!(f, "Sparse Schur Complement"),
65            LinearSolverType::DenseCholesky => write!(f, "Dense Cholesky"),
66            LinearSolverType::DenseQR => write!(f, "Dense QR"),
67        }
68    }
69}
70
71// ============================================================================
72// Error types
73// ============================================================================
74
75/// Linear algebra specific error types for apex-solver
76#[derive(Debug, Clone, Error)]
77pub enum LinAlgError {
78    /// Matrix factorization failed (Cholesky, QR, etc.)
79    #[error("Matrix factorization failed: {0}")]
80    FactorizationFailed(String),
81
82    /// Singular or near-singular matrix detected
83    #[error("Singular matrix detected: {0}")]
84    SingularMatrix(String),
85
86    /// Failed to create sparse matrix from triplets
87    #[error("Failed to create sparse matrix: {0}")]
88    SparseMatrixCreation(String),
89
90    /// Matrix format conversion failed
91    #[error("Matrix conversion failed: {0}")]
92    MatrixConversion(String),
93
94    /// Invalid input provided to linear solver
95    #[error("Invalid input: {0}")]
96    InvalidInput(String),
97
98    /// Solver in invalid state (e.g., initialized incorrectly)
99    #[error("Invalid solver state: {0}")]
100    InvalidState(String),
101}
102
103/// Result type for linear algebra operations
104pub type LinAlgResult<T> = Result<T, LinAlgError>;
105
106// ============================================================================
107// StructureAware
108// ============================================================================
109
110/// For solvers that need variable structure information before solving.
111///
112/// Implemented by Schur complement solvers, which must partition variables
113/// into camera and landmark blocks before performing any linear solves.
114/// Call [`initialize_structure`](StructureAware::initialize_structure) once
115/// during solver setup, before passing the solver to an optimizer.
116pub trait StructureAware {
117    /// Initialize the solver's block structure from problem variables.
118    fn initialize_structure(
119        &mut self,
120        variables: &HashMap<String, VariableEnum>,
121        variable_index_map: &HashMap<String, usize>,
122    ) -> LinAlgResult<()>;
123}
124
125// ============================================================================
126// LinearizationMode — re-exported from linearizer/cpu where it is defined
127// ============================================================================
128
129// ============================================================================
130// LinearSolver trait (unified solver interface, generic over LinearizationMode)
131// ============================================================================
132
133/// Unified linear solver interface parameterized by [`LinearizationMode`].
134///
135/// This is the single trait implemented by all linear solvers. When `M` is
136/// a concrete type (e.g., `SparseMode`), this trait is object-safe and can
137/// be used as `dyn LinearSolver<SparseMode>` or `dyn LinearSolver<DenseMode>`.
138///
139/// - Sparse solvers (`SparseCholeskySolver`, `SparseQRSolver`, `SchurSolverAdapter`)
140///   implement `LinearSolver<SparseMode>`.
141/// - Dense solvers (`DenseCholeskySolver`, `DenseQRSolver`)
142///   implement `LinearSolver<DenseMode>`.
143pub trait LinearSolver<M: LinearizationMode> {
144    /// Solve the normal equations: (J^T · J) · dx = −J^T · r
145    fn solve_normal_equation(
146        &mut self,
147        residuals: &Mat<f64>,
148        jacobian: &M::Jacobian,
149    ) -> LinAlgResult<Mat<f64>>;
150
151    /// Solve the augmented equations: (J^T · J + λI) · dx = −J^T · r
152    fn solve_augmented_equation(
153        &mut self,
154        residuals: &Mat<f64>,
155        jacobian: &M::Jacobian,
156        lambda: f64,
157    ) -> LinAlgResult<Mat<f64>>;
158
159    /// Get the cached Hessian matrix (J^T · J) from the last solve
160    fn get_hessian(&self) -> Option<&M::Hessian>;
161
162    /// Get the cached gradient vector (J^T · r) from the last solve
163    fn get_gradient(&self) -> Option<&Mat<f64>>;
164
165    /// Compute the covariance matrix (H^{-1}) by inverting the cached Hessian.
166    ///
167    /// Returns `None` for solvers that do not support covariance estimation
168    /// (e.g., QR solvers, Schur complement solvers). Only Cholesky-based
169    /// solvers provide a real implementation.
170    fn compute_covariance_matrix(&mut self) -> Option<&Mat<f64>> {
171        None
172    }
173
174    /// Get the cached covariance matrix (H^{-1}) computed from the Hessian.
175    ///
176    /// Returns `None` if covariance has not been computed or is not supported.
177    fn get_covariance_matrix(&self) -> Option<&Mat<f64>> {
178        None
179    }
180}
181
182// ============================================================================
183// Utility functions
184// ============================================================================
185
186/// Extract per-variable covariance blocks from the full covariance matrix.
187///
188/// Given the full covariance matrix H^{-1} (inverse of information matrix),
189/// this function extracts the diagonal blocks corresponding to each individual variable.
190pub(crate) fn extract_variable_covariances(
191    full_covariance: &Mat<f64>,
192    variables: &HashMap<String, VariableEnum>,
193    variable_index_map: &HashMap<String, usize>,
194) -> HashMap<String, Mat<f64>> {
195    let mut result = HashMap::new();
196
197    for (var_name, var) in variables {
198        if let Some(&start_idx) = variable_index_map.get(var_name) {
199            let dim = var.get_size();
200            let mut var_cov = Mat::zeros(dim, dim);
201
202            for i in 0..dim {
203                for j in 0..dim {
204                    var_cov[(i, j)] = full_covariance[(start_idx + i, start_idx + j)];
205                }
206            }
207
208            result.insert(var_name.clone(), var_cov);
209        }
210    }
211
212    result
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::core::problem::VariableEnum;
219    use crate::core::variable::Variable;
220    use crate::error::ErrorLogging;
221    use apex_manifolds::rn::Rn;
222    use faer::Mat;
223    use nalgebra::dvector;
224    use std::collections::HashMap;
225
226    // -------------------------------------------------------------------------
227    // JacobianMode
228    // -------------------------------------------------------------------------
229
230    #[test]
231    fn test_jacobian_mode_default_is_sparse() {
232        assert_eq!(JacobianMode::default(), JacobianMode::Sparse);
233    }
234
235    #[test]
236    fn test_jacobian_mode_equality() {
237        assert_eq!(JacobianMode::Sparse, JacobianMode::Sparse);
238        assert_eq!(JacobianMode::Dense, JacobianMode::Dense);
239        assert_ne!(JacobianMode::Sparse, JacobianMode::Dense);
240    }
241
242    // -------------------------------------------------------------------------
243    // LinearSolverType Display + Default
244    // -------------------------------------------------------------------------
245
246    #[test]
247    fn test_linear_solver_type_default_is_cholesky() {
248        assert_eq!(
249            LinearSolverType::default(),
250            LinearSolverType::SparseCholesky
251        );
252    }
253
254    #[test]
255    fn test_linear_solver_type_display_all_variants() {
256        assert_eq!(
257            format!("{}", LinearSolverType::SparseCholesky),
258            "Sparse Cholesky"
259        );
260        assert_eq!(format!("{}", LinearSolverType::SparseQR), "Sparse QR");
261        assert_eq!(
262            format!("{}", LinearSolverType::SparseSchurComplement),
263            "Sparse Schur Complement"
264        );
265        assert_eq!(
266            format!("{}", LinearSolverType::DenseCholesky),
267            "Dense Cholesky"
268        );
269        assert_eq!(format!("{}", LinearSolverType::DenseQR), "Dense QR");
270    }
271
272    // -------------------------------------------------------------------------
273    // LinAlgError Display — one per variant
274    // -------------------------------------------------------------------------
275
276    #[test]
277    fn test_lin_alg_error_factorization_failed_display() {
278        let e = LinAlgError::FactorizationFailed("non-positive definite".into());
279        assert!(e.to_string().contains("non-positive definite"));
280    }
281
282    #[test]
283    fn test_lin_alg_error_singular_matrix_display() {
284        let e = LinAlgError::SingularMatrix("rank deficient".into());
285        assert!(e.to_string().contains("rank deficient"));
286    }
287
288    #[test]
289    fn test_lin_alg_error_sparse_matrix_creation_display() {
290        let e = LinAlgError::SparseMatrixCreation("bad triplets".into());
291        assert!(e.to_string().contains("bad triplets"));
292    }
293
294    #[test]
295    fn test_lin_alg_error_matrix_conversion_display() {
296        let e = LinAlgError::MatrixConversion("size mismatch".into());
297        assert!(e.to_string().contains("size mismatch"));
298    }
299
300    #[test]
301    fn test_lin_alg_error_invalid_input_display() {
302        let e = LinAlgError::InvalidInput("null jacobian".into());
303        assert!(e.to_string().contains("null jacobian"));
304    }
305
306    #[test]
307    fn test_lin_alg_error_invalid_state_display() {
308        let e = LinAlgError::InvalidState("not initialized".into());
309        assert!(e.to_string().contains("not initialized"));
310    }
311
312    // -------------------------------------------------------------------------
313    // log() / log_with_source() return self
314    // -------------------------------------------------------------------------
315
316    #[test]
317    fn test_lin_alg_error_log_returns_self() {
318        let e = LinAlgError::InvalidInput("log_test".into());
319        let returned = e.log();
320        assert!(returned.to_string().contains("log_test"));
321    }
322
323    #[test]
324    fn test_lin_alg_error_log_with_source_returns_self() {
325        let e = LinAlgError::SingularMatrix("source_test".into());
326        let source = std::io::Error::other("src");
327        let returned = e.log_with_source(source);
328        assert!(returned.to_string().contains("source_test"));
329    }
330
331    // -------------------------------------------------------------------------
332    // LinAlgResult type alias
333    // -------------------------------------------------------------------------
334
335    #[test]
336    fn test_lin_alg_result_ok() {
337        let r: LinAlgResult<i32> = Ok(7);
338        assert!(matches!(r, Ok(7)));
339    }
340
341    #[test]
342    fn test_lin_alg_result_err() {
343        let r: LinAlgResult<i32> = Err(LinAlgError::InvalidInput("oops".into()));
344        assert!(r.is_err());
345    }
346
347    // -------------------------------------------------------------------------
348    // extract_variable_covariances
349    // -------------------------------------------------------------------------
350
351    fn make_rn_var(val: f64) -> VariableEnum {
352        VariableEnum::Rn(Variable::new(Rn::new(dvector![val])))
353    }
354
355    #[test]
356    fn test_extract_variable_covariances_single_variable() {
357        let mut variables = HashMap::new();
358        variables.insert("x".into(), make_rn_var(1.0));
359        let mut variable_index_map = HashMap::new();
360        variable_index_map.insert("x".into(), 0usize);
361
362        // 1×1 covariance matrix
363        let full_cov = Mat::from_fn(1, 1, |_, _| 2.5);
364        let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
365        assert_eq!(result.len(), 1);
366        assert!((result["x"][(0, 0)] - 2.5).abs() < 1e-12);
367    }
368
369    #[test]
370    fn test_extract_variable_covariances_two_variables() {
371        let mut variables = HashMap::new();
372        variables.insert("a".into(), make_rn_var(1.0));
373        variables.insert("b".into(), make_rn_var(2.0));
374        let mut variable_index_map = HashMap::new();
375        variable_index_map.insert("a".into(), 0usize);
376        variable_index_map.insert("b".into(), 1usize);
377
378        // 2×2 diagonal covariance: a=3.0, b=7.0
379        let full_cov = Mat::from_fn(2, 2, |i, j| if i == j { [3.0, 7.0][i] } else { 0.0 });
380        let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
381        assert_eq!(result.len(), 2);
382        assert!((result["a"][(0, 0)] - 3.0).abs() < 1e-12);
383        assert!((result["b"][(0, 0)] - 7.0).abs() < 1e-12);
384    }
385
386    #[test]
387    fn test_extract_variable_covariances_empty_variables() {
388        let variables: HashMap<String, VariableEnum> = HashMap::new();
389        let variable_index_map: HashMap<String, usize> = HashMap::new();
390        let full_cov = Mat::zeros(0, 0);
391        let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
392        assert!(result.is_empty());
393    }
394
395    #[test]
396    fn test_extract_variable_covariances_var_not_in_index_map() {
397        // variable present but NOT in index map — should be skipped
398        let mut variables = HashMap::new();
399        variables.insert("x".into(), make_rn_var(1.0));
400        let variable_index_map: HashMap<String, usize> = HashMap::new(); // empty
401
402        let full_cov = Mat::from_fn(1, 1, |_, _| 5.0);
403        let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
404        assert!(result.is_empty());
405    }
406}