rlst/dense/linalg/lapack/interface/
potrf.rs

1//! Implementation of ?potrf - Cholesky factorization of a symmetric positive-definite matrix.
2
3use lapack::{cpotrf, dpotrf, spotrf, zpotrf};
4
5use crate::base_types::LapackResult;
6use crate::base_types::{c32, c64};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10/// `Uplo` parameter for `?potrf` to specify which triangular part of the matrix is stored.
11#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12#[repr(u8)]
13pub enum PotrfUplo {
14    /// Upper triangular part of the matrix is stored.
15    Upper = b'U',
16    /// Lower triangular part of the matrix is stored.
17    Lower = b'L',
18}
19
20/// ?potrf - Cholesky factorization of a symmetric positive-definite matrix.
21pub trait Potrf: Sized {
22    /// Perform Cholesky factorization of a symmetric positive-definite matrix.
23    ///
24    /// If `uplo` is `Upper`, the factorization is of the form:
25    /// A = U^H * U, where U is an upper triangular matrix. If `uplo` is `Lower`, the
26    /// factorization is of the form:
27    /// A = L * L^H, where L is a lower triangular matrix.
28    ///
29    /// **Arguments:**
30    /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
31    /// - `n`: The order of the matrix A.
32    /// - `a`: The matrix A to be factored.
33    /// - `lda`: The leading dimension of the matrix A.
34    ///
35    /// **Returns:**
36    /// A `LapackResult<()>` indicating success or failure.
37    fn potrf(uplo: PotrfUplo, n: usize, a: &mut [Self], lda: usize) -> LapackResult<()>;
38}
39
40macro_rules! implement_potrf {
41    ($scalar:ty, $potrf:expr) => {
42        impl Potrf for $scalar {
43            fn potrf(uplo: PotrfUplo, n: usize, a: &mut [Self], lda: usize) -> LapackResult<()> {
44                assert_eq!(
45                    a.len(),
46                    lda * n,
47                    "Require `a.len()` {} == `lda * n` {}.",
48                    a.len(),
49                    lda * n
50                );
51
52                assert!(
53                    lda >= std::cmp::max(1, n),
54                    "Require `lda` {} >= `max(1, n)` {}.",
55                    lda,
56                    std::cmp::max(1, n)
57                );
58
59                let mut info = 0;
60
61                unsafe {
62                    $potrf(uplo as u8, n as i32, a, lda as i32, &mut info);
63                }
64
65                lapack_return(info, ())
66            }
67        }
68    };
69}
70
71implement_potrf!(f32, spotrf);
72implement_potrf!(f64, dpotrf);
73implement_potrf!(c32, cpotrf);
74implement_potrf!(c64, zpotrf);