1use num_traits::Zero;
2use std::ops::{Add, Mul};
3
4use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
5use num_complex::ComplexFloat;
6
7use crate::matmul::{Triangle, Type};
8use crate::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
9use crate::utils::unravel_index;
10
11use crate::Naive;
12
13struct BlasMatVecBuilder<'a, T, La, Lx>
14where
15 La: Layout,
16 Lx: Layout,
17{
18 alpha: T,
19 a: &'a DSlice<T, 2, La>,
20 x: &'a DSlice<T, 1, Lx>,
21}
22
23impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
24where
25 La: Layout,
26 Lx: Layout,
27 T: ComplexFloat,
28 i8: Into<T::Real>,
29 T::Real: Into<T>,
30{
31 fn parallelize(self) -> Self {
32 self
33 }
34
35 fn scale(mut self, alpha: T) -> Self {
36 self.alpha = alpha * self.alpha;
37 self
38 }
39
40 fn eval(self) -> DTensor<T, 1> {
41 let mut _y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
42 todo!()
45 }
46
47 fn overwrite<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
48 todo!()
50 }
51
52 fn add_to<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
53 todo!()
55 }
56
57 fn add_to_scaled<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>, _beta: T) {
58 todo!()
60 }
61
62 fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
63 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
64 a_copy.assign(self.a);
65
66 let (m, n) = *a_copy.shape();
67
68 for i in 0..m {
69 for j in 0..n {
70 a_copy[[i, j]] = self.alpha * a_copy[[i, j]] + beta * self.x[[i]] * y[[j]];
71 }
72 }
73
74 a_copy
75 }
76
77 fn add_outer_special(self, _beta: T, _ty: Type, _tr: Triangle) -> DTensor<T, 2> {
78 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
79 a_copy.assign(self.a);
80
81 todo!()
99 }
100}
101
102impl<T> MatVec<T> for Naive
103where
104 T: ComplexFloat,
105 i8: Into<T::Real>,
106 T::Real: Into<T>,
107{
108 fn matvec<'a, La, Lx>(
109 &self,
110 a: &'a DSlice<T, 2, La>,
111 x: &'a DSlice<T, 1, Lx>,
112 ) -> impl MatVecBuilder<'a, T, La, Lx>
113 where
114 La: Layout,
115 Lx: Layout,
116 {
117 BlasMatVecBuilder {
118 alpha: 1.into().into(),
119 a,
120 x,
121 }
122 }
123}
124
125impl<T: ComplexFloat + 'static + Add<Output = T> + Mul<Output = T> + Zero + Copy> VecOps<T>
126 for Naive
127{
128 fn add_to_scaled<Lx: Layout, Ly: Layout>(
129 &self,
130 _alpha: T,
131 _x: &DSlice<T, 1, Lx>,
132 _y: &mut DSlice<T, 1, Ly>,
133 ) {
134 todo!()
135 }
137
138 fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
139 let mut result = T::zero();
140 for (elem_x, elem_y) in std::iter::zip(x.into_iter(), y.into_iter()) {
141 result = result + *elem_x * (*elem_y);
142 }
143 result
144 }
145
146 fn dotc<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &DSlice<T, 1, Ly>) -> T {
147 todo!()
148 }
150
151 fn norm2<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real {
152 todo!()
153 }
155
156 fn norm1<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real
157 where
158 T: ComplexFloat,
159 {
160 todo!()
161 }
163
164 fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
165 todo!()
166 }
167
168 fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
169 todo!()
170 }
171
172 fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
173 todo!()
174 }
175
176 fn rot<Lx: Layout, Ly: Layout>(
177 &self,
178 _x: &mut DSlice<T, 1, Lx>,
179 _y: &mut DSlice<T, 1, Ly>,
180 _c: T::Real,
181 _s: T,
182 ) where
183 T: ComplexFloat,
184 {
185 todo!()
186 }
187}
188
189impl<
190 T: ComplexFloat<Real = T> + 'static + PartialOrd + Add<Output = T> + Mul<Output = T> + Zero + Copy,
191> Argmax<T> for Naive
192{
193 fn argmax_overwrite<Lx: Layout, S: Shape>(
194 &self,
195 x: &Slice<T, S, Lx>,
196 output: &mut Vec<usize>,
197 ) -> bool {
198 output.clear();
199
200 if x.is_empty() {
201 return false;
202 }
203
204 if x.rank() == 0 {
205 return true;
206 }
207
208 let mut max_flat_idx = 0;
209 let mut max_val = x.iter().next().unwrap();
210
211 for (flat_idx, val) in x.iter().enumerate().skip(1) {
212 if val > max_val {
213 max_val = val;
214 max_flat_idx = flat_idx;
215 }
216 }
217
218 let indices = unravel_index(x, max_flat_idx);
219 output.extend_from_slice(&indices);
220 true
221 }
222
223 fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
224 let mut result = Vec::new();
225 if self.argmax_overwrite(x, &mut result) {
226 Some(result)
227 } else {
228 None
229 }
230 }
231
232 fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
233 &self,
234 x: &Slice<T, S, Lx>,
235 output: &mut Vec<usize>,
236 ) -> bool {
237 output.clear();
238
239 if x.is_empty() {
240 return false;
241 }
242
243 if x.rank() == 0 {
244 return true;
245 }
246
247 let mut max_flat_idx = 0;
248 let mut max_val = x.iter().next().unwrap().abs();
249
250 for (flat_idx, val) in x.iter().enumerate().skip(1) {
251 if val.abs() > max_val {
252 max_val = val.abs();
253 max_flat_idx = flat_idx;
254 }
255 }
256
257 let indices = unravel_index(x, max_flat_idx);
258 output.extend_from_slice(&indices);
259 true
260 }
261
262 fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
263 let mut result = Vec::new();
264 if self.argmax_abs_overwrite(x, &mut result) {
265 Some(result)
266 } else {
267 None
268 }
269 }
270}