rstsr_openblas/driver_impl/lapack/eigh/
syevd.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [ssyevd_];
13   [f64] [dsyevd_];
14)]
15impl SYEVDDriverAPI<T> for DeviceBLAS {
16    unsafe fn driver_syevd(
17        order: FlagOrder,
18        jobz: char,
19        uplo: FlagUpLo,
20        n: usize,
21        a: *mut T,
22        lda: usize,
23        w: *mut T,
24    ) -> blas_int {
25        use lapack_ffi::lapack::func_;
26
27        // Query optimal working array(s) size
28        let mut info = 0;
29        let lwork = -1;
30        let liwork = -1;
31        let mut work_query = 0.0;
32        let mut iwork_query = 0;
33        func_(
34            &(jobz as _),
35            &uplo.into(),
36            &(n as _),
37            a,
38            &(lda as _),
39            w,
40            &mut work_query,
41            &lwork,
42            &mut iwork_query,
43            &liwork,
44            &mut info,
45        );
46        if info != 0 {
47            return info;
48        }
49        let lwork = work_query as usize;
50        let liwork = iwork_query as usize;
51
52        // Allocate memory for temporary array(s)
53        let mut work: Vec<T> = match uninitialized_vec(lwork) {
54            Ok(work) => work,
55            Err(_) => return -1010,
56        };
57        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
58            Ok(iwork) => iwork,
59            Err(_) => return -1010,
60        };
61
62        if order == ColMajor {
63            // Call LAPACK function and adjust info
64            func_(
65                &(jobz as _),
66                &uplo.into(),
67                &(n as _),
68                a,
69                &(lda as _),
70                w,
71                work.as_mut_ptr(),
72                &(lwork as _),
73                iwork.as_mut_ptr(),
74                &(liwork as _),
75                &mut info,
76            );
77            if info != 0 {
78                return info;
79            }
80        } else {
81            let lda_t = n.max(1);
82            // Transpose input matrices
83            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
84                Ok(a_t) => a_t,
85                Err(_) => return -1011,
86            };
87            let a_slice = from_raw_parts_mut(a, n * lda);
88            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
89            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
90            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
91            // Call LAPACK function and adjust info
92            func_(
93                &(jobz as _),
94                &uplo.into(),
95                &(n as _),
96                a_t.as_mut_ptr(),
97                &(lda_t as _),
98                w,
99                work.as_mut_ptr(),
100                &(lwork as _),
101                iwork.as_mut_ptr(),
102                &(liwork as _),
103                &mut info,
104            );
105            if info != 0 {
106                return info;
107            }
108            // Transpose output matrices
109            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
110        }
111        return info;
112    }
113}
114
115#[duplicate_item(
116    T              func_   ;
117   [Complex<f32>] [cheevd_];
118   [Complex<f64>] [zheevd_];
119)]
120impl SYEVDDriverAPI<T> for DeviceBLAS {
121    unsafe fn driver_syevd(
122        order: FlagOrder,
123        jobz: char,
124        uplo: FlagUpLo,
125        n: usize,
126        a: *mut T,
127        lda: usize,
128        w: *mut <T as ComplexFloat>::Real,
129    ) -> blas_int {
130        use lapack_ffi::lapack::func_;
131
132        // Query optimal working array(s) size
133        let mut info = 0;
134        let lwork = -1;
135        let lrwork = -1;
136        let liwork = -1;
137        let mut work_query = 0.0;
138        let mut rwork_query = 0.0;
139        let mut iwork_query = 0;
140        func_(
141            &(jobz as _),
142            &uplo.into(),
143            &(n as _),
144            a as *mut _,
145            &(lda as _),
146            w as *mut _,
147            &mut work_query as *mut _ as *mut _,
148            &lwork,
149            &mut rwork_query as *mut _ as *mut _,
150            &lrwork,
151            &mut iwork_query,
152            &liwork,
153            &mut info,
154        );
155        if info != 0 {
156            return info;
157        }
158        let lwork = work_query as usize;
159        let lrwork = rwork_query as usize;
160        let liwork = iwork_query as usize;
161
162        // Allocate memory for temporary array(s)
163        let mut work: Vec<T> = match uninitialized_vec(lwork) {
164            Ok(work) => work,
165            Err(_) => return -1010,
166        };
167        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
168            Ok(rwork) => rwork,
169            Err(_) => return -1010,
170        };
171        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
172            Ok(iwork) => iwork,
173            Err(_) => return -1010,
174        };
175
176        if order == ColMajor {
177            // Call LAPACK function and adjust info
178            func_(
179                &(jobz as _),
180                &uplo.into(),
181                &(n as _),
182                a as *mut _,
183                &(lda as _),
184                w as *mut _,
185                work.as_mut_ptr() as *mut _,
186                &(lwork as _),
187                rwork.as_mut_ptr() as *mut _,
188                &(lrwork as _),
189                iwork.as_mut_ptr() as *mut _,
190                &(liwork as _),
191                &mut info,
192            );
193            if info != 0 {
194                return info;
195            }
196        } else {
197            let lda_t = n.max(1);
198            // Transpose input matrices
199            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
200                Ok(a_t) => a_t,
201                Err(_) => return -1011,
202            };
203            let a_slice = from_raw_parts_mut(a, n * lda);
204            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
205            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
206            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
207            // Call LAPACK function and adjust info
208            func_(
209                &(jobz as _),
210                &uplo.into(),
211                &(n as _),
212                a_t.as_mut_ptr() as *mut _,
213                &(lda_t as _),
214                w as *mut _,
215                work.as_mut_ptr() as *mut _,
216                &(lwork as _),
217                rwork.as_mut_ptr() as *mut _,
218                &(lrwork as _),
219                iwork.as_mut_ptr(),
220                &(liwork as _),
221                &mut info,
222            );
223            if info != 0 {
224                return info;
225            }
226            // Transpose output matrices
227            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
228        }
229        return info;
230    }
231}