rstsr_openblas/driver_impl/lapack/eigh/
sygv.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [ssygv_];
13   [f64] [dsygv_];
14)]
15impl SYGVDriverAPI<T> for DeviceBLAS {
16    unsafe fn driver_sygv(
17        order: FlagOrder,
18        itype: blas_int,
19        jobz: char,
20        uplo: FlagUpLo,
21        n: usize,
22        a: *mut T,
23        lda: usize,
24        b: *mut T,
25        ldb: usize,
26        w: *mut T,
27    ) -> blas_int {
28        use lapack_ffi::lapack::func_;
29
30        // Query optimal working array(s) size
31        let mut info = 0;
32        let lwork = -1;
33        let mut work_query = 0.0;
34        func_(
35            &itype,
36            &(jobz as _),
37            &uplo.into(),
38            &(n as _),
39            a,
40            &(lda as _),
41            b,
42            &(ldb as _),
43            w,
44            &mut work_query,
45            &lwork,
46            &mut info,
47        );
48        if info != 0 {
49            return info;
50        }
51        let lwork = work_query as usize;
52
53        // Allocate memory for temporary array(s)
54        let mut work: Vec<T> = match uninitialized_vec(lwork) {
55            Ok(work) => work,
56            Err(_) => return -1010,
57        };
58
59        if order == ColMajor {
60            // Call LAPACK function and adjust info
61            func_(
62                &itype,
63                &(jobz as _),
64                &uplo.into(),
65                &(n as _),
66                a,
67                &(lda as _),
68                b,
69                &(ldb as _),
70                w,
71                work.as_mut_ptr(),
72                &(lwork as _),
73                &mut info,
74            );
75            if info != 0 {
76                return info;
77            }
78        } else {
79            let lda_t = n.max(1);
80            let ldb_t = n.max(1);
81            // Transpose input matrices
82            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
83                Ok(a_t) => a_t,
84                Err(_) => return -1011,
85            };
86            let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
87                Ok(b_t) => b_t,
88                Err(_) => return -1011,
89            };
90            let a_slice = from_raw_parts_mut(a, n * lda);
91            let b_slice = from_raw_parts_mut(b, n * ldb);
92            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
93            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
94            let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
95            let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
96            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
97            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
98            // Call LAPACK function and adjust info
99            func_(
100                &itype,
101                &(jobz as _),
102                &uplo.into(),
103                &(n as _),
104                a_t.as_mut_ptr(),
105                &(lda_t as _),
106                b_t.as_mut_ptr(),
107                &(ldb_t as _),
108                w,
109                work.as_mut_ptr(),
110                &(lwork as _),
111                &mut info,
112            );
113            if info != 0 {
114                return info;
115            }
116            // Transpose output matrices
117            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
118            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
119        }
120        return info;
121    }
122}
123
124#[duplicate_item(
125    T              func_   ;
126   [Complex<f32>] [chegv_];
127   [Complex<f64>] [zhegv_];
128)]
129impl SYGVDriverAPI<T> for DeviceBLAS {
130    unsafe fn driver_sygv(
131        order: FlagOrder,
132        itype: blas_int,
133        jobz: char,
134        uplo: FlagUpLo,
135        n: usize,
136        a: *mut T,
137        lda: usize,
138        b: *mut T,
139        ldb: usize,
140        w: *mut <T as ComplexFloat>::Real,
141    ) -> blas_int {
142        use lapack_ffi::lapack::func_;
143
144        // Allocate memory for working array(s)
145        let rwork_len = (3 * n - 2).max(1);
146        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(rwork_len) {
147            Ok(rwork) => rwork,
148            Err(_) => return -1010,
149        };
150
151        // Query optimal working array(s) size
152        let mut info = 0;
153        let lwork = -1;
154        let mut work_query = 0.0;
155        func_(
156            &itype,
157            &(jobz as _),
158            &uplo.into(),
159            &(n as _),
160            a as *mut _,
161            &(lda as _),
162            b as *mut _,
163            &(ldb as _),
164            w as *mut _,
165            &mut work_query as *mut _ as *mut _,
166            &lwork,
167            rwork.as_mut_ptr() as *mut _,
168            &mut info,
169        );
170        if info != 0 {
171            return info;
172        }
173        let lwork = work_query as usize;
174
175        // Allocate memory for work arrays
176        let mut work: Vec<T> = match uninitialized_vec(lwork) {
177            Ok(work) => work,
178            Err(_) => return -1010,
179        };
180
181        if order == ColMajor {
182            // Call LAPACK function and adjust info
183            func_(
184                &itype,
185                &(jobz as _),
186                &uplo.into(),
187                &(n as _),
188                a as *mut _,
189                &(lda as _),
190                b as *mut _,
191                &(ldb as _),
192                w as *mut _,
193                work.as_mut_ptr() as *mut _,
194                &(lwork as _),
195                rwork.as_mut_ptr() as *mut _,
196                &mut info,
197            );
198            if info != 0 {
199                return info;
200            }
201        } else {
202            let lda_t = n.max(1);
203            let ldb_t = n.max(1);
204            // Transpose input matrices
205            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
206                Ok(a_t) => a_t,
207                Err(_) => return -1011,
208            };
209            let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
210                Ok(b_t) => b_t,
211                Err(_) => return -1011,
212            };
213            let a_slice = from_raw_parts_mut(a, n * lda);
214            let b_slice = from_raw_parts_mut(b, n * ldb);
215            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
216            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
217            let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
218            let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
219            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
220            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
221            // Call LAPACK function and adjust info
222            func_(
223                &itype,
224                &(jobz as _),
225                &uplo.into(),
226                &(n as _),
227                a_t.as_mut_ptr() as *mut _,
228                &(lda_t as _),
229                b_t.as_mut_ptr() as *mut _,
230                &(ldb_t as _),
231                w as *mut _,
232                work.as_mut_ptr() as *mut _,
233                &(lwork as _),
234                rwork.as_mut_ptr() as *mut _,
235                &mut info,
236            );
237            if info != 0 {
238                return info;
239            }
240            // Transpose output matrices
241            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
242            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
243        }
244        return info;
245    }
246}