rstsr_openblas/driver_impl/lapack/eigh/
syevd.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [ssyevd_];
13 [f64] [dsyevd_];
14)]
15impl SYEVDDriverAPI<T> for DeviceBLAS {
16 unsafe fn driver_syevd(
17 order: FlagOrder,
18 jobz: char,
19 uplo: FlagUpLo,
20 n: usize,
21 a: *mut T,
22 lda: usize,
23 w: *mut T,
24 ) -> blas_int {
25 use lapack_ffi::lapack::func_;
26
27 let mut info = 0;
29 let lwork = -1;
30 let liwork = -1;
31 let mut work_query = 0.0;
32 let mut iwork_query = 0;
33 func_(
34 &(jobz as _),
35 &uplo.into(),
36 &(n as _),
37 a,
38 &(lda as _),
39 w,
40 &mut work_query,
41 &lwork,
42 &mut iwork_query,
43 &liwork,
44 &mut info,
45 );
46 if info != 0 {
47 return info;
48 }
49 let lwork = work_query as usize;
50 let liwork = iwork_query as usize;
51
52 let mut work: Vec<T> = match uninitialized_vec(lwork) {
54 Ok(work) => work,
55 Err(_) => return -1010,
56 };
57 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
58 Ok(iwork) => iwork,
59 Err(_) => return -1010,
60 };
61
62 if order == ColMajor {
63 func_(
65 &(jobz as _),
66 &uplo.into(),
67 &(n as _),
68 a,
69 &(lda as _),
70 w,
71 work.as_mut_ptr(),
72 &(lwork as _),
73 iwork.as_mut_ptr(),
74 &(liwork as _),
75 &mut info,
76 );
77 if info != 0 {
78 return info;
79 }
80 } else {
81 let lda_t = n.max(1);
82 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
84 Ok(a_t) => a_t,
85 Err(_) => return -1011,
86 };
87 let a_slice = from_raw_parts_mut(a, n * lda);
88 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
89 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
90 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
91 func_(
93 &(jobz as _),
94 &uplo.into(),
95 &(n as _),
96 a_t.as_mut_ptr(),
97 &(lda_t as _),
98 w,
99 work.as_mut_ptr(),
100 &(lwork as _),
101 iwork.as_mut_ptr(),
102 &(liwork as _),
103 &mut info,
104 );
105 if info != 0 {
106 return info;
107 }
108 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
110 }
111 return info;
112 }
113}
114
115#[duplicate_item(
116 T func_ ;
117 [Complex<f32>] [cheevd_];
118 [Complex<f64>] [zheevd_];
119)]
120impl SYEVDDriverAPI<T> for DeviceBLAS {
121 unsafe fn driver_syevd(
122 order: FlagOrder,
123 jobz: char,
124 uplo: FlagUpLo,
125 n: usize,
126 a: *mut T,
127 lda: usize,
128 w: *mut <T as ComplexFloat>::Real,
129 ) -> blas_int {
130 use lapack_ffi::lapack::func_;
131
132 let mut info = 0;
134 let lwork = -1;
135 let lrwork = -1;
136 let liwork = -1;
137 let mut work_query = 0.0;
138 let mut rwork_query = 0.0;
139 let mut iwork_query = 0;
140 func_(
141 &(jobz as _),
142 &uplo.into(),
143 &(n as _),
144 a as *mut _,
145 &(lda as _),
146 w as *mut _,
147 &mut work_query as *mut _ as *mut _,
148 &lwork,
149 &mut rwork_query as *mut _ as *mut _,
150 &lrwork,
151 &mut iwork_query,
152 &liwork,
153 &mut info,
154 );
155 if info != 0 {
156 return info;
157 }
158 let lwork = work_query as usize;
159 let lrwork = rwork_query as usize;
160 let liwork = iwork_query as usize;
161
162 let mut work: Vec<T> = match uninitialized_vec(lwork) {
164 Ok(work) => work,
165 Err(_) => return -1010,
166 };
167 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
168 Ok(rwork) => rwork,
169 Err(_) => return -1010,
170 };
171 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
172 Ok(iwork) => iwork,
173 Err(_) => return -1010,
174 };
175
176 if order == ColMajor {
177 func_(
179 &(jobz as _),
180 &uplo.into(),
181 &(n as _),
182 a as *mut _,
183 &(lda as _),
184 w as *mut _,
185 work.as_mut_ptr() as *mut _,
186 &(lwork as _),
187 rwork.as_mut_ptr() as *mut _,
188 &(lrwork as _),
189 iwork.as_mut_ptr() as *mut _,
190 &(liwork as _),
191 &mut info,
192 );
193 if info != 0 {
194 return info;
195 }
196 } else {
197 let lda_t = n.max(1);
198 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
200 Ok(a_t) => a_t,
201 Err(_) => return -1011,
202 };
203 let a_slice = from_raw_parts_mut(a, n * lda);
204 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
205 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
206 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
207 func_(
209 &(jobz as _),
210 &uplo.into(),
211 &(n as _),
212 a_t.as_mut_ptr() as *mut _,
213 &(lda_t as _),
214 w as *mut _,
215 work.as_mut_ptr() as *mut _,
216 &(lwork as _),
217 rwork.as_mut_ptr() as *mut _,
218 &(lrwork as _),
219 iwork.as_mut_ptr(),
220 &(liwork as _),
221 &mut info,
222 );
223 if info != 0 {
224 return info;
225 }
226 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
228 }
229 return info;
230 }
231}