rstsr_openblas/driver_impl/lapack/eigh/
syev.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [ssyev_];
13 [f64] [dsyev_];
14)]
15impl SYEVDriverAPI<T> for DeviceBLAS {
16 unsafe fn driver_syev(
17 order: FlagOrder,
18 jobz: char,
19 uplo: FlagUpLo,
20 n: usize,
21 a: *mut T,
22 lda: usize,
23 w: *mut T,
24 ) -> blas_int {
25 use lapack_ffi::lapack::func_;
26
27 let mut info = 0;
29 let lwork = -1;
30 let mut work_query = 0.0;
31 func_(&(jobz as _), &uplo.into(), &(n as _), a, &(lda as _), w, &mut work_query, &lwork, &mut info);
32 if info != 0 {
33 return info;
34 }
35 let lwork = work_query as usize;
36
37 let mut work: Vec<T> = match uninitialized_vec(lwork) {
39 Ok(work) => work,
40 Err(_) => return -1010,
41 };
42
43 if order == ColMajor {
44 func_(
46 &(jobz as _),
47 &uplo.into(),
48 &(n as _),
49 a,
50 &(lda as _),
51 w,
52 work.as_mut_ptr(),
53 &(lwork as _),
54 &mut info,
55 );
56 if info != 0 {
57 return info;
58 }
59 } else {
60 let lda_t = n.max(1);
61 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
63 Ok(a_t) => a_t,
64 Err(_) => return -1011,
65 };
66 let a_slice = from_raw_parts_mut(a, n * lda);
67 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
68 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
69 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
70 func_(
72 &(jobz as _),
73 &uplo.into(),
74 &(n as _),
75 a_t.as_mut_ptr(),
76 &(lda_t as _),
77 w,
78 work.as_mut_ptr(),
79 &(lwork as _),
80 &mut info,
81 );
82 if info != 0 {
83 return info;
84 }
85 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
87 }
88 return info;
89 }
90}
91
92#[duplicate_item(
93 T func_ ;
94 [Complex<f32>] [cheev_];
95 [Complex<f64>] [zheev_];
96)]
97impl SYEVDriverAPI<T> for DeviceBLAS {
98 unsafe fn driver_syev(
99 order: FlagOrder,
100 jobz: char,
101 uplo: FlagUpLo,
102 n: usize,
103 a: *mut T,
104 lda: usize,
105 w: *mut <T as ComplexFloat>::Real,
106 ) -> blas_int {
107 use lapack_ffi::lapack::func_;
108
109 let rwork_len = (3 * n - 2).max(1);
111 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(rwork_len) {
112 Ok(rwork) => rwork,
113 Err(_) => return -1010,
114 };
115
116 let mut info = 0;
118 let lwork = -1;
119 let mut work_query = 0.0;
120 func_(
121 &(jobz as _),
122 &uplo.into(),
123 &(n as _),
124 a as *mut _,
125 &(lda as _),
126 w as *mut _,
127 &mut work_query as *mut _ as *mut _,
128 &lwork,
129 rwork.as_mut_ptr() as *mut _,
130 &mut info,
131 );
132 if info != 0 {
133 return info;
134 }
135 let lwork = work_query as usize;
136
137 let mut work: Vec<T> = match uninitialized_vec(lwork) {
139 Ok(work) => work,
140 Err(_) => return -1010,
141 };
142
143 if order == ColMajor {
144 func_(
146 &(jobz as _),
147 &uplo.into(),
148 &(n as _),
149 a as *mut _,
150 &(lda as _),
151 w as *mut _,
152 work.as_mut_ptr() as *mut _,
153 &(lwork as _),
154 rwork.as_mut_ptr() as *mut _,
155 &mut info,
156 );
157 if info != 0 {
158 return info;
159 }
160 } else {
161 let lda_t = n.max(1);
162 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
164 Ok(a_t) => a_t,
165 Err(_) => return -1011,
166 };
167 let a_slice = from_raw_parts_mut(a, n * lda);
168 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
169 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
170 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
171 func_(
173 &(jobz as _),
174 &uplo.into(),
175 &(n as _),
176 a_t.as_mut_ptr() as *mut _,
177 &(lda_t as _),
178 w as *mut _,
179 work.as_mut_ptr() as *mut _,
180 &(lwork as _),
181 rwork.as_mut_ptr() as *mut _,
182 &mut info,
183 );
184 if info != 0 {
185 return info;
186 }
187 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
189 }
190 return info;
191 }
192}