mdarray_linalg/naive/matvec/
context.rs1use num_traits::Zero;
2use std::ops::{Add, Mul};
3
4use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
5use num_complex::ComplexFloat;
6
7use crate::matmul::{Triangle, Type};
8use crate::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
9
10use crate::Naive;
11
12struct BlasMatVecBuilder<'a, T, La, Lx>
13where
14 La: Layout,
15 Lx: Layout,
16{
17 alpha: T,
18 a: &'a DSlice<T, 2, La>,
19 x: &'a DSlice<T, 1, Lx>,
20}
21
22impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
23where
24 La: Layout,
25 Lx: Layout,
26 T: ComplexFloat,
27 i8: Into<T::Real>,
28 T::Real: Into<T>,
29{
30 fn parallelize(self) -> Self {
31 self
32 }
33
34 fn scale(mut self, alpha: T) -> Self {
35 self.alpha = alpha * self.alpha;
36 self
37 }
38
39 fn eval(self) -> DTensor<T, 1> {
40 let mut _y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
41 todo!()
44 }
45
46 fn overwrite<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
47 todo!()
49 }
50
51 fn add_to<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
52 todo!()
54 }
55
56 fn add_to_scaled<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>, _beta: T) {
57 todo!()
59 }
60
61 fn add_outer<Ly: Layout>(self, _y: &DSlice<T, 1, Ly>, _beta: T) -> DTensor<T, 2> {
62 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
63 a_copy.assign(self.a);
64
65 todo!()
77 }
78
79 fn add_outer_special(self, _beta: T, _ty: Type, _tr: Triangle) -> DTensor<T, 2> {
80 let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
81 a_copy.assign(self.a);
82
83 todo!()
101 }
102}
103
104impl<T> MatVec<T> for Naive
105where
106 T: ComplexFloat,
107 i8: Into<T::Real>,
108 T::Real: Into<T>,
109{
110 fn matvec<'a, La, Lx>(
111 &self,
112 a: &'a DSlice<T, 2, La>,
113 x: &'a DSlice<T, 1, Lx>,
114 ) -> impl MatVecBuilder<'a, T, La, Lx>
115 where
116 La: Layout,
117 Lx: Layout,
118 {
119 BlasMatVecBuilder {
120 alpha: 1.into().into(),
121 a,
122 x,
123 }
124 }
125}
126
127impl<T: ComplexFloat + 'static + Add<Output = T> + Mul<Output = T> + Zero + Copy> VecOps<T>
128 for Naive
129{
130 fn add_to_scaled<Lx: Layout, Ly: Layout>(
131 &self,
132 _alpha: T,
133 _x: &DSlice<T, 1, Lx>,
134 _y: &mut DSlice<T, 1, Ly>,
135 ) {
136 todo!()
137 }
139
140 fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
141 let mut result = T::zero();
142 for (elem_x, elem_y) in std::iter::zip(x.into_iter(), y.into_iter()) {
143 result = result + *elem_x * (*elem_y);
144 }
145 result
146 }
147
148 fn dotc<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &DSlice<T, 1, Ly>) -> T {
149 todo!()
150 }
152
153 fn norm2<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real {
154 todo!()
155 }
157
158 fn norm1<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real
159 where
160 T: ComplexFloat,
161 {
162 todo!()
163 }
165
166 fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
167 todo!()
168 }
169
170 fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
171 todo!()
172 }
173
174 fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
175 todo!()
176 }
177
178 fn rot<Lx: Layout, Ly: Layout>(
179 &self,
180 _x: &mut DSlice<T, 1, Lx>,
181 _y: &mut DSlice<T, 1, Ly>,
182 _c: T::Real,
183 _s: T,
184 ) where
185 T: ComplexFloat,
186 {
187 todo!()
188 }
189}
190
191impl<T: ComplexFloat + 'static + PartialOrd + Add<Output = T> + Mul<Output = T> + Zero + Copy>
192 Argmax<T> for Naive
193{
194 fn argmax_overwrite<Lx: Layout, S: Shape>(
195 &self,
196 x: &Slice<T, S, Lx>,
197 output: &mut Vec<usize>,
198 ) -> bool {
199 output.clear();
200
201 if x.is_empty() {
202 return false;
203 }
204
205 if x.rank() == 0 {
206 return true;
207 }
208
209 let mut max_flat_idx = 0;
210 let mut max_val = x.iter().next().unwrap();
211
212 for (flat_idx, val) in x.iter().enumerate().skip(1) {
213 if val > max_val {
214 max_val = val;
215 max_flat_idx = flat_idx;
216 }
217 }
218
219 let indices = unravel_index(x, max_flat_idx);
220 output.extend_from_slice(&indices);
221 true
222 }
223
224 fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
225 let mut result = Vec::new();
226 if self.argmax_overwrite(x, &mut result) {
227 Some(result)
228 } else {
229 None
230 }
231 }
232}
233
234pub fn unravel_index<T, S: Shape, L: Layout>(x: &Slice<T, S, L>, mut flat: usize) -> Vec<usize> {
235 let rank = x.rank();
236
237 assert!(
238 flat < x.len(),
239 "flat index out of bounds: {} >= {}",
240 flat,
241 x.len()
242 );
243
244 let mut coords = vec![0usize; rank];
245
246 for i in (0..rank).rev() {
247 let dim = x.shape().dim(i);
248 coords[i] = flat % dim;
249 flat /= dim;
250 }
251
252 coords
253}