rstsr_openblas/driver_impl/lapack/eigh/
syev.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [ssyev_];
13   [f64] [dsyev_];
14)]
15impl SYEVDriverAPI<T> for DeviceBLAS {
16    unsafe fn driver_syev(
17        order: FlagOrder,
18        jobz: char,
19        uplo: FlagUpLo,
20        n: usize,
21        a: *mut T,
22        lda: usize,
23        w: *mut T,
24    ) -> blas_int {
25        use lapack_ffi::lapack::func_;
26
27        // Query optimal working array(s) size
28        let mut info = 0;
29        let lwork = -1;
30        let mut work_query = 0.0;
31        func_(&(jobz as _), &uplo.into(), &(n as _), a, &(lda as _), w, &mut work_query, &lwork, &mut info);
32        if info != 0 {
33            return info;
34        }
35        let lwork = work_query as usize;
36
37        // Allocate memory for work arrays
38        let mut work: Vec<T> = match uninitialized_vec(lwork) {
39            Ok(work) => work,
40            Err(_) => return -1010,
41        };
42
43        if order == ColMajor {
44            // Call LAPACK function and adjust info
45            func_(
46                &(jobz as _),
47                &uplo.into(),
48                &(n as _),
49                a,
50                &(lda as _),
51                w,
52                work.as_mut_ptr(),
53                &(lwork as _),
54                &mut info,
55            );
56            if info != 0 {
57                return info;
58            }
59        } else {
60            let lda_t = n.max(1);
61            // Transpose input matrices
62            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
63                Ok(a_t) => a_t,
64                Err(_) => return -1011,
65            };
66            let a_slice = from_raw_parts_mut(a, n * lda);
67            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
68            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
69            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
70            // Call LAPACK function and adjust info
71            func_(
72                &(jobz as _),
73                &uplo.into(),
74                &(n as _),
75                a_t.as_mut_ptr(),
76                &(lda_t as _),
77                w,
78                work.as_mut_ptr(),
79                &(lwork as _),
80                &mut info,
81            );
82            if info != 0 {
83                return info;
84            }
85            // Transpose output matrices
86            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
87        }
88        return info;
89    }
90}
91
92#[duplicate_item(
93    T              func_   ;
94   [Complex<f32>] [cheev_];
95   [Complex<f64>] [zheev_];
96)]
97impl SYEVDriverAPI<T> for DeviceBLAS {
98    unsafe fn driver_syev(
99        order: FlagOrder,
100        jobz: char,
101        uplo: FlagUpLo,
102        n: usize,
103        a: *mut T,
104        lda: usize,
105        w: *mut <T as ComplexFloat>::Real,
106    ) -> blas_int {
107        use lapack_ffi::lapack::func_;
108
109        // Allocate memory for working array(s)
110        let rwork_len = (3 * n - 2).max(1);
111        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(rwork_len) {
112            Ok(rwork) => rwork,
113            Err(_) => return -1010,
114        };
115
116        // Query optimal working array(s) size
117        let mut info = 0;
118        let lwork = -1;
119        let mut work_query = 0.0;
120        func_(
121            &(jobz as _),
122            &uplo.into(),
123            &(n as _),
124            a as *mut _,
125            &(lda as _),
126            w as *mut _,
127            &mut work_query as *mut _ as *mut _,
128            &lwork,
129            rwork.as_mut_ptr() as *mut _,
130            &mut info,
131        );
132        if info != 0 {
133            return info;
134        }
135        let lwork = work_query as usize;
136
137        // Allocate memory for work arrays
138        let mut work: Vec<T> = match uninitialized_vec(lwork) {
139            Ok(work) => work,
140            Err(_) => return -1010,
141        };
142
143        if order == ColMajor {
144            // Call LAPACK function and adjust info
145            func_(
146                &(jobz as _),
147                &uplo.into(),
148                &(n as _),
149                a as *mut _,
150                &(lda as _),
151                w as *mut _,
152                work.as_mut_ptr() as *mut _,
153                &(lwork as _),
154                rwork.as_mut_ptr() as *mut _,
155                &mut info,
156            );
157            if info != 0 {
158                return info;
159            }
160        } else {
161            let lda_t = n.max(1);
162            // Transpose input matrices
163            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
164                Ok(a_t) => a_t,
165                Err(_) => return -1011,
166            };
167            let a_slice = from_raw_parts_mut(a, n * lda);
168            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
169            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
170            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
171            // Call LAPACK function and adjust info
172            func_(
173                &(jobz as _),
174                &uplo.into(),
175                &(n as _),
176                a_t.as_mut_ptr() as *mut _,
177                &(lda_t as _),
178                w as *mut _,
179                work.as_mut_ptr() as *mut _,
180                &(lwork as _),
181                rwork.as_mut_ptr() as *mut _,
182                &mut info,
183            );
184            if info != 0 {
185                return info;
186            }
187            // Transpose output matrices
188            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
189        }
190        return info;
191    }
192}