mdarray_linalg_blas/matvec/
context.rs

1use cblas_sys::CBLAS_UPLO;
2use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
3use num_complex::ComplexFloat;
4
5use mdarray_linalg::matmul::{Triangle, Type};
6use mdarray_linalg::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
7use mdarray_linalg::utils::unravel_index;
8
9use crate::Blas;
10
11use super::scalar::BlasScalar;
12use super::simple::{amax, asum, axpy, dotc, dotu, gemv, ger, her, nrm2, syr};
13
14struct BlasMatVecBuilder<'a, T, La, Lx>
15where
16    La: Layout,
17    Lx: Layout,
18{
19    alpha: T,
20    a: &'a DSlice<T, 2, La>,
21    x: &'a DSlice<T, 1, Lx>,
22}
23
24impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
25where
26    La: Layout,
27    Lx: Layout,
28    T: BlasScalar + ComplexFloat,
29    i8: Into<T::Real>,
30    T::Real: Into<T>,
31{
32    fn parallelize(self) -> Self {
33        self
34    }
35
36    fn scale(mut self, alpha: T) -> Self {
37        self.alpha = alpha * self.alpha;
38        self
39    }
40
41    fn eval(self) -> DTensor<T, 1> {
42        let mut y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
43        gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
44        y
45    }
46
47    fn overwrite<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
48        gemv(self.alpha, self.a, self.x, 0.into().into(), y);
49    }
50
51    fn add_to<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
52        gemv(self.alpha, self.a, self.x, 1.into().into(), y);
53    }
54
55    fn add_to_scaled<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>, beta: T) {
56        gemv(self.alpha, self.a, self.x, beta, y);
57    }
58
59    fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
60        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
61        a_copy.assign(self.a);
62
63        // Apply scale factor to preserve builder pattern logic: the alpha parameter
64        // may have been modified before this call, so we must scale the matrix
65        // before applying the rank-1 update. Unlike gemm operations, this requires
66        // a separate pass since BLAS lacks a direct matrix-scalar multiplication.
67
68        if self.alpha != 1.into().into() {
69            a_copy = a_copy.map(|x| x * self.alpha);
70        }
71
72        ger(beta, self.x, y, &mut a_copy);
73        a_copy
74    }
75
76    fn add_outer_special(self, beta: T, ty: Type, tr: Triangle) -> DTensor<T, 2> {
77        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
78        a_copy.assign(self.a);
79
80        if self.alpha != 1.into().into() {
81            a_copy = a_copy.map(|x| x * self.alpha);
82        }
83
84        let cblas_uplo = match tr {
85            Triangle::Lower => CBLAS_UPLO::CblasLower,
86            Triangle::Upper => CBLAS_UPLO::CblasUpper,
87        };
88
89        match ty {
90            Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
91            Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
92            Type::Tri => {
93                ger(beta, self.x, self.x, &mut a_copy);
94            }
95        }
96
97        a_copy
98    }
99}
100
101impl<T> MatVec<T> for Blas
102where
103    T: BlasScalar + ComplexFloat,
104    i8: Into<T::Real>,
105    T::Real: Into<T>,
106{
107    fn matvec<'a, La, Lx>(
108        &self,
109        a: &'a DSlice<T, 2, La>,
110        x: &'a DSlice<T, 1, Lx>,
111    ) -> impl MatVecBuilder<'a, T, La, Lx>
112    where
113        La: Layout,
114        Lx: Layout,
115    {
116        BlasMatVecBuilder {
117            alpha: 1.into().into(),
118            a,
119            x,
120        }
121    }
122}
123
124impl<T: ComplexFloat + BlasScalar + 'static> VecOps<T> for Blas {
125    fn add_to_scaled<Lx: Layout, Ly: Layout>(
126        &self,
127        alpha: T,
128        x: &DSlice<T, 1, Lx>,
129        y: &mut DSlice<T, 1, Ly>,
130    ) {
131        axpy(alpha, x, y);
132    }
133
134    fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
135        dotu(x, y)
136    }
137
138    fn dotc<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
139        dotc(x, y)
140    }
141
142    fn norm2<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real {
143        nrm2(x)
144    }
145
146    fn norm1<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real
147    where
148        T: ComplexFloat,
149    {
150        asum(x)
151    }
152
153    fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
154        todo!()
155    }
156    fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
157        todo!()
158    }
159    fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
160        todo!()
161    }
162    fn rot<Lx: Layout, Ly: Layout>(
163        &self,
164        _x: &mut DSlice<T, 1, Lx>,
165        _y: &mut DSlice<T, 1, Ly>,
166        _c: T::Real,
167        _s: T,
168    ) where
169        T: ComplexFloat,
170    {
171        todo!()
172    }
173}
174
175impl<T: ComplexFloat + 'static + std::cmp::PartialOrd + BlasScalar> Argmax<T> for Blas {
176    fn argmax_overwrite<Lx: Layout, S: Shape>(
177        &self,
178        _x: &Slice<T, S, Lx>,
179        _output: &mut Vec<usize>,
180    ) -> bool {
181        unimplemented!("BLAS does not implement an argmax function, only argmax_abs")
182    }
183
184    fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
185        let mut result = Vec::new();
186        if self.argmax_overwrite(x, &mut result) {
187            Some(result)
188        } else {
189            None
190        }
191    }
192
193    fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
194        &self,
195        x: &Slice<T, S, Lx>,
196        output: &mut Vec<usize>,
197    ) -> bool {
198        output.clear();
199
200        if x.is_empty() {
201            return false;
202        }
203
204        if x.rank() == 0 {
205            return true;
206        }
207
208        let max_flat_idx = amax(x);
209        let indices = unravel_index(x, max_flat_idx);
210        output.extend_from_slice(&indices);
211
212        true
213    }
214
215    fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
216        let mut result = Vec::new();
217        if self.argmax_overwrite(x, &mut result) {
218            Some(result)
219        } else {
220            None
221        }
222    }
223}