mdarray_linalg_blas/matvec/
context.rs

1use cblas_sys::CBLAS_UPLO;
2use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
3use num_complex::ComplexFloat;
4
5use mdarray_linalg::matmul::{Triangle, Type};
6use mdarray_linalg::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
7
8use crate::Blas;
9
10use super::scalar::BlasScalar;
11use super::simple::{asum, axpy, dotc, dotu, gemv, ger, her, nrm2, syr};
12
13struct BlasMatVecBuilder<'a, T, La, Lx>
14where
15    La: Layout,
16    Lx: Layout,
17{
18    alpha: T,
19    a: &'a DSlice<T, 2, La>,
20    x: &'a DSlice<T, 1, Lx>,
21}
22
23impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
24where
25    La: Layout,
26    Lx: Layout,
27    T: BlasScalar + ComplexFloat,
28    i8: Into<T::Real>,
29    T::Real: Into<T>,
30{
31    fn parallelize(self) -> Self {
32        self
33    }
34
35    fn scale(mut self, alpha: T) -> Self {
36        self.alpha = alpha * self.alpha;
37        self
38    }
39
40    fn eval(self) -> DTensor<T, 1> {
41        let mut y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
42        gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
43        y
44    }
45
46    fn overwrite<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
47        gemv(self.alpha, self.a, self.x, 0.into().into(), y);
48    }
49
50    fn add_to<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
51        gemv(self.alpha, self.a, self.x, 1.into().into(), y);
52    }
53
54    fn add_to_scaled<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>, beta: T) {
55        gemv(self.alpha, self.a, self.x, beta, y);
56    }
57
58    fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
59        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
60        a_copy.assign(self.a);
61
62        // Apply scale factor to preserve builder pattern logic: the alpha parameter
63        // may have been modified before this call, so we must scale the matrix
64        // before applying the rank-1 update. Unlike gemm operations, this requires
65        // a separate pass since BLAS lacks a direct matrix-scalar multiplication.
66
67        if self.alpha != 1.into().into() {
68            a_copy = a_copy.map(|x| x * self.alpha);
69        }
70
71        ger(beta, self.x, y, &mut a_copy);
72        a_copy
73    }
74
75    fn add_outer_special(self, beta: T, ty: Type, tr: Triangle) -> DTensor<T, 2> {
76        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
77        a_copy.assign(self.a);
78
79        if self.alpha != 1.into().into() {
80            a_copy = a_copy.map(|x| x * self.alpha);
81        }
82
83        let cblas_uplo = match tr {
84            Triangle::Lower => CBLAS_UPLO::CblasLower,
85            Triangle::Upper => CBLAS_UPLO::CblasUpper,
86        };
87
88        match ty {
89            Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
90            Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
91            Type::Tri => {
92                ger(beta, self.x, self.x, &mut a_copy);
93            }
94        }
95
96        a_copy
97    }
98}
99
100impl<T> MatVec<T> for Blas
101where
102    T: BlasScalar + ComplexFloat,
103    i8: Into<T::Real>,
104    T::Real: Into<T>,
105{
106    fn matvec<'a, La, Lx>(
107        &self,
108        a: &'a DSlice<T, 2, La>,
109        x: &'a DSlice<T, 1, Lx>,
110    ) -> impl MatVecBuilder<'a, T, La, Lx>
111    where
112        La: Layout,
113        Lx: Layout,
114    {
115        BlasMatVecBuilder {
116            alpha: 1.into().into(),
117            a,
118            x,
119        }
120    }
121}
122
123impl<T: ComplexFloat + BlasScalar + 'static> VecOps<T> for Blas {
124    fn add_to_scaled<Lx: Layout, Ly: Layout>(
125        &self,
126        alpha: T,
127        x: &DSlice<T, 1, Lx>,
128        y: &mut DSlice<T, 1, Ly>,
129    ) {
130        axpy(alpha, x, y);
131    }
132
133    fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
134        dotu(x, y)
135    }
136
137    fn dotc<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
138        dotc(x, y)
139    }
140
141    fn norm2<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real {
142        nrm2(x)
143    }
144
145    fn norm1<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real
146    where
147        T: ComplexFloat,
148    {
149        asum(x)
150    }
151
152    fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
153        todo!()
154    }
155    fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
156        todo!()
157    }
158    fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
159        todo!()
160    }
161    fn rot<Lx: Layout, Ly: Layout>(
162        &self,
163        _x: &mut DSlice<T, 1, Lx>,
164        _y: &mut DSlice<T, 1, Ly>,
165        _c: T::Real,
166        _s: T,
167    ) where
168        T: ComplexFloat,
169    {
170        todo!()
171    }
172}
173
174impl<T: ComplexFloat + 'static + std::cmp::PartialOrd> Argmax<T> for Blas {
175    fn argmax_overwrite<Lx: Layout, S: Shape>(
176        &self,
177        _x: &Slice<T, S, Lx>,
178        _output: &mut Vec<usize>,
179    ) -> bool {
180        todo!()
181    }
182
183    fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
184        let mut result = Vec::new();
185        if self.argmax_overwrite(x, &mut result) {
186            Some(result)
187        } else {
188            None
189        }
190    }
191}