1use cblas_sys::{CBLAS_LAYOUT, CBLAS_TRANSPOSE, CBLAS_UPLO};
2use mdarray::{DSlice, Layout, Shape, Slice};
3use mdarray_linalg::{into_i32, trans_stride};
4use num_complex::ComplexFloat;
5
6use num_complex::Complex;
7use std::any::TypeId;
8
9use super::scalar::BlasScalar;
10
11pub fn gemv<T, La, Lx, Ly>(
12 alpha: T,
13 a: &DSlice<T, 2, La>,
14 x: &DSlice<T, 1, Lx>,
15 beta: T,
16 y: &mut DSlice<T, 1, Ly>,
17) where
18 T: BlasScalar + ComplexFloat,
19 La: Layout,
20 Lx: Layout,
21 Ly: Layout,
22{
23 let (m, n) = *a.shape();
24
25 if a.stride(1) == 1 {
26 assert_eq!(x.len(), n, "x length must match number of columns in a");
27 } else {
28 assert_eq!(x.len(), m, "x length must match number of rows in a");
29 }
30
31 assert_eq!(
32 y.len(),
33 if a.stride(1) == 1 { m } else { n },
34 "y length must match the output dimension"
35 );
36
37 let row_major = a.stride(1) == 1;
38 assert!(
39 row_major || a.stride(0) == 1,
40 "a must be contiguous in one dimension"
41 );
42
43 let (same_order, other_order) = if row_major {
44 (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans)
45 } else {
46 (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans)
47 };
48 let (a_trans, a_stride) = trans_stride!(a, same_order, other_order);
49
50 let x_inc = into_i32(x.stride(0));
51 let y_inc = into_i32(y.stride(0));
52
53 unsafe {
54 T::cblas_gemv(
55 if row_major {
56 CBLAS_LAYOUT::CblasRowMajor
57 } else {
58 CBLAS_LAYOUT::CblasColMajor
59 },
60 a_trans,
61 into_i32(m),
62 into_i32(n),
63 alpha,
64 a.as_ptr(),
65 a_stride,
66 x.as_ptr(),
67 x_inc,
68 beta,
69 y.as_mut_ptr(),
70 y_inc,
71 )
72 }
73}
74
75pub fn ger<T, La, Lx, Ly>(
76 beta: T,
77 x: &DSlice<T, 1, Lx>,
78 y: &DSlice<T, 1, Ly>,
79 a: &mut DSlice<T, 2, La>,
80) where
81 T: BlasScalar + ComplexFloat,
82 La: Layout,
83 Lx: Layout,
84 Ly: Layout,
85{
86 let (m, n) = *a.shape();
87
88 assert_eq!(x.len(), m, "x length must match number of rows in a");
89 assert_eq!(y.len(), n, "y length must match number of columns in a");
90
91 let x_inc = into_i32(x.stride(0));
92 let y_inc = into_i32(y.stride(0));
93
94 let row_major = a.stride(1) == 1;
95 assert!(
96 row_major || a.stride(0) == 1,
97 "a must be contiguous in one dimension"
98 );
99
100 let lda = if row_major { into_i32(n) } else { into_i32(m) };
101
102 unsafe {
103 T::cblas_ger(
104 if row_major {
105 CBLAS_LAYOUT::CblasRowMajor
106 } else {
107 CBLAS_LAYOUT::CblasColMajor
108 },
109 into_i32(m),
110 into_i32(n),
111 beta,
112 x.as_ptr(),
113 x_inc,
114 y.as_ptr(),
115 y_inc,
116 a.as_mut_ptr(),
117 lda,
118 )
119 }
120}
121
122pub fn scal<T, Lx>(alpha: T, x: &mut DSlice<T, 1, Lx>)
123where
124 T: BlasScalar + ComplexFloat,
125 Lx: Layout,
126{
127 let n = into_i32(x.len());
128 let incx = into_i32(x.stride(0));
129
130 unsafe { T::cblas_scal(n, alpha, x.as_mut_ptr(), incx) }
131}
132
133pub fn syr<T, Lx, La>(uplo: CBLAS_UPLO, alpha: T, x: &DSlice<T, 1, Lx>, a: &mut DSlice<T, 2, La>)
134where
135 T: BlasScalar + ComplexFloat,
136 Lx: Layout,
137 La: Layout,
138{
139 let (m, n) = *a.shape();
140 assert_eq!(m, n, "Matrix a must be square for symmetric update");
141 assert_eq!(x.len(), n, "x length must match matrix dimension");
142
143 let row_major = a.stride(1) == 1;
144 assert!(
145 row_major || a.stride(0) == 1,
146 "a must be contiguous in one dimension"
147 );
148
149 let x_inc = into_i32(x.stride(0));
150 let lda = if row_major { into_i32(n) } else { into_i32(m) };
151
152 unsafe {
153 T::cblas_syr(
154 if row_major {
155 CBLAS_LAYOUT::CblasRowMajor
156 } else {
157 CBLAS_LAYOUT::CblasColMajor
158 },
159 uplo,
160 into_i32(n),
161 alpha,
162 x.as_ptr(),
163 x_inc,
164 a.as_mut_ptr(),
165 lda,
166 )
167 }
168}
169
170pub fn her<T, Lx, La>(
171 uplo: CBLAS_UPLO,
172 alpha: T::Real,
173 x: &DSlice<T, 1, Lx>,
174 a: &mut DSlice<T, 2, La>,
175) where
176 T: BlasScalar + ComplexFloat,
177 Lx: Layout,
178 La: Layout,
179{
180 let (m, n) = *a.shape();
181 assert_eq!(m, n, "Matrix a must be square for hermitian update");
182 assert_eq!(x.len(), n, "x length must match matrix dimension");
183
184 let row_major = a.stride(1) == 1;
185 assert!(
186 row_major || a.stride(0) == 1,
187 "a must be contiguous in one dimension"
188 );
189
190 let x_inc = into_i32(x.stride(0));
191 let lda = if row_major { into_i32(n) } else { into_i32(m) };
192
193 unsafe {
194 T::cblas_her(
195 if row_major {
196 CBLAS_LAYOUT::CblasRowMajor
197 } else {
198 CBLAS_LAYOUT::CblasColMajor
199 },
200 uplo,
201 into_i32(n),
202 alpha,
203 x.as_ptr(),
204 x_inc,
205 a.as_mut_ptr(),
206 lda,
207 )
208 }
209}
210
211pub fn asum<T, Lx>(x: &DSlice<T, 1, Lx>) -> T::Real
212where
213 T: BlasScalar + ComplexFloat,
214 Lx: Layout,
215{
216 let n = into_i32(x.len());
217 let incx = into_i32(x.stride(0));
218
219 unsafe { T::cblas_asum(n, x.as_ptr(), incx) }
220}
221
222pub fn axpy<T, Lx, Ly>(alpha: T, x: &DSlice<T, 1, Lx>, y: &mut DSlice<T, 1, Ly>)
223where
224 T: BlasScalar + ComplexFloat,
225 Lx: Layout,
226 Ly: Layout,
227{
228 assert_eq!(x.len(), y.len(), "Vector lengths must match");
229
230 let n = into_i32(x.len());
231 let incx = into_i32(x.stride(0));
232 let incy = into_i32(y.stride(0));
233
234 unsafe { T::cblas_axpy(n, alpha, x.as_ptr(), incx, y.as_mut_ptr(), incy) }
235}
236
237pub fn nrm2<T, Lx>(x: &DSlice<T, 1, Lx>) -> T::Real
238where
239 T: BlasScalar + ComplexFloat,
240 Lx: Layout,
241{
242 let n = into_i32(x.len());
243 let incx = into_i32(x.stride(0));
244
245 unsafe { T::cblas_nrm2(n, x.as_ptr(), incx) }
246}
247
248pub fn dotu<T, Lx, Ly>(x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T
249where
250 T: BlasScalar + ComplexFloat + 'static,
251 Lx: Layout,
252 Ly: Layout,
253{
254 assert_eq!(x.len(), y.len(), "Vector lengths must match");
255
256 let n = into_i32(x.len());
257 let incx = into_i32(x.stride(0));
258 let incy = into_i32(y.stride(0));
259
260 let mut result = T::zero();
261
262 if TypeId::of::<T>() == TypeId::of::<Complex<f32>>()
263 || TypeId::of::<T>() == TypeId::of::<Complex<f64>>()
264 {
265 unsafe {
266 T::cblas_dotu_sub(n, x.as_ptr(), incx, y.as_ptr(), incy, &mut result);
267 }
268 } else {
269 unsafe {
270 result = T::cblas_dot(n, x.as_ptr(), incx, y.as_ptr(), incy);
271 }
272 }
273
274 result
275}
276
277pub fn dotc<T, Lx, Ly>(x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T
278where
279 T: BlasScalar + ComplexFloat + 'static,
280 Lx: Layout,
281 Ly: Layout,
282{
283 assert_eq!(x.len(), y.len(), "Vector lengths must match");
284
285 let n = into_i32(x.len());
286 let incx = into_i32(x.stride(0));
287 let incy = into_i32(y.stride(0));
288
289 let mut result = T::zero();
290
291 if TypeId::of::<T>() == TypeId::of::<Complex<f32>>()
292 || TypeId::of::<T>() == TypeId::of::<Complex<f64>>()
293 {
294 unsafe {
295 T::cblas_dotc_sub(n, x.as_ptr(), incx, y.as_ptr(), incy, &mut result);
296 }
297 } else {
298 unsafe {
299 result = T::cblas_dot(n, x.as_ptr(), incx, y.as_ptr(), incy);
300 }
301 }
302
303 result
304}
305
306pub fn amax<T, S, L>(x: &Slice<T, S, L>) -> usize
307where
308 T: BlasScalar + ComplexFloat + 'static,
309 S: Shape,
310 L: Layout,
311{
312 assert!(!x.is_empty(), "Cannot find amax of empty slice");
313
314 let n = into_i32(x.len());
315 let incx = if x.rank() == 1 {
316 into_i32(x.stride(0))
317 } else {
318 1 };
320
321 let max_idx = unsafe { T::cblas_amax(n, x.as_ptr(), incx) } as usize - 1; max_idx
324}