1use super::{CooMatrix, CscMatrix, LinSolParams, LinSolTrait, Ordering, Scaling, StatsLinSol, Sym};
2use crate::constants::*;
3use crate::StrError;
4use russell_lab::{vec_copy, Stopwatch, Vector};
5
6#[repr(C)]
10struct InterfaceKLU {
11 _data: [u8; 0],
12 _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
13}
14
15unsafe impl Send for InterfaceKLU {}
19
20unsafe impl Send for SolverKLU {}
24
25extern "C" {
26 fn solver_klu_new() -> *mut InterfaceKLU;
27 fn solver_klu_drop(solver: *mut InterfaceKLU);
28 fn solver_klu_initialize(
29 solver: *mut InterfaceKLU,
30 ordering: i32,
31 scaling: i32,
32 ndim: i32,
33 col_pointers: *const i32,
34 row_indices: *const i32,
35 ) -> i32;
36 fn solver_klu_factorize(
37 solver: *mut InterfaceKLU,
38 effective_ordering: *mut i32,
39 effective_scaling: *mut i32,
40 cond_estimate: *mut f64,
41 compute_cond: CcBool,
42 col_pointers: *const i32,
43 row_indices: *const i32,
44 values: *const f64,
45 ) -> i32;
46 fn solver_klu_solve(solver: *mut InterfaceKLU, ndim: i32, in_rhs_out_x: *mut f64) -> i32;
47}
48
49pub struct SolverKLU {
53 solver: *mut InterfaceKLU,
55
56 csc: Option<CscMatrix>,
58
59 initialized: bool,
61
62 factorized: bool,
64
65 initialized_sym: Sym,
67
68 initialized_ndim: usize,
70
71 initialized_nnz: usize,
73
74 effective_ordering: i32,
76
77 effective_scaling: i32,
79
80 cond_estimate: f64,
82
83 stopwatch: Stopwatch,
85
86 time_initialize_ns: u128,
88
89 time_factorize_ns: u128,
91
92 time_solve_ns: u128,
94}
95
96impl Drop for SolverKLU {
97 fn drop(&mut self) {
99 unsafe {
100 solver_klu_drop(self.solver);
101 }
102 }
103}
104
105impl SolverKLU {
106 pub fn new() -> Result<Self, StrError> {
108 unsafe {
109 let solver = solver_klu_new();
110 if solver.is_null() {
111 return Err("c-code failed to allocate the KLU solver");
112 }
113 Ok(SolverKLU {
114 solver,
115 csc: None,
116 initialized: false,
117 factorized: false,
118 initialized_sym: Sym::No,
119 initialized_ndim: 0,
120 initialized_nnz: 0,
121 effective_ordering: -1,
122 effective_scaling: -1,
123 cond_estimate: 0.0,
124 stopwatch: Stopwatch::new(),
125 time_initialize_ns: 0,
126 time_factorize_ns: 0,
127 time_solve_ns: 0,
128 })
129 }
130 }
131}
132
133impl LinSolTrait for SolverKLU {
134 fn factorize(&mut self, mat: &CooMatrix, params: Option<LinSolParams>) -> Result<(), StrError> {
153 if self.initialized {
155 if mat.symmetric != self.initialized_sym {
156 return Err("subsequent factorizations must use the same matrix (symmetric differs)");
157 }
158 if mat.nrow != self.initialized_ndim {
159 return Err("subsequent factorizations must use the same matrix (ndim differs)");
160 }
161 if mat.nnz != self.initialized_nnz {
162 return Err("subsequent factorizations must use the same matrix (nnz differs)");
163 }
164 self.csc.as_mut().unwrap().update_from_coo(mat)?;
165 } else {
166 if mat.nrow != mat.ncol {
167 return Err("the matrix must be square");
168 }
169 if mat.nnz < 1 {
170 return Err("the COO matrix must have at least one non-zero value");
171 }
172 if mat.symmetric == Sym::YesLower || mat.symmetric == Sym::YesUpper {
173 return Err("KLU requires Sym::YesFull for symmetric matrices");
174 }
175 self.initialized_sym = mat.symmetric;
176 self.initialized_ndim = mat.nrow;
177 self.initialized_nnz = mat.nnz;
178 self.csc = Some(CscMatrix::from_coo(mat)?);
179 }
180 let csc = self.csc.as_ref().unwrap();
181
182 let par = if let Some(p) = params { p } else { LinSolParams::new() };
184
185 let ordering = klu_ordering(par.ordering);
187 let scaling = klu_scaling(par.scaling);
188
189 let compute_cond = if par.compute_condition_numbers { 1 } else { 0 };
191
192 let ndim = to_i32(csc.nrow);
194
195 if !self.initialized {
197 self.stopwatch.reset();
198 unsafe {
199 let status = solver_klu_initialize(
200 self.solver,
201 ordering,
202 scaling,
203 ndim,
204 csc.col_pointers.as_ptr(),
205 csc.row_indices.as_ptr(),
206 );
207 if status != SUCCESSFUL_EXIT {
208 return Err(handle_klu_error_code(status));
209 }
210 }
211 self.time_initialize_ns = self.stopwatch.stop();
212 self.initialized = true;
213 }
214
215 self.stopwatch.reset();
217 unsafe {
218 let status = solver_klu_factorize(
219 self.solver,
220 &mut self.effective_ordering,
221 &mut self.effective_scaling,
222 &mut self.cond_estimate,
223 compute_cond,
224 csc.col_pointers.as_ptr(),
225 csc.row_indices.as_ptr(),
226 csc.values.as_ptr(),
227 );
228 if status != SUCCESSFUL_EXIT {
229 return Err(handle_klu_error_code(status));
230 }
231 }
232 self.time_factorize_ns = self.stopwatch.stop();
233
234 self.factorized = true;
236 Ok(())
237 }
238
239 fn solve(&mut self, x: &mut Vector, rhs: &Vector, _verbose: bool) -> Result<(), StrError> {
260 if !self.factorized {
262 return Err("the function factorize must be called before solve");
263 }
264
265 if x.dim() != self.initialized_ndim {
267 return Err("the dimension of the vector of unknown values x is incorrect");
268 }
269 if rhs.dim() != self.initialized_ndim {
270 return Err("the dimension of the right-hand side vector is incorrect");
271 }
272
273 let ndim = to_i32(self.initialized_ndim);
275 vec_copy(x, rhs).unwrap();
276 self.stopwatch.reset();
277 unsafe {
278 let status = solver_klu_solve(self.solver, ndim, x.as_mut_data().as_mut_ptr());
279 if status != SUCCESSFUL_EXIT {
280 return Err(handle_klu_error_code(status));
281 }
282 }
283 self.time_solve_ns = self.stopwatch.stop();
284
285 Ok(())
287 }
288
289 fn update_stats(&self, stats: &mut StatsLinSol) {
291 stats.main.solver = if cfg!(feature = "local_suitesparse") {
292 "KLU-local".to_string()
293 } else {
294 "KLU".to_string()
295 };
296 stats.output.umfpack_rcond_estimate = self.cond_estimate;
297 stats.output.effective_ordering = match self.effective_ordering {
298 KLU_ORDERING_AMD => "Amd".to_string(),
299 KLU_ORDERING_COLAMD => "Colamd".to_string(),
300 _ => "Unknown".to_string(),
301 };
302 stats.output.effective_scaling = match self.effective_scaling {
303 KLU_SCALE_NONE => "No".to_string(),
304 KLU_SCALE_SUM => "Sum".to_string(),
305 KLU_SCALE_MAX => "Max".to_string(),
306 _ => "Unknown".to_string(),
307 };
308 stats.time_nanoseconds.initialize = self.time_initialize_ns;
309 stats.time_nanoseconds.factorize = self.time_factorize_ns;
310 stats.time_nanoseconds.solve = self.time_solve_ns;
311 }
312
313 fn get_ns_init(&self) -> u128 {
315 self.time_initialize_ns
316 }
317
318 fn get_ns_fact(&self) -> u128 {
320 self.time_factorize_ns
321 }
322
323 fn get_ns_solve(&self) -> u128 {
325 self.time_solve_ns
326 }
327}
328
329pub(crate) const KLU_ORDERING_AUTO: i32 = -10; pub(crate) const KLU_ORDERING_AMD: i32 = 0; pub(crate) const KLU_ORDERING_COLAMD: i32 = 1; pub(crate) const KLU_SCALE_AUTO: i32 = -10; pub(crate) const KLU_SCALE_NONE: i32 = 0; pub(crate) const KLU_SCALE_SUM: i32 = 1; pub(crate) const KLU_SCALE_MAX: i32 = 2; pub(crate) fn klu_ordering(ordering: Ordering) -> i32 {
340 match ordering {
341 Ordering::Amd => KLU_ORDERING_AMD,
342 Ordering::Amf => KLU_ORDERING_AUTO,
343 Ordering::Auto => KLU_ORDERING_AUTO,
344 Ordering::Best => KLU_ORDERING_AUTO,
345 Ordering::Cholmod => KLU_ORDERING_AUTO,
346 Ordering::Colamd => KLU_ORDERING_COLAMD,
347 Ordering::Metis => KLU_ORDERING_AUTO,
348 Ordering::No => KLU_ORDERING_AUTO,
349 Ordering::Pord => KLU_ORDERING_AUTO,
350 Ordering::Qamd => KLU_ORDERING_AUTO,
351 Ordering::Scotch => KLU_ORDERING_AUTO,
352 }
353}
354
355pub(crate) fn klu_scaling(scaling: Scaling) -> i32 {
357 match scaling {
358 Scaling::Auto => KLU_SCALE_AUTO,
359 Scaling::Column => KLU_ORDERING_AUTO,
360 Scaling::Diagonal => KLU_ORDERING_AUTO,
361 Scaling::Max => KLU_SCALE_MAX,
362 Scaling::No => KLU_SCALE_NONE,
363 Scaling::RowCol => KLU_ORDERING_AUTO,
364 Scaling::RowColIter => KLU_ORDERING_AUTO,
365 Scaling::RowColRig => KLU_ORDERING_AUTO,
366 Scaling::Sum => KLU_SCALE_SUM,
367 }
368}
369
370pub(crate) fn handle_klu_error_code(err: i32) -> StrError {
372 match err {
373 -9 => "klu_analyze failed",
374 -8 => "klu_factor failed",
375 -7 => "klu_condest failed",
376 ERROR_NULL_POINTER => "KLU failed due to NULL POINTER error",
377 ERROR_MALLOC => "KLU failed due to MALLOC error",
378 ERROR_VERSION => "KLU failed due to VERSION error",
379 ERROR_NOT_AVAILABLE => "KLU is not AVAILABLE",
380 ERROR_NEED_INITIALIZATION => "KLU failed because INITIALIZATION is needed",
381 ERROR_NEED_FACTORIZATION => "KLU failed because FACTORIZATION is needed",
382 ERROR_ALREADY_INITIALIZED => "KLU failed because INITIALIZATION has been completed already",
383 _ => "Error: unknown error returned by c-code (KLU)",
384 }
385}
386
387#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::{CooMatrix, Samples};
393 use russell_lab::vec_approx_eq;
394
395 #[test]
396 fn new_and_drop_work() {
397 let solver = SolverKLU::new().unwrap();
399 assert!(!solver.factorized);
400 }
401
402 #[test]
403 fn factorize_handles_errors() {
404 let mut solver = SolverKLU::new().unwrap();
405 assert!(!solver.factorized);
406
407 let coo = CooMatrix::new(1, 1, 1, Sym::No).unwrap();
409 assert_eq!(
410 solver.factorize(&coo, None).err(),
411 Some("the COO matrix must have at least one non-zero value")
412 );
413
414 let (coo, _, _, _) = Samples::rectangular_1x7();
416 assert_eq!(solver.factorize(&coo, None).err(), Some("the matrix must be square"));
417 let (coo, _, _, _) = Samples::mkl_symmetric_5x5_lower(false, false);
418 assert_eq!(
419 solver.factorize(&coo, None).err(),
420 Some("KLU requires Sym::YesFull for symmetric matrices")
421 );
422
423 let mut coo = CooMatrix::new(2, 2, 2, Sym::No).unwrap();
425 coo.put(0, 0, 1.0).unwrap();
426 coo.put(1, 1, 2.0).unwrap();
427 solver.factorize(&coo, None).unwrap();
429 let mut coo = CooMatrix::new(2, 2, 2, Sym::YesFull).unwrap();
431 coo.put(0, 0, 1.0).unwrap();
432 coo.put(1, 1, 2.0).unwrap();
433 assert_eq!(
434 solver.factorize(&coo, None).err(),
435 Some("subsequent factorizations must use the same matrix (symmetric differs)")
436 );
437 let mut coo = CooMatrix::new(1, 1, 1, Sym::No).unwrap();
439 coo.put(0, 0, 1.0).unwrap();
440 assert_eq!(
441 solver.factorize(&coo, None).err(),
442 Some("subsequent factorizations must use the same matrix (ndim differs)")
443 );
444 let mut coo = CooMatrix::new(2, 2, 1, Sym::No).unwrap();
446 coo.put(0, 0, 1.0).unwrap();
447 assert_eq!(
448 solver.factorize(&coo, None).err(),
449 Some("subsequent factorizations must use the same matrix (nnz differs)")
450 );
451 }
452
453 #[test]
454 fn factorize_works() {
455 let mut solver = SolverKLU::new().unwrap();
456 assert!(!solver.factorized);
457 let (coo, _, _, _) = Samples::umfpack_unsymmetric_5x5();
458
459 let mut params = LinSolParams::new();
460 params.ordering = Ordering::Metis;
461 params.scaling = Scaling::Sum;
462
463 solver.factorize(&coo, Some(params)).unwrap();
464 assert!(solver.factorized);
465 assert_eq!(solver.effective_ordering, KLU_ORDERING_AMD);
466 assert_eq!(solver.effective_scaling, KLU_SCALE_SUM);
467
468 solver.factorize(&coo, Some(params)).unwrap();
470 }
471
472 #[test]
473 fn factorize_fails_on_singular_matrix() {
474 let mut solver = SolverKLU::new().unwrap();
475 let mut coo = CooMatrix::new(2, 2, 2, Sym::No).unwrap();
476 coo.put(0, 0, 1.0).unwrap();
477 coo.put(1, 1, 0.0).unwrap();
478 assert_eq!(solver.factorize(&coo, None), Err("klu_factor failed"));
479 }
480
481 #[test]
482 fn solve_handles_errors() {
483 let mut coo = CooMatrix::new(2, 2, 2, Sym::No).unwrap();
484 coo.put(0, 0, 123.0).unwrap();
485 coo.put(1, 1, 456.0).unwrap();
486 let mut solver = SolverKLU::new().unwrap();
487 assert!(!solver.factorized);
488 let mut x = Vector::new(2);
489 let rhs = Vector::new(2);
490 assert_eq!(
491 solver.solve(&mut x, &rhs, false),
492 Err("the function factorize must be called before solve")
493 );
494 let mut x = Vector::new(1);
495 solver.factorize(&coo, None).unwrap();
496 assert_eq!(
497 solver.solve(&mut x, &rhs, false),
498 Err("the dimension of the vector of unknown values x is incorrect")
499 );
500 let mut x = Vector::new(2);
501 let rhs = Vector::new(1);
502 assert_eq!(
503 solver.solve(&mut x, &rhs, false),
504 Err("the dimension of the right-hand side vector is incorrect")
505 );
506 }
507
508 #[test]
509 fn solve_works() {
510 let mut solver = SolverKLU::new().unwrap();
511 let (coo, _, _, _) = Samples::umfpack_unsymmetric_5x5();
512 let mut x = Vector::new(5);
513 let rhs = Vector::from(&[8.0, 45.0, -3.0, 3.0, 19.0]);
514 let x_correct = &[1.0, 2.0, 3.0, 4.0, 5.0];
515
516 let mut params = LinSolParams::new();
517 params.ordering = Ordering::Cholmod;
518 params.scaling = Scaling::Max;
519
520 solver.factorize(&coo, Some(params)).unwrap();
521 solver.solve(&mut x, &rhs, false).unwrap();
522 vec_approx_eq(&x, x_correct, 1e-14);
523
524 let mut x_again = Vector::new(5);
526 solver.solve(&mut x_again, &rhs, false).unwrap();
527 vec_approx_eq(&x_again, x_correct, 1e-14);
528
529 let mut stats = StatsLinSol::new();
531 solver.update_stats(&mut stats);
532 assert_eq!(stats.output.effective_ordering, "Amd");
533 assert_eq!(stats.output.effective_scaling, "Max");
534 }
535
536 #[test]
537 fn solve_works_symmetric() {
538 let mut solver = SolverKLU::new().unwrap();
539 let (coo, _, _, _) = Samples::mkl_symmetric_5x5_full();
540 let mut x = Vector::new(5);
541 let rhs = Vector::from(&[1.0, 2.0, 3.0, 4.0, 5.0]);
542 let x_correct = &[-979.0 / 3.0, 983.0, 1961.0 / 12.0, 398.0, 123.0 / 2.0];
543
544 let mut params = LinSolParams::new();
545 params.ordering = Ordering::Colamd;
546 params.scaling = Scaling::No;
547
548 solver.factorize(&coo, Some(params)).unwrap();
549 solver.solve(&mut x, &rhs, false).unwrap();
550 vec_approx_eq(&x, x_correct, 1e-10);
551
552 let mut x_again = Vector::new(5);
554 solver.solve(&mut x_again, &rhs, false).unwrap();
555 vec_approx_eq(&x_again, x_correct, 1e-10);
556
557 let mut stats = StatsLinSol::new();
559 solver.update_stats(&mut stats);
560 assert_eq!(stats.output.effective_ordering, "Colamd");
561 assert_eq!(stats.output.effective_scaling, "No");
562 }
563
564 #[test]
565 fn ordering_and_scaling_works() {
566 assert_eq!(klu_ordering(Ordering::Amd), KLU_ORDERING_AMD);
567 assert_eq!(klu_ordering(Ordering::Amf), KLU_ORDERING_AUTO);
568 assert_eq!(klu_ordering(Ordering::Auto), KLU_ORDERING_AUTO);
569 assert_eq!(klu_ordering(Ordering::Best), KLU_ORDERING_AUTO);
570 assert_eq!(klu_ordering(Ordering::Cholmod), KLU_ORDERING_AUTO);
571 assert_eq!(klu_ordering(Ordering::Colamd), KLU_ORDERING_COLAMD);
572 assert_eq!(klu_ordering(Ordering::Metis), KLU_ORDERING_AUTO);
573 assert_eq!(klu_ordering(Ordering::No), KLU_ORDERING_AUTO);
574 assert_eq!(klu_ordering(Ordering::Pord), KLU_ORDERING_AUTO);
575 assert_eq!(klu_ordering(Ordering::Qamd), KLU_ORDERING_AUTO);
576 assert_eq!(klu_ordering(Ordering::Scotch), KLU_ORDERING_AUTO);
577
578 assert_eq!(klu_scaling(Scaling::Auto), KLU_SCALE_AUTO);
579 assert_eq!(klu_scaling(Scaling::Column), KLU_SCALE_AUTO);
580 assert_eq!(klu_scaling(Scaling::Diagonal), KLU_SCALE_AUTO);
581 assert_eq!(klu_scaling(Scaling::Max), KLU_SCALE_MAX);
582 assert_eq!(klu_scaling(Scaling::No), KLU_SCALE_NONE);
583 assert_eq!(klu_scaling(Scaling::RowCol), KLU_SCALE_AUTO);
584 assert_eq!(klu_scaling(Scaling::RowColIter), KLU_SCALE_AUTO);
585 assert_eq!(klu_scaling(Scaling::RowColRig), KLU_SCALE_AUTO);
586 assert_eq!(klu_scaling(Scaling::Sum), KLU_SCALE_SUM);
587 }
588
589 #[test]
590 fn handle_klu_error_code_works() {
591 let default = "Error: unknown error returned by c-code (KLU)";
592 assert_eq!(handle_klu_error_code(-9), "klu_analyze failed");
593 assert_eq!(handle_klu_error_code(-8), "klu_factor failed");
594 assert_eq!(handle_klu_error_code(-7), "klu_condest failed");
595 assert_eq!(
596 handle_klu_error_code(ERROR_NULL_POINTER),
597 "KLU failed due to NULL POINTER error"
598 );
599 assert_eq!(handle_klu_error_code(ERROR_MALLOC), "KLU failed due to MALLOC error");
600 assert_eq!(handle_klu_error_code(ERROR_VERSION), "KLU failed due to VERSION error");
601 assert_eq!(handle_klu_error_code(ERROR_NOT_AVAILABLE), "KLU is not AVAILABLE");
602 assert_eq!(
603 handle_klu_error_code(ERROR_NEED_INITIALIZATION),
604 "KLU failed because INITIALIZATION is needed"
605 );
606 assert_eq!(
607 handle_klu_error_code(ERROR_NEED_FACTORIZATION),
608 "KLU failed because FACTORIZATION is needed"
609 );
610 assert_eq!(
611 handle_klu_error_code(ERROR_ALREADY_INITIALIZED),
612 "KLU failed because INITIALIZATION has been completed already"
613 );
614 assert_eq!(handle_klu_error_code(123), default);
615 }
616}