1pub mod dense;
2pub mod sparse;
3pub mod utils;
4
5use crate::core::problem::VariableEnum;
6use faer::Mat;
7use std::{
8 collections::HashMap,
9 fmt::{self, Debug, Display, Formatter},
10};
11use thiserror::Error;
12
13pub use sparse::{
14 IterativeSchurSolver, SchurBlockStructure, SchurOrdering, SchurPreconditioner, SchurVariant,
15 SparseCholeskySolver, SparseQRSolver, SparseSchurComplementSolver,
16};
17
18pub use dense::{DenseCholeskySolver, DenseQRSolver};
19
20pub use crate::linearizer::cpu::{DenseMode, LinearizationMode, SparseMode};
21
22#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
36pub enum JacobianMode {
37 #[default]
39 Sparse,
40 Dense,
42}
43
44#[non_exhaustive]
49#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
50pub enum LinearSolverType {
51 #[default]
52 SparseCholesky,
53 SparseQR,
54 SparseSchurComplement,
55 DenseCholesky,
56 DenseQR,
57}
58
59impl Display for LinearSolverType {
60 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
61 match self {
62 LinearSolverType::SparseCholesky => write!(f, "Sparse Cholesky"),
63 LinearSolverType::SparseQR => write!(f, "Sparse QR"),
64 LinearSolverType::SparseSchurComplement => write!(f, "Sparse Schur Complement"),
65 LinearSolverType::DenseCholesky => write!(f, "Dense Cholesky"),
66 LinearSolverType::DenseQR => write!(f, "Dense QR"),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Error)]
77pub enum LinAlgError {
78 #[error("Matrix factorization failed: {0}")]
80 FactorizationFailed(String),
81
82 #[error("Singular matrix detected: {0}")]
84 SingularMatrix(String),
85
86 #[error("Failed to create sparse matrix: {0}")]
88 SparseMatrixCreation(String),
89
90 #[error("Matrix conversion failed: {0}")]
92 MatrixConversion(String),
93
94 #[error("Invalid input: {0}")]
96 InvalidInput(String),
97
98 #[error("Invalid solver state: {0}")]
100 InvalidState(String),
101}
102
103pub type LinAlgResult<T> = Result<T, LinAlgError>;
105
106pub trait StructureAware {
117 fn initialize_structure(
119 &mut self,
120 variables: &HashMap<String, VariableEnum>,
121 variable_index_map: &HashMap<String, usize>,
122 ) -> LinAlgResult<()>;
123}
124
125pub trait LinearSolver<M: LinearizationMode> {
144 fn solve_normal_equation(
146 &mut self,
147 residuals: &Mat<f64>,
148 jacobian: &M::Jacobian,
149 ) -> LinAlgResult<Mat<f64>>;
150
151 fn solve_augmented_equation(
153 &mut self,
154 residuals: &Mat<f64>,
155 jacobian: &M::Jacobian,
156 lambda: f64,
157 ) -> LinAlgResult<Mat<f64>>;
158
159 fn get_hessian(&self) -> Option<&M::Hessian>;
161
162 fn get_gradient(&self) -> Option<&Mat<f64>>;
164
165 fn compute_covariance_matrix(&mut self) -> Option<&Mat<f64>> {
171 None
172 }
173
174 fn get_covariance_matrix(&self) -> Option<&Mat<f64>> {
178 None
179 }
180}
181
182pub(crate) fn extract_variable_covariances(
191 full_covariance: &Mat<f64>,
192 variables: &HashMap<String, VariableEnum>,
193 variable_index_map: &HashMap<String, usize>,
194) -> HashMap<String, Mat<f64>> {
195 let mut result = HashMap::new();
196
197 for (var_name, var) in variables {
198 if let Some(&start_idx) = variable_index_map.get(var_name) {
199 let dim = var.get_size();
200 let mut var_cov = Mat::zeros(dim, dim);
201
202 for i in 0..dim {
203 for j in 0..dim {
204 var_cov[(i, j)] = full_covariance[(start_idx + i, start_idx + j)];
205 }
206 }
207
208 result.insert(var_name.clone(), var_cov);
209 }
210 }
211
212 result
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::core::problem::VariableEnum;
219 use crate::core::variable::Variable;
220 use crate::error::ErrorLogging;
221 use apex_manifolds::rn::Rn;
222 use faer::Mat;
223 use nalgebra::dvector;
224 use std::collections::HashMap;
225
226 #[test]
231 fn test_jacobian_mode_default_is_sparse() {
232 assert_eq!(JacobianMode::default(), JacobianMode::Sparse);
233 }
234
235 #[test]
236 fn test_jacobian_mode_equality() {
237 assert_eq!(JacobianMode::Sparse, JacobianMode::Sparse);
238 assert_eq!(JacobianMode::Dense, JacobianMode::Dense);
239 assert_ne!(JacobianMode::Sparse, JacobianMode::Dense);
240 }
241
242 #[test]
247 fn test_linear_solver_type_default_is_cholesky() {
248 assert_eq!(
249 LinearSolverType::default(),
250 LinearSolverType::SparseCholesky
251 );
252 }
253
254 #[test]
255 fn test_linear_solver_type_display_all_variants() {
256 assert_eq!(
257 format!("{}", LinearSolverType::SparseCholesky),
258 "Sparse Cholesky"
259 );
260 assert_eq!(format!("{}", LinearSolverType::SparseQR), "Sparse QR");
261 assert_eq!(
262 format!("{}", LinearSolverType::SparseSchurComplement),
263 "Sparse Schur Complement"
264 );
265 assert_eq!(
266 format!("{}", LinearSolverType::DenseCholesky),
267 "Dense Cholesky"
268 );
269 assert_eq!(format!("{}", LinearSolverType::DenseQR), "Dense QR");
270 }
271
272 #[test]
277 fn test_lin_alg_error_factorization_failed_display() {
278 let e = LinAlgError::FactorizationFailed("non-positive definite".into());
279 assert!(e.to_string().contains("non-positive definite"));
280 }
281
282 #[test]
283 fn test_lin_alg_error_singular_matrix_display() {
284 let e = LinAlgError::SingularMatrix("rank deficient".into());
285 assert!(e.to_string().contains("rank deficient"));
286 }
287
288 #[test]
289 fn test_lin_alg_error_sparse_matrix_creation_display() {
290 let e = LinAlgError::SparseMatrixCreation("bad triplets".into());
291 assert!(e.to_string().contains("bad triplets"));
292 }
293
294 #[test]
295 fn test_lin_alg_error_matrix_conversion_display() {
296 let e = LinAlgError::MatrixConversion("size mismatch".into());
297 assert!(e.to_string().contains("size mismatch"));
298 }
299
300 #[test]
301 fn test_lin_alg_error_invalid_input_display() {
302 let e = LinAlgError::InvalidInput("null jacobian".into());
303 assert!(e.to_string().contains("null jacobian"));
304 }
305
306 #[test]
307 fn test_lin_alg_error_invalid_state_display() {
308 let e = LinAlgError::InvalidState("not initialized".into());
309 assert!(e.to_string().contains("not initialized"));
310 }
311
312 #[test]
317 fn test_lin_alg_error_log_returns_self() {
318 let e = LinAlgError::InvalidInput("log_test".into());
319 let returned = e.log();
320 assert!(returned.to_string().contains("log_test"));
321 }
322
323 #[test]
324 fn test_lin_alg_error_log_with_source_returns_self() {
325 let e = LinAlgError::SingularMatrix("source_test".into());
326 let source = std::io::Error::other("src");
327 let returned = e.log_with_source(source);
328 assert!(returned.to_string().contains("source_test"));
329 }
330
331 #[test]
336 fn test_lin_alg_result_ok() {
337 let r: LinAlgResult<i32> = Ok(7);
338 assert!(matches!(r, Ok(7)));
339 }
340
341 #[test]
342 fn test_lin_alg_result_err() {
343 let r: LinAlgResult<i32> = Err(LinAlgError::InvalidInput("oops".into()));
344 assert!(r.is_err());
345 }
346
347 fn make_rn_var(val: f64) -> VariableEnum {
352 VariableEnum::Rn(Variable::new(Rn::new(dvector![val])))
353 }
354
355 #[test]
356 fn test_extract_variable_covariances_single_variable() {
357 let mut variables = HashMap::new();
358 variables.insert("x".into(), make_rn_var(1.0));
359 let mut variable_index_map = HashMap::new();
360 variable_index_map.insert("x".into(), 0usize);
361
362 let full_cov = Mat::from_fn(1, 1, |_, _| 2.5);
364 let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
365 assert_eq!(result.len(), 1);
366 assert!((result["x"][(0, 0)] - 2.5).abs() < 1e-12);
367 }
368
369 #[test]
370 fn test_extract_variable_covariances_two_variables() {
371 let mut variables = HashMap::new();
372 variables.insert("a".into(), make_rn_var(1.0));
373 variables.insert("b".into(), make_rn_var(2.0));
374 let mut variable_index_map = HashMap::new();
375 variable_index_map.insert("a".into(), 0usize);
376 variable_index_map.insert("b".into(), 1usize);
377
378 let full_cov = Mat::from_fn(2, 2, |i, j| if i == j { [3.0, 7.0][i] } else { 0.0 });
380 let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
381 assert_eq!(result.len(), 2);
382 assert!((result["a"][(0, 0)] - 3.0).abs() < 1e-12);
383 assert!((result["b"][(0, 0)] - 7.0).abs() < 1e-12);
384 }
385
386 #[test]
387 fn test_extract_variable_covariances_empty_variables() {
388 let variables: HashMap<String, VariableEnum> = HashMap::new();
389 let variable_index_map: HashMap<String, usize> = HashMap::new();
390 let full_cov = Mat::zeros(0, 0);
391 let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
392 assert!(result.is_empty());
393 }
394
395 #[test]
396 fn test_extract_variable_covariances_var_not_in_index_map() {
397 let mut variables = HashMap::new();
399 variables.insert("x".into(), make_rn_var(1.0));
400 let variable_index_map: HashMap<String, usize> = HashMap::new(); let full_cov = Mat::from_fn(1, 1, |_, _| 5.0);
403 let result = extract_variable_covariances(&full_cov, &variables, &variable_index_map);
404 assert!(result.is_empty());
405 }
406}