rstsr_openblas/driver_impl/lapack/solve/
potrf.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10    T     func_   ;
11   [f32] [spotrf_];
12   [f64] [dpotrf_];
13)]
14impl POTRFDriverAPI<T> for DeviceBLAS {
15    unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int {
16        use lapack_ffi::lapack::func_;
17
18        let mut info = 0;
19
20        if order == ColMajor {
21            // Call LAPACK function and adjust info
22            func_(&uplo.into(), &(n as _), a, &(lda as _), &mut info);
23            if info != 0 {
24                return info;
25            }
26        } else {
27            let lda_t = n.max(1);
28            // Transpose input matrices
29            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
30                Ok(a_t) => a_t,
31                Err(_) => return -1011,
32            };
33            let a_slice = from_raw_parts_mut(a, n * lda);
34            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
35            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
36            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
37            // Call LAPACK function and adjust info
38            func_(&uplo.into(), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), &mut info);
39            if info != 0 {
40                return info;
41            }
42            // Transpose output matrices
43            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
44        }
45        return info;
46    }
47}
48
49#[duplicate_item(
50    T              func_   ;
51   [Complex<f32>] [cpotrf_];
52   [Complex<f64>] [zpotrf_];
53)]
54impl POTRFDriverAPI<T> for DeviceBLAS {
55    unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int {
56        use lapack_ffi::lapack::func_;
57
58        let mut info = 0;
59
60        if order == ColMajor {
61            // Call LAPACK function and adjust info
62            func_(&uplo.into(), &(n as _), a as *mut _, &(lda as _), &mut info);
63            if info != 0 {
64                return info;
65            }
66        } else {
67            let lda_t = n.max(1);
68            // Transpose input matrices
69            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
70                Ok(a_t) => a_t,
71                Err(_) => return -1011,
72            };
73            let a_slice = from_raw_parts_mut(a, n * lda);
74            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
75            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
76            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
77            // Call LAPACK function and adjust info
78            func_(&uplo.into(), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), &mut info);
79            if info != 0 {
80                return info;
81            }
82            // Transpose output matrices
83            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
84        }
85        return info;
86    }
87}