mdarray_linalg/naive/matvec/
context.rs

1use num_traits::Zero;
2use std::ops::{Add, Mul};
3
4use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
5use num_complex::ComplexFloat;
6
7use crate::matmul::{Triangle, Type};
8use crate::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
9use crate::utils::unravel_index;
10
11use crate::Naive;
12
13struct BlasMatVecBuilder<'a, T, La, Lx>
14where
15    La: Layout,
16    Lx: Layout,
17{
18    alpha: T,
19    a: &'a DSlice<T, 2, La>,
20    x: &'a DSlice<T, 1, Lx>,
21}
22
23impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
24where
25    La: Layout,
26    Lx: Layout,
27    T: ComplexFloat,
28    i8: Into<T::Real>,
29    T::Real: Into<T>,
30{
31    fn parallelize(self) -> Self {
32        self
33    }
34
35    fn scale(mut self, alpha: T) -> Self {
36        self.alpha = alpha * self.alpha;
37        self
38    }
39
40    fn eval(self) -> DTensor<T, 1> {
41        let mut _y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
42        // gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
43        // y
44        todo!()
45    }
46
47    fn overwrite<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
48        // gemv(self.alpha, self.a, self.x, 0.into().into(), y);
49        todo!()
50    }
51
52    fn add_to<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
53        // gemv(self.alpha, self.a, self.x, 1.into().into(), y);
54        todo!()
55    }
56
57    fn add_to_scaled<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>, _beta: T) {
58        // gemv(self.alpha, self.a, self.x, beta, y);
59        todo!()
60    }
61
62    fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2> {
63        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
64        a_copy.assign(self.a);
65
66        let (m, n) = *a_copy.shape();
67
68        for i in 0..m {
69            for j in 0..n {
70                a_copy[[i, j]] = self.alpha * a_copy[[i, j]] + beta * self.x[[i]] * y[[j]];
71            }
72        }
73
74        a_copy
75    }
76
77    fn add_outer_special(self, _beta: T, _ty: Type, _tr: Triangle) -> DTensor<T, 2> {
78        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
79        a_copy.assign(self.a);
80
81        // if self.alpha != 1.into().into() {
82        //     a_copy = a_copy.map(|x| x * self.alpha);
83        // }
84
85        // let cblas_uplo = match tr {
86        //     Triangle::Lower => CBLAS_UPLO::CblasLower,
87        //     Triangle::Upper => CBLAS_UPLO::CblasUpper,
88        // };
89
90        // match ty {
91        //     Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
92        //     Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
93        //     Type::Tri => {
94        //         ger(beta, self.x, self.x, &mut a_copy);
95        //     }
96        // }
97        // a_copy
98        todo!()
99    }
100}
101
102impl<T> MatVec<T> for Naive
103where
104    T: ComplexFloat,
105    i8: Into<T::Real>,
106    T::Real: Into<T>,
107{
108    fn matvec<'a, La, Lx>(
109        &self,
110        a: &'a DSlice<T, 2, La>,
111        x: &'a DSlice<T, 1, Lx>,
112    ) -> impl MatVecBuilder<'a, T, La, Lx>
113    where
114        La: Layout,
115        Lx: Layout,
116    {
117        BlasMatVecBuilder {
118            alpha: 1.into().into(),
119            a,
120            x,
121        }
122    }
123}
124
125impl<T: ComplexFloat + 'static + Add<Output = T> + Mul<Output = T> + Zero + Copy> VecOps<T>
126    for Naive
127{
128    fn add_to_scaled<Lx: Layout, Ly: Layout>(
129        &self,
130        _alpha: T,
131        _x: &DSlice<T, 1, Lx>,
132        _y: &mut DSlice<T, 1, Ly>,
133    ) {
134        todo!()
135        // axpy(alpha, x, y);
136    }
137
138    fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
139        let mut result = T::zero();
140        for (elem_x, elem_y) in std::iter::zip(x.into_iter(), y.into_iter()) {
141            result = result + *elem_x * (*elem_y);
142        }
143        result
144    }
145
146    fn dotc<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &DSlice<T, 1, Ly>) -> T {
147        todo!()
148        // dotc(x, y)
149    }
150
151    fn norm2<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real {
152        todo!()
153        // nrm2(x)
154    }
155
156    fn norm1<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real
157    where
158        T: ComplexFloat,
159    {
160        todo!()
161        // asum(x)
162    }
163
164    fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
165        todo!()
166    }
167
168    fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
169        todo!()
170    }
171
172    fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
173        todo!()
174    }
175
176    fn rot<Lx: Layout, Ly: Layout>(
177        &self,
178        _x: &mut DSlice<T, 1, Lx>,
179        _y: &mut DSlice<T, 1, Ly>,
180        _c: T::Real,
181        _s: T,
182    ) where
183        T: ComplexFloat,
184    {
185        todo!()
186    }
187}
188
189impl<
190    T: ComplexFloat<Real = T> + 'static + PartialOrd + Add<Output = T> + Mul<Output = T> + Zero + Copy,
191> Argmax<T> for Naive
192{
193    fn argmax_overwrite<Lx: Layout, S: Shape>(
194        &self,
195        x: &Slice<T, S, Lx>,
196        output: &mut Vec<usize>,
197    ) -> bool {
198        output.clear();
199
200        if x.is_empty() {
201            return false;
202        }
203
204        if x.rank() == 0 {
205            return true;
206        }
207
208        let mut max_flat_idx = 0;
209        let mut max_val = x.iter().next().unwrap();
210
211        for (flat_idx, val) in x.iter().enumerate().skip(1) {
212            if val > max_val {
213                max_val = val;
214                max_flat_idx = flat_idx;
215            }
216        }
217
218        let indices = unravel_index(x, max_flat_idx);
219        output.extend_from_slice(&indices);
220        true
221    }
222
223    fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
224        let mut result = Vec::new();
225        if self.argmax_overwrite(x, &mut result) {
226            Some(result)
227        } else {
228            None
229        }
230    }
231
232    fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
233        &self,
234        x: &Slice<T, S, Lx>,
235        output: &mut Vec<usize>,
236    ) -> bool {
237        output.clear();
238
239        if x.is_empty() {
240            return false;
241        }
242
243        if x.rank() == 0 {
244            return true;
245        }
246
247        let mut max_flat_idx = 0;
248        let mut max_val = x.iter().next().unwrap().abs();
249
250        for (flat_idx, val) in x.iter().enumerate().skip(1) {
251            if val.abs() > max_val {
252                max_val = val.abs();
253                max_flat_idx = flat_idx;
254            }
255        }
256
257        let indices = unravel_index(x, max_flat_idx);
258        output.extend_from_slice(&indices);
259        true
260    }
261
262    fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
263        let mut result = Vec::new();
264        if self.argmax_abs_overwrite(x, &mut result) {
265            Some(result)
266        } else {
267            None
268        }
269    }
270}