1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3
4use faer::sparse::{Argsort, Pair, SparseColMat, SymbolicSparseColMat};
5use faer_ext::IntoFaer;
6use nalgebra as na;
7use rayon::prelude::*;
8
9use crate::manifold::Manifold;
10use crate::parameter_block::ParameterBlock;
11use crate::{factors, loss_functions, residual_block};
12
13type ResidualBlockId = usize;
14
15pub struct Problem {
16 pub total_residual_dimension: usize,
17 residual_id_count: usize,
18 residual_blocks: HashMap<ResidualBlockId, residual_block::ResidualBlock>,
19 pub fixed_variable_indexes: HashMap<String, HashSet<usize>>,
20 pub variable_bounds: HashMap<String, HashMap<usize, (f64, f64)>>,
21 pub variable_manifold: HashMap<String, Arc<dyn Manifold + Sync + Send>>,
22}
23impl Default for Problem {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29pub struct SymbolicStructure {
30 pattern: SymbolicSparseColMat<usize>,
31 order: Argsort<usize>,
32}
33
34type JacobianValue = f64;
35
36impl Problem {
37 pub fn new() -> Problem {
38 Problem {
39 total_residual_dimension: 0,
40 residual_id_count: 0,
41 residual_blocks: HashMap::new(),
42 fixed_variable_indexes: HashMap::new(),
43 variable_bounds: HashMap::new(),
44 variable_manifold: HashMap::new(),
45 }
46 }
47
48 pub fn build_symbolic_structure(
49 &self,
50 parameter_blocks: &HashMap<String, ParameterBlock>,
51 total_variable_dimension: usize,
52 variable_name_to_col_idx_dict: &HashMap<String, usize>,
53 ) -> SymbolicStructure {
54 let mut indices = Vec::<Pair<usize, usize>>::new();
55
56 self.residual_blocks.iter().for_each(|(_, residual_block)| {
57 let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
58 let mut count_variable_local_idx: usize = 0;
59 for var_key in &residual_block.variable_key_list {
60 if let Some(param) = parameter_blocks.get(var_key) {
61 variable_local_idx_size_list
62 .push((count_variable_local_idx, param.tangent_size()));
63 count_variable_local_idx += param.tangent_size();
64 };
65 }
66 for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
67 if let Some(variable_global_idx) = variable_name_to_col_idx_dict.get(var_key) {
68 let (_, var_size) = variable_local_idx_size_list[i];
69 for row_idx in 0..residual_block.dim_residual {
70 let mut current_var_col_offset = 0;
71 for col_idx in 0..var_size {
72 if let Some(param) = parameter_blocks.get(var_key)
73 && param.manifold.is_none()
74 && param.fixed_variables.contains(&col_idx)
75 {
76 continue;
77 }
78 let global_row_idx = residual_block.residual_row_start_idx + row_idx;
79 let global_col_idx = variable_global_idx + current_var_col_offset;
80 indices.push(Pair::new(global_row_idx, global_col_idx));
81 current_var_col_offset += 1;
82 }
83 }
84 }
85 }
86 });
87 let start = std::time::Instant::now();
88 let (s, o) = SymbolicSparseColMat::try_new_from_indices(
89 self.total_residual_dimension,
90 total_variable_dimension,
91 &indices,
92 )
93 .unwrap();
94 log::trace!("Built symbolic matrix: {:?}", start.elapsed());
95 SymbolicStructure {
96 pattern: s,
97 order: o,
98 }
99 }
100
101 pub fn get_variable_name_to_col_idx_dict(
102 &self,
103 parameter_blocks: &HashMap<String, ParameterBlock>,
104 ) -> HashMap<String, usize> {
105 let mut count_col_idx = 0;
106 let mut variable_name_to_col_idx_dict = HashMap::new();
107 parameter_blocks
108 .iter()
109 .for_each(|(param_name, param_block)| {
110 variable_name_to_col_idx_dict.insert(param_name.to_owned(), count_col_idx);
111 let effective_size = if param_block.manifold.is_some() {
112 param_block.tangent_size()
113 } else {
114 param_block.tangent_size() - param_block.fixed_variables.len()
115 };
116 count_col_idx += effective_size;
117 });
118 variable_name_to_col_idx_dict
119 }
120 pub fn add_residual_block(
121 &mut self,
122 dim_residual: usize,
123 variable_key_size_list: &[&str],
124 factor: Box<dyn factors::FactorImpl + Send>,
125 loss_func: Option<Box<dyn loss_functions::Loss + Send>>,
126 ) -> ResidualBlockId {
127 self.residual_blocks.insert(
128 self.residual_id_count,
129 residual_block::ResidualBlock::new(
130 self.residual_id_count,
131 dim_residual,
132 self.total_residual_dimension,
133 variable_key_size_list,
134 factor,
135 loss_func,
136 ),
137 );
138 let block_id = self.residual_id_count;
139 self.residual_id_count += 1;
140
141 self.total_residual_dimension += dim_residual;
142
143 block_id
144 }
145 pub fn remove_residual_block(
146 &mut self,
147 block_id: ResidualBlockId,
148 ) -> Option<residual_block::ResidualBlock> {
149 if let Some(residual_block) = self.residual_blocks.remove(&block_id) {
150 self.total_residual_dimension -= residual_block.dim_residual;
151 Some(residual_block)
152 } else {
153 None
154 }
155 }
156 pub fn fix_variable(&mut self, var_to_fix: &str, idx: usize) {
157 if let Some(var_mut) = self.fixed_variable_indexes.get_mut(var_to_fix) {
158 var_mut.insert(idx);
159 } else {
160 self.fixed_variable_indexes
161 .insert(var_to_fix.to_owned(), HashSet::from([idx]));
162 }
163 }
164 pub fn unfix_variable(&mut self, var_to_unfix: &str) {
165 self.fixed_variable_indexes.remove(var_to_unfix);
166 }
167 pub fn set_variable_bounds(
168 &mut self,
169 var_to_bound: &str,
170 idx: usize,
171 lower_bound: f64,
172 upper_bound: f64,
173 ) {
174 if lower_bound > upper_bound {
175 log::error!("lower bound is larger than upper bound");
176 } else if let Some(var_mut) = self.variable_bounds.get_mut(var_to_bound) {
177 var_mut.insert(idx, (lower_bound, upper_bound));
178 } else {
179 self.variable_bounds.insert(
180 var_to_bound.to_owned(),
181 HashMap::from([(idx, (lower_bound, upper_bound))]),
182 );
183 }
184 }
185 pub fn set_variable_manifold(
186 &mut self,
187 var_name: &str,
188 manifold: Arc<dyn Manifold + Sync + Send>,
189 ) {
190 self.variable_manifold
191 .insert(var_name.to_string(), manifold);
192 }
193 pub fn remove_variable_bounds(&mut self, var_to_unbound: &str) {
194 self.variable_bounds.remove(var_to_unbound);
195 }
196 pub fn initialize_parameter_blocks(
197 &self,
198 initial_values: &HashMap<String, na::DVector<f64>>,
199 ) -> HashMap<String, ParameterBlock> {
200 let parameter_blocks: HashMap<String, ParameterBlock> = initial_values
201 .iter()
202 .map(|(k, v)| {
203 let mut p_block = ParameterBlock::from_vec(v.clone());
204 if let Some(indexes) = self.fixed_variable_indexes.get(k) {
205 p_block.fixed_variables = indexes.clone();
206 }
207 if let Some(bounds) = self.variable_bounds.get(k) {
208 p_block.variable_bounds = bounds.clone();
209 }
210 if let Some(manifold) = self.variable_manifold.get(k) {
211 p_block.manifold = Some(manifold.clone())
212 }
213
214 (k.to_owned(), p_block)
215 })
216 .collect();
217 parameter_blocks
218 }
219
220 pub fn compute_residuals(
221 &self,
222 parameter_blocks: &HashMap<String, ParameterBlock>,
223 with_loss_fn: bool,
224 ) -> faer::Mat<f64> {
225 let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
226 self.total_residual_dimension,
227 )));
228 self.residual_blocks
229 .par_iter()
230 .for_each(|(_, residual_block)| {
231 self.compute_residual_impl(
232 residual_block,
233 parameter_blocks,
234 &total_residual,
235 with_loss_fn,
236 )
237 });
238 let total_residual = Arc::try_unwrap(total_residual)
239 .unwrap()
240 .into_inner()
241 .unwrap();
242
243 total_residual.view_range(.., ..).into_faer().to_owned()
244 }
245
246 pub fn compute_residual_and_jacobian(
247 &self,
248 parameter_blocks: &HashMap<String, ParameterBlock>,
249 variable_name_to_col_idx_dict: &HashMap<String, usize>,
250 symbolic_structure: &SymbolicStructure,
251 ) -> (faer::Mat<f64>, SparseColMat<usize, f64>) {
252 let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
254 self.total_residual_dimension,
255 )));
256
257 let jacobian_lists: Vec<JacobianValue> = self
258 .residual_blocks
259 .par_iter()
260 .map(|(_, residual_block)| {
261 self.compute_residual_and_jacobian_impl(
262 residual_block,
263 parameter_blocks,
264 variable_name_to_col_idx_dict,
265 &total_residual,
266 )
267 })
268 .flatten()
269 .collect();
270
271 let total_residual = Arc::try_unwrap(total_residual)
272 .unwrap()
273 .into_inner()
274 .unwrap();
275
276 let residual_faer = total_residual.view_range(.., ..).into_faer().to_owned();
277 let jacobian_faer = SparseColMat::new_from_argsort(
278 symbolic_structure.pattern.clone(),
279 &symbolic_structure.order,
280 jacobian_lists.as_slice(),
281 )
282 .unwrap();
283 (residual_faer, jacobian_faer)
284 }
285
286 fn compute_residual_impl(
287 &self,
288 residual_block: &crate::ResidualBlock,
289 parameter_blocks: &HashMap<String, ParameterBlock>,
290 total_residual: &Arc<Mutex<na::DVector<f64>>>,
291 with_loss_fn: bool,
292 ) {
293 let mut params = Vec::new();
294 for var_key in &residual_block.variable_key_list {
295 if let Some(param) = parameter_blocks.get(var_key) {
296 params.push(param);
297 };
298 }
299 let res = residual_block.residual(¶ms, with_loss_fn);
300
301 {
302 let mut total_residual = total_residual.lock().unwrap();
303 total_residual
304 .rows_mut(
305 residual_block.residual_row_start_idx,
306 residual_block.dim_residual,
307 )
308 .copy_from(&res);
309 }
310 }
311
312 fn compute_residual_and_jacobian_impl(
313 &self,
314 residual_block: &crate::ResidualBlock,
315 parameter_blocks: &HashMap<String, ParameterBlock>,
316 variable_name_to_col_idx_dict: &HashMap<String, usize>,
317 total_residual: &Arc<Mutex<na::DVector<f64>>>,
318 ) -> Vec<JacobianValue> {
319 let mut params = Vec::new();
320 let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
321 let mut count_variable_local_idx: usize = 0;
322 for var_key in &residual_block.variable_key_list {
323 if let Some(param) = parameter_blocks.get(var_key) {
324 params.push(param);
325 variable_local_idx_size_list.push((count_variable_local_idx, param.tangent_size()));
326 count_variable_local_idx += param.tangent_size();
327 };
328 }
329 let (res, jac) = residual_block.residual_and_jacobian(¶ms);
330 {
331 let mut total_residual = total_residual.lock().unwrap();
332 total_residual
333 .rows_mut(
334 residual_block.residual_row_start_idx,
335 residual_block.dim_residual,
336 )
337 .copy_from(&res);
338 }
339
340 let mut local_jacobian_list = Vec::new();
341
342 for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
343 if variable_name_to_col_idx_dict.contains_key(var_key) {
344 let (variable_local_idx, var_size) = variable_local_idx_size_list[i];
345 let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size));
346 let param = ¶ms[i];
347 for row_idx in 0..jac.shape().0 {
348 for col_idx in 0..var_size {
349 if param.manifold.is_none() && param.fixed_variables.contains(&col_idx) {
350 continue;
351 }
352 let j_value = variable_jac[(row_idx, col_idx)];
353 if j_value.is_finite() {
354 local_jacobian_list.push(j_value);
355 } else {
356 log::warn!(
357 "Non-finite Jacobian value detected at residual block {}, variable {}, row {}, col {}. Setting to 0.0",
358 residual_block.residual_block_id,
359 var_key,
360 row_idx,
361 col_idx
362 );
363 local_jacobian_list.push(0.0);
364 }
365 }
366 }
367 } else {
368 panic!(
369 "Missing key {} in variable-to-column-index mapping",
370 var_key
371 );
372 }
373 }
374
375 local_jacobian_list
376 }
377}