mdarray_linalg/naive/matvec/
context.rs

1use num_traits::Zero;
2use std::ops::{Add, Mul};
3
4use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
5use num_complex::ComplexFloat;
6
7use crate::matmul::{Triangle, Type};
8use crate::matvec::{Argmax, MatVec, MatVecBuilder, VecOps};
9
10use crate::Naive;
11
12struct BlasMatVecBuilder<'a, T, La, Lx>
13where
14    La: Layout,
15    Lx: Layout,
16{
17    alpha: T,
18    a: &'a DSlice<T, 2, La>,
19    x: &'a DSlice<T, 1, Lx>,
20}
21
22impl<'a, T, La, Lx> MatVecBuilder<'a, T, La, Lx> for BlasMatVecBuilder<'a, T, La, Lx>
23where
24    La: Layout,
25    Lx: Layout,
26    T: ComplexFloat,
27    i8: Into<T::Real>,
28    T::Real: Into<T>,
29{
30    fn parallelize(self) -> Self {
31        self
32    }
33
34    fn scale(mut self, alpha: T) -> Self {
35        self.alpha = alpha * self.alpha;
36        self
37    }
38
39    fn eval(self) -> DTensor<T, 1> {
40        let mut _y = DTensor::<T, 1>::from_elem(self.x.len(), 0.into().into());
41        // gemv(self.alpha, self.a, self.x, 0.into().into(), &mut y);
42        // y
43        todo!()
44    }
45
46    fn overwrite<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
47        // gemv(self.alpha, self.a, self.x, 0.into().into(), y);
48        todo!()
49    }
50
51    fn add_to<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>) {
52        // gemv(self.alpha, self.a, self.x, 1.into().into(), y);
53        todo!()
54    }
55
56    fn add_to_scaled<Ly: Layout>(self, _y: &mut DSlice<T, 1, Ly>, _beta: T) {
57        // gemv(self.alpha, self.a, self.x, beta, y);
58        todo!()
59    }
60
61    fn add_outer<Ly: Layout>(self, _y: &DSlice<T, 1, Ly>, _beta: T) -> DTensor<T, 2> {
62        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
63        a_copy.assign(self.a);
64
65        // Apply scale factor to preserve builder pattern logic: the alpha parameter
66        // may have been modified before this call, so we must scale the matrix
67        // before applying the rank-1 update. Unlike gemm operations, this requires
68        // a separate pass since BLAS lacks a direct matrix-scalar multiplication.
69
70        // if self.alpha != 1.into().into() {
71        //     a_copy = a_copy.map(|x| x * self.alpha);
72        // }
73
74        // ger(beta, self.x, y, &mut a_copy);
75        // a_copy
76        todo!()
77    }
78
79    fn add_outer_special(self, _beta: T, _ty: Type, _tr: Triangle) -> DTensor<T, 2> {
80        let mut a_copy = DTensor::<T, 2>::from_elem(*self.a.shape(), 0.into().into());
81        a_copy.assign(self.a);
82
83        // if self.alpha != 1.into().into() {
84        //     a_copy = a_copy.map(|x| x * self.alpha);
85        // }
86
87        // let cblas_uplo = match tr {
88        //     Triangle::Lower => CBLAS_UPLO::CblasLower,
89        //     Triangle::Upper => CBLAS_UPLO::CblasUpper,
90        // };
91
92        // match ty {
93        //     Type::Her => her(cblas_uplo, beta.re(), self.x, &mut a_copy),
94        //     Type::Sym => syr(cblas_uplo, beta, self.x, &mut a_copy),
95        //     Type::Tri => {
96        //         ger(beta, self.x, self.x, &mut a_copy);
97        //     }
98        // }
99        // a_copy
100        todo!()
101    }
102}
103
104impl<T> MatVec<T> for Naive
105where
106    T: ComplexFloat,
107    i8: Into<T::Real>,
108    T::Real: Into<T>,
109{
110    fn matvec<'a, La, Lx>(
111        &self,
112        a: &'a DSlice<T, 2, La>,
113        x: &'a DSlice<T, 1, Lx>,
114    ) -> impl MatVecBuilder<'a, T, La, Lx>
115    where
116        La: Layout,
117        Lx: Layout,
118    {
119        BlasMatVecBuilder {
120            alpha: 1.into().into(),
121            a,
122            x,
123        }
124    }
125}
126
127impl<T: ComplexFloat + 'static + Add<Output = T> + Mul<Output = T> + Zero + Copy> VecOps<T>
128    for Naive
129{
130    fn add_to_scaled<Lx: Layout, Ly: Layout>(
131        &self,
132        _alpha: T,
133        _x: &DSlice<T, 1, Lx>,
134        _y: &mut DSlice<T, 1, Ly>,
135    ) {
136        todo!()
137        // axpy(alpha, x, y);
138    }
139
140    fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T {
141        let mut result = T::zero();
142        for (elem_x, elem_y) in std::iter::zip(x.into_iter(), y.into_iter()) {
143            result = result + *elem_x * (*elem_y);
144        }
145        result
146    }
147
148    fn dotc<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &DSlice<T, 1, Ly>) -> T {
149        todo!()
150        // dotc(x, y)
151    }
152
153    fn norm2<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real {
154        todo!()
155        // nrm2(x)
156    }
157
158    fn norm1<Lx: Layout>(&self, _x: &DSlice<T, 1, Lx>) -> T::Real
159    where
160        T: ComplexFloat,
161    {
162        todo!()
163        // asum(x)
164    }
165
166    fn copy<Lx: Layout, Ly: Layout>(&self, _x: &DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
167        todo!()
168    }
169
170    fn scal<Lx: Layout>(&self, _alpha: T, _x: &mut DSlice<T, 1, Lx>) {
171        todo!()
172    }
173
174    fn swap<Lx: Layout, Ly: Layout>(&self, _x: &mut DSlice<T, 1, Lx>, _y: &mut DSlice<T, 1, Ly>) {
175        todo!()
176    }
177
178    fn rot<Lx: Layout, Ly: Layout>(
179        &self,
180        _x: &mut DSlice<T, 1, Lx>,
181        _y: &mut DSlice<T, 1, Ly>,
182        _c: T::Real,
183        _s: T,
184    ) where
185        T: ComplexFloat,
186    {
187        todo!()
188    }
189}
190
191impl<T: ComplexFloat + 'static + PartialOrd + Add<Output = T> + Mul<Output = T> + Zero + Copy>
192    Argmax<T> for Naive
193{
194    fn argmax_overwrite<Lx: Layout, S: Shape>(
195        &self,
196        x: &Slice<T, S, Lx>,
197        output: &mut Vec<usize>,
198    ) -> bool {
199        output.clear();
200
201        if x.is_empty() {
202            return false;
203        }
204
205        if x.rank() == 0 {
206            return true;
207        }
208
209        let mut max_flat_idx = 0;
210        let mut max_val = x.iter().next().unwrap();
211
212        for (flat_idx, val) in x.iter().enumerate().skip(1) {
213            if val > max_val {
214                max_val = val;
215                max_flat_idx = flat_idx;
216            }
217        }
218
219        let indices = unravel_index(x, max_flat_idx);
220        output.extend_from_slice(&indices);
221        true
222    }
223
224    fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>> {
225        let mut result = Vec::new();
226        if self.argmax_overwrite(x, &mut result) {
227            Some(result)
228        } else {
229            None
230        }
231    }
232}
233
234pub fn unravel_index<T, S: Shape, L: Layout>(x: &Slice<T, S, L>, mut flat: usize) -> Vec<usize> {
235    let rank = x.rank();
236
237    assert!(
238        flat < x.len(),
239        "flat index out of bounds: {} >= {}",
240        flat,
241        x.len()
242    );
243
244    let mut coords = vec![0usize; rank];
245
246    for i in (0..rank).rev() {
247        let dim = x.shape().dim(i);
248        coords[i] = flat % dim;
249        flat /= dim;
250    }
251
252    coords
253}