rstsr_openblas/driver_impl/lapack/eigh/
sygvd.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [ssygvd_];
13   [f64] [dsygvd_];
14)]
15impl SYGVDDriverAPI<T> for DeviceBLAS {
16    unsafe fn driver_sygvd(
17        order: FlagOrder,
18        itype: blas_int,
19        jobz: char,
20        uplo: FlagUpLo,
21        n: usize,
22        a: *mut T,
23        lda: usize,
24        b: *mut T,
25        ldb: usize,
26        w: *mut T,
27    ) -> blas_int {
28        use lapack_ffi::lapack::func_;
29
30        // Query optimal working array(s) size
31        let mut info = 0;
32        let lwork = -1;
33        let liwork = -1;
34        let mut work_query = 0.0;
35        let mut iwork_query = 0;
36        func_(
37            &itype,
38            &(jobz as _),
39            &uplo.into(),
40            &(n as _),
41            a,
42            &(lda as _),
43            b,
44            &(ldb as _),
45            w,
46            &mut work_query,
47            &lwork,
48            &mut iwork_query,
49            &liwork,
50            &mut info,
51        );
52        if info != 0 {
53            return info;
54        }
55        let lwork = work_query as usize;
56        let liwork = iwork_query as usize;
57
58        // Allocate memory for temporary array(s)
59        let mut work: Vec<T> = match uninitialized_vec(lwork) {
60            Ok(work) => work,
61            Err(_) => return -1010,
62        };
63        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
64            Ok(iwork) => iwork,
65            Err(_) => return -1010,
66        };
67
68        if order == ColMajor {
69            // Call LAPACK function and adjust info
70            func_(
71                &itype,
72                &(jobz as _),
73                &uplo.into(),
74                &(n as _),
75                a,
76                &(lda as _),
77                b,
78                &(ldb as _),
79                w,
80                work.as_mut_ptr(),
81                &(lwork as _),
82                iwork.as_mut_ptr(),
83                &(liwork as _),
84                &mut info,
85            );
86            if info != 0 {
87                return info;
88            }
89        } else {
90            let lda_t = n.max(1);
91            let ldb_t = n.max(1);
92            // Transpose input matrices
93            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
94                Ok(a_t) => a_t,
95                Err(_) => return -1011,
96            };
97            let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
98                Ok(b_t) => b_t,
99                Err(_) => return -1011,
100            };
101            let a_slice = from_raw_parts_mut(a, n * lda);
102            let b_slice = from_raw_parts_mut(b, n * ldb);
103            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
104            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
105            let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
106            let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
107            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
108            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
109            // Call LAPACK function and adjust info
110            func_(
111                &itype,
112                &(jobz as _),
113                &uplo.into(),
114                &(n as _),
115                a_t.as_mut_ptr(),
116                &(lda_t as _),
117                b_t.as_mut_ptr(),
118                &(ldb_t as _),
119                w,
120                work.as_mut_ptr(),
121                &(lwork as _),
122                iwork.as_mut_ptr(),
123                &(liwork as _),
124                &mut info,
125            );
126            if info != 0 {
127                return info;
128            }
129            // Transpose output matrices
130            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
131            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
132        }
133        return info;
134    }
135}
136
137#[duplicate_item(
138    T              func_   ;
139   [Complex<f32>] [chegvd_];
140   [Complex<f64>] [zhegvd_];
141)]
142impl SYGVDDriverAPI<T> for DeviceBLAS {
143    unsafe fn driver_sygvd(
144        order: FlagOrder,
145        itype: blas_int,
146        jobz: char,
147        uplo: FlagUpLo,
148        n: usize,
149        a: *mut T,
150        lda: usize,
151        b: *mut T,
152        ldb: usize,
153        w: *mut <T as ComplexFloat>::Real,
154    ) -> blas_int {
155        use lapack_ffi::lapack::func_;
156
157        // Query optimal working array(s) size
158        let mut info = 0;
159        let lwork = -1;
160        let lrwork = -1;
161        let liwork = -1;
162        let mut work_query = 0.0;
163        let mut rwork_query = 0.0;
164        let mut iwork_query = 0;
165        func_(
166            &itype,
167            &(jobz as _),
168            &uplo.into(),
169            &(n as _),
170            a as *mut _,
171            &(lda as _),
172            b as *mut _,
173            &(ldb as _),
174            w as *mut _,
175            &mut work_query as *mut _ as *mut _,
176            &lwork,
177            &mut rwork_query as *mut _ as *mut _,
178            &lrwork,
179            &mut iwork_query,
180            &liwork,
181            &mut info,
182        );
183        if info != 0 {
184            return info;
185        }
186        let lwork = work_query as usize;
187        let lrwork = rwork_query as usize;
188        let liwork = iwork_query as usize;
189
190        // Allocate memory for temporary array(s)
191        let mut work: Vec<T> = match uninitialized_vec(lwork) {
192            Ok(work) => work,
193            Err(_) => return -1010,
194        };
195        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
196            Ok(rwork) => rwork,
197            Err(_) => return -1010,
198        };
199        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
200            Ok(iwork) => iwork,
201            Err(_) => return -1010,
202        };
203
204        if order == ColMajor {
205            // Call LAPACK function and adjust info
206            func_(
207                &itype,
208                &(jobz as _),
209                &uplo.into(),
210                &(n as _),
211                a as *mut _,
212                &(lda as _),
213                b as *mut _,
214                &(ldb as _),
215                w as *mut _,
216                work.as_mut_ptr() as *mut _,
217                &(lwork as _),
218                rwork.as_mut_ptr() as *mut _,
219                &(lrwork as _),
220                iwork.as_mut_ptr() as *mut _,
221                &(liwork as _),
222                &mut info,
223            );
224            if info != 0 {
225                return info;
226            }
227        } else {
228            let lda_t = n.max(1);
229            let ldb_t = n.max(1);
230            // Transpose input matrices
231            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
232                Ok(a_t) => a_t,
233                Err(_) => return -1011,
234            };
235            let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
236                Ok(b_t) => b_t,
237                Err(_) => return -1011,
238            };
239            let a_slice = from_raw_parts_mut(a, n * lda);
240            let b_slice = from_raw_parts_mut(b, n * ldb);
241            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
242            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
243            let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
244            let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
245            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
246            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
247            // Call LAPACK function and adjust info
248            func_(
249                &itype,
250                &(jobz as _),
251                &uplo.into(),
252                &(n as _),
253                a_t.as_mut_ptr() as *mut _,
254                &(lda_t as _),
255                b_t.as_mut_ptr() as *mut _,
256                &(ldb_t as _),
257                w as *mut _,
258                work.as_mut_ptr() as *mut _,
259                &(lwork as _),
260                rwork.as_mut_ptr() as *mut _,
261                &(lrwork as _),
262                iwork.as_mut_ptr(),
263                &(liwork as _),
264                &mut info,
265            );
266            if info != 0 {
267                return info;
268            }
269            // Transpose output matrices
270            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
271            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
272        }
273        return info;
274    }
275}