1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [ssygv_];
13 [f64] [dsygv_];
14)]
15impl SYGVDriverAPI<T> for DeviceBLAS {
16 unsafe fn driver_sygv(
17 order: FlagOrder,
18 itype: blas_int,
19 jobz: char,
20 uplo: FlagUpLo,
21 n: usize,
22 a: *mut T,
23 lda: usize,
24 b: *mut T,
25 ldb: usize,
26 w: *mut T,
27 ) -> blas_int {
28 use lapack_ffi::lapack::func_;
29
30 let mut info = 0;
32 let lwork = -1;
33 let mut work_query = 0.0;
34 func_(
35 &itype,
36 &(jobz as _),
37 &uplo.into(),
38 &(n as _),
39 a,
40 &(lda as _),
41 b,
42 &(ldb as _),
43 w,
44 &mut work_query,
45 &lwork,
46 &mut info,
47 );
48 if info != 0 {
49 return info;
50 }
51 let lwork = work_query as usize;
52
53 let mut work: Vec<T> = match uninitialized_vec(lwork) {
55 Ok(work) => work,
56 Err(_) => return -1010,
57 };
58
59 if order == ColMajor {
60 func_(
62 &itype,
63 &(jobz as _),
64 &uplo.into(),
65 &(n as _),
66 a,
67 &(lda as _),
68 b,
69 &(ldb as _),
70 w,
71 work.as_mut_ptr(),
72 &(lwork as _),
73 &mut info,
74 );
75 if info != 0 {
76 return info;
77 }
78 } else {
79 let lda_t = n.max(1);
80 let ldb_t = n.max(1);
81 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
83 Ok(a_t) => a_t,
84 Err(_) => return -1011,
85 };
86 let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
87 Ok(b_t) => b_t,
88 Err(_) => return -1011,
89 };
90 let a_slice = from_raw_parts_mut(a, n * lda);
91 let b_slice = from_raw_parts_mut(b, n * ldb);
92 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
93 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
94 let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
95 let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
96 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
97 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
98 func_(
100 &itype,
101 &(jobz as _),
102 &uplo.into(),
103 &(n as _),
104 a_t.as_mut_ptr(),
105 &(lda_t as _),
106 b_t.as_mut_ptr(),
107 &(ldb_t as _),
108 w,
109 work.as_mut_ptr(),
110 &(lwork as _),
111 &mut info,
112 );
113 if info != 0 {
114 return info;
115 }
116 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
118 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
119 }
120 return info;
121 }
122}
123
124#[duplicate_item(
125 T func_ ;
126 [Complex<f32>] [chegv_];
127 [Complex<f64>] [zhegv_];
128)]
129impl SYGVDriverAPI<T> for DeviceBLAS {
130 unsafe fn driver_sygv(
131 order: FlagOrder,
132 itype: blas_int,
133 jobz: char,
134 uplo: FlagUpLo,
135 n: usize,
136 a: *mut T,
137 lda: usize,
138 b: *mut T,
139 ldb: usize,
140 w: *mut <T as ComplexFloat>::Real,
141 ) -> blas_int {
142 use lapack_ffi::lapack::func_;
143
144 let rwork_len = (3 * n - 2).max(1);
146 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(rwork_len) {
147 Ok(rwork) => rwork,
148 Err(_) => return -1010,
149 };
150
151 let mut info = 0;
153 let lwork = -1;
154 let mut work_query = 0.0;
155 func_(
156 &itype,
157 &(jobz as _),
158 &uplo.into(),
159 &(n as _),
160 a as *mut _,
161 &(lda as _),
162 b as *mut _,
163 &(ldb as _),
164 w as *mut _,
165 &mut work_query as *mut _ as *mut _,
166 &lwork,
167 rwork.as_mut_ptr() as *mut _,
168 &mut info,
169 );
170 if info != 0 {
171 return info;
172 }
173 let lwork = work_query as usize;
174
175 let mut work: Vec<T> = match uninitialized_vec(lwork) {
177 Ok(work) => work,
178 Err(_) => return -1010,
179 };
180
181 if order == ColMajor {
182 func_(
184 &itype,
185 &(jobz as _),
186 &uplo.into(),
187 &(n as _),
188 a as *mut _,
189 &(lda as _),
190 b as *mut _,
191 &(ldb as _),
192 w as *mut _,
193 work.as_mut_ptr() as *mut _,
194 &(lwork as _),
195 rwork.as_mut_ptr() as *mut _,
196 &mut info,
197 );
198 if info != 0 {
199 return info;
200 }
201 } else {
202 let lda_t = n.max(1);
203 let ldb_t = n.max(1);
204 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
206 Ok(a_t) => a_t,
207 Err(_) => return -1011,
208 };
209 let mut b_t: Vec<T> = match uninitialized_vec(n * n) {
210 Ok(b_t) => b_t,
211 Err(_) => return -1011,
212 };
213 let a_slice = from_raw_parts_mut(a, n * lda);
214 let b_slice = from_raw_parts_mut(b, n * ldb);
215 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
216 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
217 let lb = Layout::new_unchecked([n, n], [ldb as isize, 1], 0);
218 let lb_t = Layout::new_unchecked([n, n], [1, ldb_t as isize], 0);
219 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
220 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
221 func_(
223 &itype,
224 &(jobz as _),
225 &uplo.into(),
226 &(n as _),
227 a_t.as_mut_ptr() as *mut _,
228 &(lda_t as _),
229 b_t.as_mut_ptr() as *mut _,
230 &(ldb_t as _),
231 w as *mut _,
232 work.as_mut_ptr() as *mut _,
233 &(lwork as _),
234 rwork.as_mut_ptr() as *mut _,
235 &mut info,
236 );
237 if info != 0 {
238 return info;
239 }
240 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
242 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
243 }
244 return info;
245 }
246}