rstsr_openblas/driver_impl/lapack/solve/
potrf.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10 T func_ ;
11 [f32] [spotrf_];
12 [f64] [dpotrf_];
13)]
14impl POTRFDriverAPI<T> for DeviceBLAS {
15 unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int {
16 use lapack_ffi::lapack::func_;
17
18 let mut info = 0;
19
20 if order == ColMajor {
21 func_(&uplo.into(), &(n as _), a, &(lda as _), &mut info);
23 if info != 0 {
24 return info;
25 }
26 } else {
27 let lda_t = n.max(1);
28 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
30 Ok(a_t) => a_t,
31 Err(_) => return -1011,
32 };
33 let a_slice = from_raw_parts_mut(a, n * lda);
34 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
35 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
36 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
37 func_(&uplo.into(), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), &mut info);
39 if info != 0 {
40 return info;
41 }
42 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
44 }
45 return info;
46 }
47}
48
49#[duplicate_item(
50 T func_ ;
51 [Complex<f32>] [cpotrf_];
52 [Complex<f64>] [zpotrf_];
53)]
54impl POTRFDriverAPI<T> for DeviceBLAS {
55 unsafe fn driver_potrf(order: FlagOrder, uplo: FlagUpLo, n: usize, a: *mut T, lda: usize) -> blas_int {
56 use lapack_ffi::lapack::func_;
57
58 let mut info = 0;
59
60 if order == ColMajor {
61 func_(&uplo.into(), &(n as _), a as *mut _, &(lda as _), &mut info);
63 if info != 0 {
64 return info;
65 }
66 } else {
67 let lda_t = n.max(1);
68 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
70 Ok(a_t) => a_t,
71 Err(_) => return -1011,
72 };
73 let a_slice = from_raw_parts_mut(a, n * lda);
74 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
75 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
76 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
77 func_(&uplo.into(), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), &mut info);
79 if info != 0 {
80 return info;
81 }
82 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
84 }
85 return info;
86 }
87}