1use cblas_sys::CBLAS_UPLO;
2use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
3use num_complex::ComplexFloat;
4
5use mdarray_linalg::matmul::{Triangle, Type};
6use mdarray_linalg::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
7
8use crate::Blas;
9
10use super::scalar::BlasScalar;
11use super::simple::{asum, axpy, dotc, dotu, gemv, ger, her, nrm2, syr};
12
13struct BlasMatVecBuilder<'a, T, La, Lx>
14where
15 La: Layout,
16 Lx: Layout,
17{
18 alpha: T,
19 a: &'a DSlice<T, 2, La>,
20 x: &'a DSlice<T, 1, Lx>,
21}
22
23impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
24where
25 La: Layout,
26 Lx: Layout,
27 T: BlasScalar + ComplexFloat,
28 i8: Into<T::Real>,
29 T::Real: Into<T>,
30{
31 fn parallelize(self) -> Self {
32 self
33 }
34
35 fn scale(mut self, alpha: T) -> Self {
36 self.alpha = alpha * self.alpha;
37 self
38 }
39
40 fn eval(self) -> DTensor<T, 1> {
41 let mut y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
42 gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
43 y
44 }
45
46 fn overwrite<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
47 gemv(self.alpha, self.a, self.x, 0.into().into(), y);
48 }
49
50 fn add_to<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>) {
51 gemv(self.alpha, self.a, self.x, 1.into().into(), y);
52 }
53
54 fn add_to_scaled<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>, beta: T) {
55 gemv(self.alpha, self.a, self.x, beta, y);
56 }
57
58 fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
59 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
60 a_copy.assign(self.a);
61
62 if self.alpha != 1.into().into() {
68 a_copy = a_copy.map(|x| x * self.alpha);
69 }
70
71 ger(beta, self.x, y, &mut a_copy);
72 a_copy
73 }
74
75 fn add_outer_special(self, beta: T, ty: Type, tr: Triangle) -> DTensor<T, 2> {
76 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
77 a_copy.assign(self.a);
78
79 if self.alpha != 1.into().into() {
80 a_copy = a_copy.map(|x| x * self.alpha);
81 }
82
83 let cblas_uplo = match tr {
84 Triangle::Lower => CBLAS_UPLO::CblasLower,
85 Triangle::Upper => CBLAS_UPLO::CblasUpper,
86 };
87
88 match ty {
89 Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
90 Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
91 Type::Tri => {
92 ger(beta, self.x, self.x, &mut a_copy);
93 }
94 }
95
96 a_copy
97 }
98}
99
100impl<T> MatVec<T> for Blas
101where
102 T: BlasScalar + ComplexFloat,
103 i8: Into<T::Real>,
104 T::Real: Into<T>,
105{
106 fn matvec<'a, La, Lx>(
107 &self,
108 a: &'a DSlice<T, 2, La>,
109 x: &'a DSlice<T, 1, Lx>,
110 ) -> impl MatVecBuilder<'a, T, La, Lx>
111 where
112 La: Layout,
113 Lx: Layout,
114 {
115 BlasMatVecBuilder {
116 alpha: 1.into().into(),
117 a,
118 x,
119 }
120 }
121}
122
123impl<T: ComplexFloat + BlasScalar + 'static> VecOps<T> for Blas {
124 fn add_to_scaled<Lx: Layout, Ly: Layout>(
125 &self,
126 alpha: T,
127 x: &DSlice<T, 1, Lx>,
128 y: &mut DSlice<T, 1, Ly>,
129 ) {
130 axpy(alpha, x, y);
131 }
132
133 fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
134 dotu(x, y)
135 }
136
137 fn dotc<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
138 dotc(x, y)
139 }
140
141 fn norm2<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real {
142 nrm2(x)
143 }
144
145 fn norm1<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real
146 where
147 T: ComplexFloat,
148 {
149 asum(x)
150 }
151
152 fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
153 todo!()
154 }
155 fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
156 todo!()
157 }
158 fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
159 todo!()
160 }
161 fn rot<Lx: Layout, Ly: Layout>(
162 &self,
163 _x: &mut DSlice<T, 1, Lx>,
164 _y: &mut DSlice<T, 1, Ly>,
165 _c: T::Real,
166 _s: T,
167 ) where
168 T: ComplexFloat,
169 {
170 todo!()
171 }
172}
173
174impl<T: ComplexFloat + 'static + std::cmp::PartialOrd> Argmax<T> for Blas {
175 fn argmax_overwrite<Lx: Layout, S: Shape>(
176 &self,
177 _x: &Slice<T, S, Lx>,
178 _output: &mut Vec<usize>,
179 ) -> bool {
180 todo!()
181 }
182
183 fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
184 let mut result = Vec::new();
185 if self.argmax_overwrite(x, &mut result) {
186 Some(result)
187 } else {
188 None
189 }
190 }
191}