1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [ssygvd_];
13 [f64] [dsygvd_];
14)]
15impl SYGVDDriverAPI<T> for DeviceBLAS {
16 unsafe fn driver_sygvd(
17 order: FlagOrder,
18 itype: blas_int,
19 jobz: char,
20 uplo: FlagUpLo,
21 n: usize,
22 a: *mut T,
23 lda: usize,
24 b: *mut T,
25 ldb: usize,
26 w: *mut T,
27 ) -> blas_int {
28 use lapack_ffi::lapack::func_;
29
30 let mut info = 0;
32 let lwork = -1;
33 let liwork = -1;
34 let mut work_query = 0.0;
35 let mut iwork_query = 0;
36 func_(
37 &itype,
38 &(jobz as _),
39 &uplo.into(),
40 &(n as _),
41 a,
42 &(lda as _),
43 b,
44 &(ldb as _),
45 w,
46 &mut work_query,
47 &lwork,
48 &mut iwork_query,
49 &liwork,
50 &mut info,
51 );
52 if info != 0 {
53 return info;
54 }
55 let lwork = work_query as usize;
56 let liwork = iwork_query as usize;
57
58 let mut work: Vec<T> = match uninitialized_vec(lwork) {
60 Ok(work) => work,
61 Err(_) => return -1010,
62 };
63 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
64 Ok(iwork) => iwork,
65 Err(_) => return -1010,
66 };
67
68 if order == ColMajor {
69 func_(
71 &itype,
72 &(jobz as _),
73 &uplo.into(),
74 &(n as _),
75 a,
76 &(lda as _),
77 b,
78 &(ldb as _),
79 w,
80 work.as_mut_ptr(),
81 &(lwork as _),
82 iwork.as_mut_ptr(),
83 &(liwork as _),
84 &mut info,
85 );
86 if info != 0 {
87 return info;
88 }
89 } else {
90 let lda_t = n.max(1);
91 let ldb_t = n.max(1);
92 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
94 Ok(a_t) => a_t,
95 Err(_) => return -1011,
96 };
97 let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
98 Ok(b_t) => b_t,
99 Err(_) => return -1011,
100 };
101 let a_slice = from_raw_parts_mut(a, n * lda);
102 let b_slice = from_raw_parts_mut(b, n * ldb);
103 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
104 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
105 let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
106 let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
107 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
108 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
109 func_(
111 &itype,
112 &(jobz as _),
113 &uplo.into(),
114 &(n as _),
115 a_t.as_mut_ptr(),
116 &(lda_t as _),
117 b_t.as_mut_ptr(),
118 &(ldb_t as _),
119 w,
120 work.as_mut_ptr(),
121 &(lwork as _),
122 iwork.as_mut_ptr(),
123 &(liwork as _),
124 &mut info,
125 );
126 if info != 0 {
127 return info;
128 }
129 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
131 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
132 }
133 return info;
134 }
135}
136
137#[duplicate_item(
138 T func_ ;
139 [Complex<f32>] [chegvd_];
140 [Complex<f64>] [zhegvd_];
141)]
142impl SYGVDDriverAPI<T> for DeviceBLAS {
143 unsafe fn driver_sygvd(
144 order: FlagOrder,
145 itype: blas_int,
146 jobz: char,
147 uplo: FlagUpLo,
148 n: usize,
149 a: *mut T,
150 lda: usize,
151 b: *mut T,
152 ldb: usize,
153 w: *mut <T as ComplexFloat>::Real,
154 ) -> blas_int {
155 use lapack_ffi::lapack::func_;
156
157 let mut info = 0;
159 let lwork = -1;
160 let lrwork = -1;
161 let liwork = -1;
162 let mut work_query = 0.0;
163 let mut rwork_query = 0.0;
164 let mut iwork_query = 0;
165 func_(
166 &itype,
167 &(jobz as _),
168 &uplo.into(),
169 &(n as _),
170 a as *mut _,
171 &(lda as _),
172 b as *mut _,
173 &(ldb as _),
174 w as *mut _,
175 &mut work_query as *mut _ as *mut _,
176 &lwork,
177 &mut rwork_query as *mut _ as *mut _,
178 &lrwork,
179 &mut iwork_query,
180 &liwork,
181 &mut info,
182 );
183 if info != 0 {
184 return info;
185 }
186 let lwork = work_query as usize;
187 let lrwork = rwork_query as usize;
188 let liwork = iwork_query as usize;
189
190 let mut work: Vec<T> = match uninitialized_vec(lwork) {
192 Ok(work) => work,
193 Err(_) => return -1010,
194 };
195 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
196 Ok(rwork) => rwork,
197 Err(_) => return -1010,
198 };
199 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
200 Ok(iwork) => iwork,
201 Err(_) => return -1010,
202 };
203
204 if order == ColMajor {
205 func_(
207 &itype,
208 &(jobz as _),
209 &uplo.into(),
210 &(n as _),
211 a as *mut _,
212 &(lda as _),
213 b as *mut _,
214 &(ldb as _),
215 w as *mut _,
216 work.as_mut_ptr() as *mut _,
217 &(lwork as _),
218 rwork.as_mut_ptr() as *mut _,
219 &(lrwork as _),
220 iwork.as_mut_ptr() as *mut _,
221 &(liwork as _),
222 &mut info,
223 );
224 if info != 0 {
225 return info;
226 }
227 } else {
228 let lda_t = n.max(1);
229 let ldb_t = n.max(1);
230 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
232 Ok(a_t) => a_t,
233 Err(_) => return -1011,
234 };
235 let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
236 Ok(b_t) => b_t,
237 Err(_) => return -1011,
238 };
239 let a_slice = from_raw_parts_mut(a, n * lda);
240 let b_slice = from_raw_parts_mut(b, n * ldb);
241 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
242 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
243 let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
244 let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
245 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
246 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
247 func_(
249 &itype,
250 &(jobz as _),
251 &uplo.into(),
252 &(n as _),
253 a_t.as_mut_ptr() as *mut _,
254 &(lda_t as _),
255 b_t.as_mut_ptr() as *mut _,
256 &(ldb_t as _),
257 w as *mut _,
258 work.as_mut_ptr() as *mut _,
259 &(lwork as _),
260 rwork.as_mut_ptr() as *mut _,
261 &(lrwork as _),
262 iwork.as_mut_ptr(),
263 &(liwork as _),
264 &mut info,
265 );
266 if info != 0 {
267 return info;
268 }
269 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
271 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
272 }
273 return info;
274 }
275}