mdarray_linalg/
matmul.rs

1//! Matrix multiplication and tensor contraction
2//!
3//!```rust
4//!use mdarray::tensor;
5//!use mdarray_linalg::prelude::*;
6//!use mdarray_linalg::Naive;
7//!
8//!let a = tensor![[1., 2.], [3., 4.]].into_dyn(); // requires dynamic tensor
9//!let b = tensor![[5., 6.], [7., 8.]].into_dyn();
10//!
11//!let expected_all = tensor![[70.0]].into_dyn();
12//!let result_all = Naive.contract_all(&a, &b).eval();
13//!let result_contract_k = Naive.contract_n(&a, &b, 2).eval();
14//!assert_eq!(result_contract_k, expected_all);
15//!
16//!let expected_matmul = tensor![[19., 22.], [43., 50.]].into_dyn();
17//!let result_specific = Naive
18//!    .contract(&a, &b, vec![1], vec![0])
19//!    .eval();
20//!assert_eq!(result_specific, expected_matmul);
21//!```
22use num_complex::ComplexFloat;
23use num_traits::{One, Zero};
24
25use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor};
26
27/// Specifies whether the left or right matrix has the special property
28pub enum Side {
29    Left,
30    Right,
31}
32
33/// Identifies the structural type of a matrix (Hermitian, symmetric, or triangular)
34pub enum Type {
35    Sym,
36    Her,
37    Tri,
38}
39
40/// Specifies whether a matrix is lower or upper triangular
41pub enum Triangle {
42    Upper,
43    Lower,
44}
45
46/// Matrix-matrix multiplication and related operations
47pub trait MatMul<T: One> {
48    fn matmul<'a, La, Lb>(
49        &self,
50        a: &'a DSlice<T, 2, La>,
51        b: &'a DSlice<T, 2, Lb>,
52    ) -> impl MatMulBuilder<'a, T, La, Lb>
53    where
54        T: One,
55        La: Layout,
56        Lb: Layout;
57
58    /// Contracts all axes of the first tensor with all axes of the second tensor.
59    fn contract_all<'a, La, Lb>(
60        &self,
61        a: &'a Slice<T, DynRank, La>,
62        b: &'a Slice<T, DynRank, Lb>,
63    ) -> impl ContractBuilder<'a, T, La, Lb>
64    where
65        T: 'a,
66        La: Layout,
67        Lb: Layout;
68
69    /// Contracts the last `n` axes of the first tensor with the first `n` axes of the second tensor.
70    /// # Example
71    /// For two matrices (2D tensors), `contract_n(1)` performs standard matrix multiplication.
72    fn contract_n<'a, La, Lb>(
73        &self,
74        a: &'a Slice<T, DynRank, La>,
75        b: &'a Slice<T, DynRank, Lb>,
76        n: usize,
77    ) -> impl ContractBuilder<'a, T, La, Lb>
78    where
79        T: 'a,
80        La: Layout,
81        Lb: Layout;
82
83    /// Specifies exactly which axes to contract_all.
84    /// # Example
85    /// `specific([1, 2], [3, 4])` contracts axis 1 and 2 of `a`
86    /// with axes 3 and 4 of `b`.
87    fn contract<'a, La, Lb>(
88        &self,
89        a: &'a Slice<T, DynRank, La>,
90        b: &'a Slice<T, DynRank, Lb>,
91        axes_a: impl Into<Box<[usize]>>,
92        axes_b: impl Into<Box<[usize]>>,
93    ) -> impl ContractBuilder<'a, T, La, Lb>
94    where
95        T: 'a,
96        La: Layout,
97        Lb: Layout;
98}
99
100/// Builder interface for configuring matrix-matrix operations
101pub trait MatMulBuilder<'a, T, La, Lb>
102where
103    La: Layout,
104    Lb: Layout,
105    T: 'a,
106    La: 'a,
107    Lb: 'a,
108{
109    /// Enable parallelization.
110    fn parallelize(self) -> Self;
111
112    /// Multiplies the result by a scalar factor.
113    fn scale(self, factor: T) -> Self;
114
115    /// Returns a new owned tensor containing the result.
116    fn eval(self) -> DTensor<T, 2>;
117
118    /// Overwrites the provided slice with the result.
119    fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
120
121    /// Adds the result to the provided slice.
122    fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
123
124    /// Adds the result to the provided slice after scaling the slice by `beta`
125    /// (i.e. C := beta * C + result).
126    fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T);
127
128    /// Computes a matrix product where the first operand is a special
129    /// matrix (symmetric, Hermitian, or triangular) and the other is
130    /// general.
131    ///
132    /// The special matrix is always treated as `A`. `lr` determines the multiplication order:
133    /// - `Side::Left`  : C := alpha * A * B
134    /// - `Side::Right` : C := alpha * B * A
135    ///
136    /// # Parameters
137    /// * `lr` - side of multiplication (left or right)
138    /// * `type_of_matrix` - special matrix type: `Sym`, `Her`, or `Tri`
139    /// * `tr` - triangle containing stored data: `Upper` or `Lower`
140    ///
141    /// Only the specified triangle needs to be stored for symmetric/Hermitian matrices;
142    /// for triangular matrices it specifies which half is used.
143    ///
144    /// # Returns
145    /// A new tensor with the result.
146    fn special(self, lr: Side, type_of_matrix: Type, tr: Triangle) -> DTensor<T, 2>;
147}
148
149/// Builder interface for configuring tensor contraction operations
150pub trait ContractBuilder<'a, T, La, Lb>
151where
152    T: 'a,
153    La: Layout,
154    Lb: Layout,
155{
156    /// Multiplies the result by a scalar factor.
157    fn scale(self, factor: T) -> Self;
158
159    /// Returns a new owned tensor containing the result.
160    fn eval(self) -> Tensor<T, DynRank>;
161
162    /// Overwrites the provided tensor with the result.
163    fn overwrite(self, c: &mut Slice<T>);
164}
165
166pub enum Axes {
167    All,
168    LastFirst { k: usize },
169    Specific(Box<[usize]>, Box<[usize]>),
170}
171
172/// Helper for implementing contraction through matrix multiplication
173pub fn _contract<T: Zero + ComplexFloat, La: Layout, Lb: Layout>(
174    bd: impl MatMul<T>,
175    a: &Slice<T, DynRank, La>,
176    b: &Slice<T, DynRank, Lb>,
177    axes: Axes,
178    alpha: T,
179) -> Tensor<T, DynRank> {
180    let rank_a = a.rank();
181    let rank_b = b.rank();
182
183    let extract_shape = |s: &DynRank| match s {
184        DynRank::Dyn(arr) => arr.clone(),
185        DynRank::One(n) => Box::new([*n]),
186    };
187    let shape_a = extract_shape(a.shape());
188    let shape_b = extract_shape(b.shape());
189
190    let (axes_a, axes_b) = match axes {
191        Axes::All => ((0..rank_a).collect(), (0..rank_b).collect()),
192        Axes::LastFirst { k } => (((rank_a - k)..rank_a).collect(), (0..k).collect()),
193        Axes::Specific(ax_a, ax_b) => (ax_a, ax_b),
194    };
195
196    assert_eq!(
197        axes_a.len(),
198        axes_b.len(),
199        "Axis count mismatch: {} (tensor A) vs {} (tensor B)",
200        axes_a.len(),
201        axes_b.len()
202    );
203
204    axes_a.iter().zip(&axes_b).for_each(|(a_ax, b_ax)| {
205        assert_eq!(
206            shape_a[*a_ax], shape_b[*b_ax],
207            "Dimension mismatch at contraction: A[axis {}] = {} ≠ B[axis {}] = {}",
208            *a_ax, shape_a[*a_ax], *b_ax, shape_b[*b_ax]
209        );
210    });
211
212    let compute_keep_axes = |rank: usize, axes: &[usize]| -> Vec<usize> {
213        (0..rank).filter(|k| !axes.contains(k)).collect()
214    };
215    let keep_axes_a = compute_keep_axes(rank_a, &axes_a);
216    let keep_axes_b = compute_keep_axes(rank_b, &axes_b);
217    let compute_keep_shape = |axes: &[usize], shape: &[usize]| -> Vec<usize> {
218        axes.iter().map(|&ax| shape[ax]).collect()
219    };
220
221    let mut keep_shape_a = compute_keep_shape(&keep_axes_a, &shape_a);
222    let keep_shape_b = compute_keep_shape(&keep_axes_b, &shape_b);
223
224    let compute_size =
225        |axes: &[usize], shape: &[usize]| -> usize { axes.iter().map(|&k| shape[k]).product() };
226
227    let contract_size_a = compute_size(&axes_a, &shape_a);
228    let contract_size_b = compute_size(&axes_b, &shape_b);
229    let keep_size_a = compute_size(&keep_axes_a, &shape_a);
230    let keep_size_b = compute_size(&keep_axes_b, &shape_b);
231
232    let order_a: Vec<usize> = keep_axes_a.iter().chain(axes_a.iter()).copied().collect();
233    let order_b: Vec<usize> = axes_b.iter().chain(keep_axes_b.iter()).copied().collect();
234
235    let trans_a = a.permute(order_a).to_tensor();
236    let trans_b = b.permute(order_b).to_tensor();
237
238    let a_resh = trans_a.reshape([keep_size_a, contract_size_a]);
239    let b_resh = trans_b.reshape([contract_size_b, keep_size_b]);
240
241    let ab_resh = bd.matmul(&a_resh, &b_resh).scale(alpha).eval();
242
243    if keep_shape_a.is_empty() && keep_shape_b.is_empty() {
244        ab_resh.to_owned().into_dyn()
245    } else if keep_shape_a.is_empty() {
246        ab_resh
247            .view(0, ..)
248            .reshape(keep_shape_a)
249            .to_owned()
250            .into_dyn()
251            .into()
252    } else if keep_shape_b.is_empty() {
253        ab_resh
254            .view(.., 0)
255            .reshape(keep_shape_b)
256            .to_owned()
257            .into_dyn()
258            .into()
259    } else {
260        keep_shape_a.extend(keep_shape_b);
261        ab_resh.reshape(keep_shape_a).to_owned().into_dyn().into()
262    }
263}