1use cblas_sys::CBLAS_UPLO;
2use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
3use num_complex::ComplexFloat;
4
5use mdarray_linalg::matmul::{Triangle, Type};
6use mdarray_linalg::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
7use mdarray_linalg::utils::unravel_index;
8
9use crate::Blas;
10
11use super::scalar::BlasScalar;
12use super::simple::{amax, asum, axpy, dotc, dotu, gemv, ger, her, nrm2, syr};
13
14struct BlasMatVecBuilder<'a, T, La, Lx>
15where
16 La: Layout,
17 Lx: Layout,
18{
19 alpha: T,
20 a: &'a DSlice<T, 2, La>,
21 x: &'a DSlice<T, 1, Lx>,
22}
23
24impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
25where
26 La: Layout,
27 Lx: Layout,
28 T: BlasScalar + ComplexFloat,
29 i8: Into<T::Real>,
30 T::Real: Into<T>,
31{
32 fn parallelize(self) -> Self {
33 self
34 }
35
36 fn scale(mut self, alpha: T) -> Self {
37 self.alpha = alpha * self.alpha;
38 self
39 }
40
41 fn eval(self) -> DTensor<T, 1> {
42 let mut y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
43 gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
44 y
45 }
46
47 fn overwrite<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
48 gemv(self.alpha, self.a, self.x, 0.into().into(), y);
49 }
50
51 fn add_to<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
52 gemv(self.alpha, self.a, self.x, 1.into().into(), y);
53 }
54
55 fn add_to_scaled<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>, beta: T) {
56 gemv(self.alpha, self.a, self.x, beta, y);
57 }
58
59 fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
60 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
61 a_copy.assign(self.a);
62
63 if self.alpha != 1.into().into() {
69 a_copy = a_copy.map(|x| x * self.alpha);
70 }
71
72 ger(beta, self.x, y, &mut a_copy);
73 a_copy
74 }
75
76 fn add_outer_special(self, beta: T, ty: Type, tr: Triangle) -> DTensor<T, 2> {
77 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
78 a_copy.assign(self.a);
79
80 if self.alpha != 1.into().into() {
81 a_copy = a_copy.map(|x| x * self.alpha);
82 }
83
84 let cblas_uplo = match tr {
85 Triangle::Lower => CBLAS_UPLO::CblasLower,
86 Triangle::Upper => CBLAS_UPLO::CblasUpper,
87 };
88
89 match ty {
90 Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
91 Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
92 Type::Tri => {
93 ger(beta, self.x, self.x, &mut a_copy);
94 }
95 }
96
97 a_copy
98 }
99}
100
101impl<T> MatVec<T> for Blas
102where
103 T: BlasScalar + ComplexFloat,
104 i8: Into<T::Real>,
105 T::Real: Into<T>,
106{
107 fn matvec<'a, La, Lx>(
108 &self,
109 a: &'a DSlice<T, 2, La>,
110 x: &'a DSlice<T, 1, Lx>,
111 ) -> impl MatVecBuilder<'a, T, La, Lx>
112 where
113 La: Layout,
114 Lx: Layout,
115 {
116 BlasMatVecBuilder {
117 alpha: 1.into().into(),
118 a,
119 x,
120 }
121 }
122}
123
124impl<T: ComplexFloat + BlasScalar + 'static> VecOps<T> for Blas {
125 fn add_to_scaled<Lx: Layout, Ly: Layout>(
126 &self,
127 alpha: T,
128 x: &DSlice<T, 1, Lx>,
129 y: &mut DSlice<T, 1, Ly>,
130 ) {
131 axpy(alpha, x, y);
132 }
133
134 fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
135 dotu(x, y)
136 }
137
138 fn dotc<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
139 dotc(x, y)
140 }
141
142 fn norm2<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real {
143 nrm2(x)
144 }
145
146 fn norm1<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real
147 where
148 T: ComplexFloat,
149 {
150 asum(x)
151 }
152
153 fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
154 todo!()
155 }
156 fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
157 todo!()
158 }
159 fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
160 todo!()
161 }
162 fn rot<Lx: Layout, Ly: Layout>(
163 &self,
164 _x: &mut DSlice<T, 1, Lx>,
165 _y: &mut DSlice<T, 1, Ly>,
166 _c: T::Real,
167 _s: T,
168 ) where
169 T: ComplexFloat,
170 {
171 todo!()
172 }
173}
174
175impl<T: ComplexFloat + 'static + std::cmp::PartialOrd + BlasScalar> Argmax<T> for Blas {
176 fn argmax_overwrite<Lx: Layout, S: Shape>(
177 &self,
178 _x: &Slice<T, S, Lx>,
179 _output: &mut Vec<usize>,
180 ) -> bool {
181 unimplemented!("BLAS does not implement an argmax function, only argmax_abs")
182 }
183
184 fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
185 let mut result = Vec::new();
186 if self.argmax_overwrite(x, &mut result) {
187 Some(result)
188 } else {
189 None
190 }
191 }
192
193 fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
194 &self,
195 x: &Slice<T, S, Lx>,
196 output: &mut Vec<usize>,
197 ) -> bool {
198 output.clear();
199
200 if x.is_empty() {
201 return false;
202 }
203
204 if x.rank() == 0 {
205 return true;
206 }
207
208 let max_flat_idx = amax(x);
209 let indices = unravel_index(x, max_flat_idx);
210 output.extend_from_slice(&indices);
211
212 true
213 }
214
215 fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
216 let mut result = Vec::new();
217 if self.argmax_overwrite(x, &mut result) {
218 Some(result)
219 } else {
220 None
221 }
222 }
223}