rstsr_openblas/driver_impl/lapack/solve/
sysv.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [ssysv_];
13   [f64] [dsysv_];
14)]
15impl<const HERMI: bool> SYSVDriverAPI<T, HERMI> for DeviceBLAS {
16    unsafe fn driver_sysv(
17        order: FlagOrder,
18        uplo: FlagUpLo,
19        n: usize,
20        nrhs: usize,
21        a: *mut T,
22        lda: usize,
23        ipiv: *mut blas_int,
24        b: *mut T,
25        ldb: usize,
26    ) -> blas_int {
27        use lapack_ffi::lapack::func_;
28
29        // Query optimal working array(s) size
30        let mut info = 0;
31        let lwork = -1;
32        let mut work_query = 0.0;
33        func_(
34            &uplo.into(),
35            &(n as _),
36            &(nrhs as _),
37            a,
38            &(n as _),
39            ipiv,
40            b,
41            &(n as _),
42            &mut work_query,
43            &lwork,
44            &mut info,
45        );
46        if info != 0 {
47            return info;
48        }
49        let lwork = work_query as usize;
50
51        // Allocate memory for work arrays
52        let mut work: Vec<T> = match uninitialized_vec(lwork) {
53            Ok(work) => work,
54            Err(_) => return -1010,
55        };
56
57        if order == ColMajor {
58            // Call LAPACK function and adjust info
59            func_(
60                &uplo.into(),
61                &(n as _),
62                &(nrhs as _),
63                a,
64                &(lda as _),
65                ipiv,
66                b,
67                &(ldb as _),
68                work.as_mut_ptr(),
69                &(lwork as _),
70                &mut info,
71            );
72            if info != 0 {
73                return info;
74            }
75        } else {
76            let lda_t = n.max(1);
77            let ldb_t = n.max(1);
78            // Transpose input matrices
79            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
80                Ok(a_t) => a_t,
81                Err(_) => return -1011,
82            };
83            let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
84                Ok(b_t) => b_t,
85                Err(_) => return -1011,
86            };
87            let a_slice = from_raw_parts_mut(a, n * lda);
88            let b_slice = from_raw_parts_mut(b, n * ldb);
89            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
90            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
91            let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
92            let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
93            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
94            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
95            // Call LAPACK function and adjust info
96            func_(
97                &uplo.into(),
98                &(n as _),
99                &(nrhs as _),
100                a_t.as_mut_ptr(),
101                &(lda_t as _),
102                ipiv,
103                b_t.as_mut_ptr(),
104                &(ldb_t as _),
105                work.as_mut_ptr(),
106                &(lwork as _),
107                &mut info,
108            );
109            if info != 0 {
110                return info;
111            }
112            // Transpose output matrices
113            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
114            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
115        }
116        return info;
117    }
118}
119
120#[duplicate_item(
121    T              func_   HERMI  ;
122   [Complex<f32>] [csysv_] [false];
123   [Complex<f32>] [chesv_] [true ];
124   [Complex<f64>] [zsysv_] [false];
125   [Complex<f64>] [zhesv_] [true ];
126)]
127impl SYSVDriverAPI<T, HERMI> for DeviceBLAS {
128    unsafe fn driver_sysv(
129        order: FlagOrder,
130        uplo: FlagUpLo,
131        n: usize,
132        nrhs: usize,
133        a: *mut T,
134        lda: usize,
135        ipiv: *mut blas_int,
136        b: *mut T,
137        ldb: usize,
138    ) -> blas_int {
139        use lapack_ffi::lapack::func_;
140
141        // Query optimal working array(s) size
142        let mut info = 0;
143        let lwork = -1;
144        let mut work_query = 0.0;
145        func_(
146            &uplo.into(),
147            &(n as _),
148            &(nrhs as _),
149            a as *mut _,
150            &(n as _),
151            ipiv,
152            b as *mut _,
153            &(n as _),
154            &mut work_query as *mut _ as *mut _,
155            &lwork,
156            &mut info,
157        );
158        if info != 0 {
159            return info;
160        }
161        let lwork = work_query as usize;
162
163        // Allocate memory for work arrays
164        let mut work: Vec<T> = match uninitialized_vec(lwork) {
165            Ok(work) => work,
166            Err(_) => return -1010,
167        };
168
169        if order == ColMajor {
170            // Call LAPACK function and adjust info
171            func_(
172                &uplo.into(),
173                &(n as _),
174                &(nrhs as _),
175                a as *mut _,
176                &(lda as _),
177                ipiv,
178                b as *mut _,
179                &(ldb as _),
180                work.as_mut_ptr() as *mut _,
181                &(lwork as _),
182                &mut info,
183            );
184            if info != 0 {
185                return info;
186            }
187        } else {
188            let lda_t = n.max(1);
189            let ldb_t = n.max(1);
190            // Transpose input matrices
191            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
192                Ok(a_t) => a_t,
193                Err(_) => return -1011,
194            };
195            let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
196                Ok(b_t) => b_t,
197                Err(_) => return -1011,
198            };
199            let a_slice = from_raw_parts_mut(a, n * lda);
200            let b_slice = from_raw_parts_mut(b, n * ldb);
201            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
202            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
203            let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
204            let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
205            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
206            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
207            // Call LAPACK function and adjust info
208            func_(
209                &uplo.into(),
210                &(n as _),
211                &(nrhs as _),
212                a_t.as_mut_ptr() as *mut _,
213                &(lda_t as _),
214                ipiv,
215                b_t.as_mut_ptr() as *mut _,
216                &(ldb_t as _),
217                work.as_mut_ptr() as *mut _,
218                &(lwork as _),
219                &mut info,
220            );
221            if info != 0 {
222                return info;
223            }
224            // Transpose output matrices
225            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
226            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
227        }
228        return info;
229    }
230}