mdarray_linalg_blas/matvec/
simple.rs

1use cblas_sys::{CBLAS_LAYOUT, CBLAS_TRANSPOSE, CBLAS_UPLO};
2use mdarray::{DSlice, Layout, Shape, Slice};
3use mdarray_linalg::{into_i32, trans_stride};
4use num_complex::ComplexFloat;
5
6use num_complex::Complex;
7use std::any::TypeId;
8
9use super::scalar::BlasScalar;
10
11pub fn gemv<T, La, Lx, Ly>(
12    alpha: T,
13    a: &DSlice<T, 2, La>,
14    x: &DSlice<T, 1, Lx>,
15    beta: T,
16    y: &mut DSlice<T, 1, Ly>,
17) where
18    T: BlasScalar + ComplexFloat,
19    La: Layout,
20    Lx: Layout,
21    Ly: Layout,
22{
23    let (m, n) = *a.shape();
24
25    if a.stride(1) == 1 {
26        assert_eq!(x.len(), n, "x length must match number of columns in a");
27    } else {
28        assert_eq!(x.len(), m, "x length must match number of rows in a");
29    }
30
31    assert_eq!(
32        y.len(),
33        if a.stride(1) == 1 { m } else { n },
34        "y length must match the output dimension"
35    );
36
37    let row_major = a.stride(1) == 1;
38    assert!(
39        row_major || a.stride(0) == 1,
40        "a must be contiguous in one dimension"
41    );
42
43    let (same_order, other_order) = if row_major {
44        (CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans)
45    } else {
46        (CBLAS_TRANSPOSE::CblasTrans, CBLAS_TRANSPOSE::CblasNoTrans)
47    };
48    let (a_trans, a_stride) = trans_stride!(a, same_order, other_order);
49
50    let x_inc = into_i32(x.stride(0));
51    let y_inc = into_i32(y.stride(0));
52
53    unsafe {
54        T::cblas_gemv(
55            if row_major {
56                CBLAS_LAYOUT::CblasRowMajor
57            } else {
58                CBLAS_LAYOUT::CblasColMajor
59            },
60            a_trans,
61            into_i32(m),
62            into_i32(n),
63            alpha,
64            a.as_ptr(),
65            a_stride,
66            x.as_ptr(),
67            x_inc,
68            beta,
69            y.as_mut_ptr(),
70            y_inc,
71        )
72    }
73}
74
75pub fn ger<T, La, Lx, Ly>(
76    beta: T,
77    x: &DSlice<T, 1, Lx>,
78    y: &DSlice<T, 1, Ly>,
79    a: &mut DSlice<T, 2, La>,
80) where
81    T: BlasScalar + ComplexFloat,
82    La: Layout,
83    Lx: Layout,
84    Ly: Layout,
85{
86    let (m, n) = *a.shape();
87
88    assert_eq!(x.len(), m, "x length must match number of rows in a");
89    assert_eq!(y.len(), n, "y length must match number of columns in a");
90
91    let x_inc = into_i32(x.stride(0));
92    let y_inc = into_i32(y.stride(0));
93
94    let row_major = a.stride(1) == 1;
95    assert!(
96        row_major || a.stride(0) == 1,
97        "a must be contiguous in one dimension"
98    );
99
100    let lda = if row_major { into_i32(n) } else { into_i32(m) };
101
102    unsafe {
103        T::cblas_ger(
104            if row_major {
105                CBLAS_LAYOUT::CblasRowMajor
106            } else {
107                CBLAS_LAYOUT::CblasColMajor
108            },
109            into_i32(m),
110            into_i32(n),
111            beta,
112            x.as_ptr(),
113            x_inc,
114            y.as_ptr(),
115            y_inc,
116            a.as_mut_ptr(),
117            lda,
118        )
119    }
120}
121
122pub fn scal<T, Lx>(alpha: T, x: &mut DSlice<T, 1, Lx>)
123where
124    T: BlasScalar + ComplexFloat,
125    Lx: Layout,
126{
127    let n = into_i32(x.len());
128    let incx = into_i32(x.stride(0));
129
130    unsafe { T::cblas_scal(n, alpha, x.as_mut_ptr(), incx) }
131}
132
133pub fn syr<T, Lx, La>(uplo: CBLAS_UPLO, alpha: T, x: &DSlice<T, 1, Lx>, a: &mut DSlice<T, 2, La>)
134where
135    T: BlasScalar + ComplexFloat,
136    Lx: Layout,
137    La: Layout,
138{
139    let (m, n) = *a.shape();
140    assert_eq!(m, n, "Matrix a must be square for symmetric update");
141    assert_eq!(x.len(), n, "x length must match matrix dimension");
142
143    let row_major = a.stride(1) == 1;
144    assert!(
145        row_major || a.stride(0) == 1,
146        "a must be contiguous in one dimension"
147    );
148
149    let x_inc = into_i32(x.stride(0));
150    let lda = if row_major { into_i32(n) } else { into_i32(m) };
151
152    unsafe {
153        T::cblas_syr(
154            if row_major {
155                CBLAS_LAYOUT::CblasRowMajor
156            } else {
157                CBLAS_LAYOUT::CblasColMajor
158            },
159            uplo,
160            into_i32(n),
161            alpha,
162            x.as_ptr(),
163            x_inc,
164            a.as_mut_ptr(),
165            lda,
166        )
167    }
168}
169
170pub fn her<T, Lx, La>(
171    uplo: CBLAS_UPLO,
172    alpha: T::Real,
173    x: &DSlice<T, 1, Lx>,
174    a: &mut DSlice<T, 2, La>,
175) where
176    T: BlasScalar + ComplexFloat,
177    Lx: Layout,
178    La: Layout,
179{
180    let (m, n) = *a.shape();
181    assert_eq!(m, n, "Matrix a must be square for hermitian update");
182    assert_eq!(x.len(), n, "x length must match matrix dimension");
183
184    let row_major = a.stride(1) == 1;
185    assert!(
186        row_major || a.stride(0) == 1,
187        "a must be contiguous in one dimension"
188    );
189
190    let x_inc = into_i32(x.stride(0));
191    let lda = if row_major { into_i32(n) } else { into_i32(m) };
192
193    unsafe {
194        T::cblas_her(
195            if row_major {
196                CBLAS_LAYOUT::CblasRowMajor
197            } else {
198                CBLAS_LAYOUT::CblasColMajor
199            },
200            uplo,
201            into_i32(n),
202            alpha,
203            x.as_ptr(),
204            x_inc,
205            a.as_mut_ptr(),
206            lda,
207        )
208    }
209}
210
211pub fn asum<T, Lx>(x: &DSlice<T, 1, Lx>) -> T::Real
212where
213    T: BlasScalar + ComplexFloat,
214    Lx: Layout,
215{
216    let n = into_i32(x.len());
217    let incx = into_i32(x.stride(0));
218
219    unsafe { T::cblas_asum(n, x.as_ptr(), incx) }
220}
221
222pub fn axpy<T, Lx, Ly>(alpha: T, x: &DSlice<T, 1, Lx>, y: &mut DSlice<T, 1, Ly>)
223where
224    T: BlasScalar + ComplexFloat,
225    Lx: Layout,
226    Ly: Layout,
227{
228    assert_eq!(x.len(), y.len(), "Vector lengths must match");
229
230    let n = into_i32(x.len());
231    let incx = into_i32(x.stride(0));
232    let incy = into_i32(y.stride(0));
233
234    unsafe { T::cblas_axpy(n, alpha, x.as_ptr(), incx, y.as_mut_ptr(), incy) }
235}
236
237pub fn nrm2<T, Lx>(x: &DSlice<T, 1, Lx>) -> T::Real
238where
239    T: BlasScalar + ComplexFloat,
240    Lx: Layout,
241{
242    let n = into_i32(x.len());
243    let incx = into_i32(x.stride(0));
244
245    unsafe { T::cblas_nrm2(n, x.as_ptr(), incx) }
246}
247
248pub fn dotu<T, Lx, Ly>(x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T
249where
250    T: BlasScalar + ComplexFloat + 'static,
251    Lx: Layout,
252    Ly: Layout,
253{
254    assert_eq!(x.len(), y.len(), "Vector lengths must match");
255
256    let n = into_i32(x.len());
257    let incx = into_i32(x.stride(0));
258    let incy = into_i32(y.stride(0));
259
260    let mut result = T::zero();
261
262    if TypeId::of::<T>() == TypeId::of::<Complex<f32>>()
263        || TypeId::of::<T>() == TypeId::of::<Complex<f64>>()
264    {
265        unsafe {
266            T::cblas_dotu_sub(n, x.as_ptr(), incx, y.as_ptr(), incy, &mut result);
267        }
268    } else {
269        unsafe {
270            result = T::cblas_dot(n, x.as_ptr(), incx, y.as_ptr(), incy);
271        }
272    }
273
274    result
275}
276
277pub fn dotc<T, Lx, Ly>(x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T
278where
279    T: BlasScalar + ComplexFloat + 'static,
280    Lx: Layout,
281    Ly: Layout,
282{
283    assert_eq!(x.len(), y.len(), "Vector lengths must match");
284
285    let n = into_i32(x.len());
286    let incx = into_i32(x.stride(0));
287    let incy = into_i32(y.stride(0));
288
289    let mut result = T::zero();
290
291    if TypeId::of::<T>() == TypeId::of::<Complex<f32>>()
292        || TypeId::of::<T>() == TypeId::of::<Complex<f64>>()
293    {
294        unsafe {
295            T::cblas_dotc_sub(n, x.as_ptr(), incx, y.as_ptr(), incy, &mut result);
296        }
297    } else {
298        unsafe {
299            result = T::cblas_dot(n, x.as_ptr(), incx, y.as_ptr(), incy);
300        }
301    }
302
303    result
304}
305
306pub fn amax<T, S, L>(x: &Slice<T, S, L>) -> usize
307where
308    T: BlasScalar + ComplexFloat + 'static,
309    S: Shape,
310    L: Layout,
311{
312    assert!(!x.is_empty(), "Cannot find amax of empty slice");
313
314    let n = into_i32(x.len());
315    let incx = if x.rank() == 1 {
316        into_i32(x.stride(0))
317    } else {
318        1 // Treat multi-dimensional as flat contiguous
319    };
320
321    let max_idx = unsafe { T::cblas_amax(n, x.as_ptr(), incx) } as usize - 1; // BLAS uses 1-based indexing
322
323    max_idx
324}