rlst/dense/linalg/lapack/interface/potrf.rs
1//! Implementation of ?potrf - Cholesky factorization of a symmetric positive-definite matrix.
2
3use lapack::{cpotrf, dpotrf, spotrf, zpotrf};
4
5use crate::base_types::LapackResult;
6use crate::base_types::{c32, c64};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10/// `Uplo` parameter for `?potrf` to specify which triangular part of the matrix is stored.
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12#[repr(u8)]
13pub enum PotrfUplo {
14 /// Upper triangular part of the matrix is stored.
15 Upper = b'U',
16 /// Lower triangular part of the matrix is stored.
17 Lower = b'L',
18}
19
20/// ?potrf - Cholesky factorization of a symmetric positive-definite matrix.
21pub trait Potrf: Sized {
22 /// Perform Cholesky factorization of a symmetric positive-definite matrix.
23 ///
24 /// If `uplo` is `Upper`, the factorization is of the form:
25 /// A = U^H * U, where U is an upper triangular matrix. If `uplo` is `Lower`, the
26 /// factorization is of the form:
27 /// A = L * L^H, where L is a lower triangular matrix.
28 ///
29 /// **Arguments:**
30 /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
31 /// - `n`: The order of the matrix A.
32 /// - `a`: The matrix A to be factored.
33 /// - `lda`: The leading dimension of the matrix A.
34 ///
35 /// **Returns:**
36 /// A `LapackResult<()>` indicating success or failure.
37 fn potrf(uplo: PotrfUplo, n: usize, a: &mut [Self], lda: usize) -> LapackResult<()>;
38}
39
40macro_rules! implement_potrf {
41 ($scalar:ty, $potrf:expr) => {
42 impl Potrf for $scalar {
43 fn potrf(uplo: PotrfUplo, n: usize, a: &mut [Self], lda: usize) -> LapackResult<()> {
44 assert_eq!(
45 a.len(),
46 lda * n,
47 "Require `a.len()` {} == `lda * n` {}.",
48 a.len(),
49 lda * n
50 );
51
52 assert!(
53 lda >= std::cmp::max(1, n),
54 "Require `lda` {} >= `max(1, n)` {}.",
55 lda,
56 std::cmp::max(1, n)
57 );
58
59 let mut info = 0;
60
61 unsafe {
62 $potrf(uplo as u8, n as i32, a, lda as i32, &mut info);
63 }
64
65 lapack_return(info, ())
66 }
67 }
68 };
69}
70
71implement_potrf!(f32, spotrf);
72implement_potrf!(f64, dpotrf);
73implement_potrf!(c32, cpotrf);
74implement_potrf!(c64, zpotrf);