mdarray_linalg/
matvec.rs

1//! Basic vector and matrix-vector operations, including Ax, Ax + βy, Givens rotations, argmax, and rank-1 updates
2use mdarray::{DSlice, DTensor, Layout, Shape, Slice};
3
4use crate::matmul::{Triangle, Type};
5
6use num_complex::ComplexFloat;
7
8/// Matrix-vector multiplication and transformations
9pub trait MatVec<T> {
10    fn matvec<'a, La, Lx>(
11        &self,
12        a: &'a DSlice<T, 2, La>,
13        x: &'a DSlice<T, 1, Lx>,
14    ) -> impl MatVecBuilder<'a, T, La, Lx>
15    where
16        La: Layout,
17        Lx: Layout;
18}
19
20/// Builder interface for configuring matrix-vector operations
21pub trait MatVecBuilder<'a, T, La, Lx>
22where
23    La: Layout,
24    Lx: Layout,
25    T: 'a,
26    La: 'a,
27    Lx: 'a,
28{
29    fn parallelize(self) -> Self;
30
31    /// `A := α·A`
32    fn scale(self, alpha: T) -> Self;
33
34    /// Returns `α·A·x`
35    fn eval(self) -> DTensor<T, 1>;
36
37    /// `A := α·A·x`
38    fn overwrite<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>);
39
40    /// `A := α·A·x + y`
41    fn add_to<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>);
42
43    /// `A := α·A·x + β·y`
44    fn add_to_scaled<Ly: Layout>(self, y: &mut DSlice<T, 1, Ly>, beta: T);
45
46    /// Rank-1 update: `β·x·yᵀ + α·A`
47    fn add_outer<Ly: Layout>(self, y: &DSlice<T, 1, Ly>, beta: T) -> DTensor<T, 2>;
48
49    /// Rank-1 update: `β·x·xᵀ (or x·x†) + α·A`
50    fn add_outer_special(self, beta: T, ty: Type, tr: Triangle) -> DTensor<T, 2>;
51
52    // Special rank-2 update: beta * (x * y^T + y * x^T) + alpha * A
53    // syr2 her2
54
55    // Special rank-k update: beta * AA^T + alpha * C
56    // syrk herk
57}
58
59/// Vector operations and basic linear algebra utilities
60pub trait VecOps<T: ComplexFloat> {
61    /// Accumulate a scaled vector: `y := α·x + y`
62    fn add_to_scaled<Lx: Layout, Ly: Layout>(
63        &self,
64        alpha: T,
65        x: &DSlice<T, 1, Lx>,
66        y: &mut DSlice<T, 1, Ly>,
67    );
68
69    /// Dot product: `∑xᵢyᵢ`
70    fn dot<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T;
71
72    /// Conjugated dot product: `∑(xᵢ * conj(yᵢ))`
73    fn dotc<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &DSlice<T, 1, Ly>) -> T;
74
75    /// L2 norm: `√(∑|xᵢ|²)`
76    fn norm2<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real;
77
78    /// L1 norm: `∑|xᵢ|`
79    fn norm1<Lx: Layout>(&self, x: &DSlice<T, 1, Lx>) -> T::Real
80    where
81        T: ComplexFloat;
82
83    /// Copy vector: `y := x` (**TODO**)
84    fn copy<Lx: Layout, Ly: Layout>(&self, x: &DSlice<T, 1, Lx>, y: &mut DSlice<T, 1, Ly>);
85
86    /// Scale vector: `x := α·xᵢ` (**TODO**)
87    fn scal<Lx: Layout>(&self, alpha: T, x: &mut DSlice<T, 1, Lx>);
88
89    /// Swap vectors: `x ↔ y` (**TODO**)
90    fn swap<Lx: Layout, Ly: Layout>(&self, x: &mut DSlice<T, 1, Lx>, y: &mut DSlice<T, 1, Ly>);
91
92    /// Givens rotation (**TODO**)
93    fn rot<Lx: Layout, Ly: Layout>(
94        &self,
95        x: &mut DSlice<T, 1, Lx>,
96        y: &mut DSlice<T, 1, Ly>,
97        c: T::Real,
98        s: T,
99    ) where
100        T: ComplexFloat;
101}
102
103/// Argmax for tensors, unlike other traits: it requires `T: PartialOrd` and works on tensor of any rank.
104pub trait Argmax<T: ComplexFloat + std::cmp::PartialOrd> {
105    fn argmax_overwrite<Lx: Layout, S: Shape>(
106        &self,
107        x: &Slice<T, S, Lx>,
108        output: &mut Vec<usize>,
109    ) -> bool;
110
111    fn argmax_abs_overwrite<Lx: Layout, S: Shape>(
112        &self,
113        x: &Slice<T, S, Lx>,
114        output: &mut Vec<usize>,
115    ) -> bool;
116
117    /// Index of max xᵢ (argmaxᵢ xᵢ)
118    fn argmax<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>>;
119
120    /// Index of max |xᵢ| (argmaxᵢ |xᵢ|)
121    fn argmax_abs<Lx: Layout, S: Shape>(&self, x: &Slice<T, S, Lx>) -> Option<Vec<usize>>;
122}