mdarray_linalg/naive/matmul/
context.rs

1use num_complex::ComplexFloat;
2use num_traits::{One, Zero};
3
4use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor, tensor};
5
6use crate::matmul::{Axes, Side, Triangle, Type, _contract};
7use crate::prelude::*;
8
9use crate::Naive;
10
11use super::simple::naive_matmul;
12
13struct NaiveMatMulBuilder<'a, T, La, Lb>
14where
15    La: Layout,
16    Lb: Layout,
17{
18    alpha: T,
19    a: &'a DSlice<T, 2, La>,
20    b: &'a DSlice<T, 2, Lb>,
21}
22
23struct NaiveContractBuilder<'a, T, La, Lb>
24where
25    La: Layout,
26    Lb: Layout,
27{
28    alpha: T,
29    a: &'a Slice<T, DynRank, La>,
30    b: &'a Slice<T, DynRank, Lb>,
31    axes: Axes,
32}
33
34impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for NaiveMatMulBuilder<'a, T, La, Lb>
35where
36    La: Layout,
37    Lb: Layout,
38    T: ComplexFloat + Zero + One,
39    // i8: Into<T::Real>,
40    // T::Real: Into<T>,
41{
42    /// Enable parallelization.
43    fn parallelize(self) -> Self {
44        self
45    }
46
47    /// Multiplies the result by a scalar factor.
48    fn scale(mut self, factor: T) -> Self {
49        self.alpha = self.alpha * factor;
50        self
51    }
52
53    /// Returns a new owned tensor containing the result.
54    fn eval(self) -> DTensor<T, 2> {
55        let (m, _) = *self.a.shape();
56        let (_, n) = *self.b.shape();
57        let mut c = tensor![[T::zero(); n]; m];
58        naive_matmul(self.alpha, self.a, self.b, T::zero(), &mut c);
59        c
60    }
61
62    /// Overwrites the provided slice with the result.
63    fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
64        naive_matmul(self.alpha, self.a, self.b, T::zero(), c);
65    }
66
67    /// Adds the result to the provided slice.
68    fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
69        naive_matmul(self.alpha, self.a, self.b, T::one(), c);
70    }
71
72    /// Adds the result to the provided slice after scaling the slice by `beta`
73    /// (i.e. C := beta * C + result).
74    fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T) {
75        naive_matmul(self.alpha, self.a, self.b, beta, c);
76    }
77
78    /// Computes a matrix product where the first operand is a special
79    /// matrix (symmetric, Hermitian, or triangular) and the other is
80    /// general.
81    ///
82    /// The special matrix is always treated as `A`. `lr` determines the multiplication order:
83    /// - `Side::Left`  : C := alpha * A * B
84    /// - `Side::Right` : C := alpha * B * A
85    ///
86    /// # Parameters
87    /// * `lr` - side of multiplication (left or right)
88    /// * `type_of_matrix` - special matrix type: `Sym`, `Her`, or `Tri`
89    /// * `tr` - triangle containing stored data: `Upper` or `Lower`
90    ///
91    /// Only the specified triangle needs to be stored for symmetric/Hermitian matrices;
92    /// for triangular matrices it specifies which half is used.
93    ///
94    /// # Returns
95    /// A new tensor with the result.
96    fn special(self, _lr: Side, _type_of_matrix: Type, _tr: Triangle) -> DTensor<T, 2> {
97        todo!()
98    }
99}
100
101impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for NaiveContractBuilder<'a, T, La, Lb>
102where
103    La: Layout,
104    Lb: Layout,
105    T: ComplexFloat + Zero + One,
106{
107    fn scale(mut self, factor: T) -> Self {
108        self.alpha = self.alpha * factor;
109        self
110    }
111
112    fn eval(self) -> Tensor<T> {
113        _contract(Naive, self.a, self.b, self.axes, self.alpha)
114    }
115
116    fn overwrite(self, _c: &mut Slice<T>) {
117        todo!()
118    }
119}
120
121impl<T> MatMul<T> for Naive
122where
123    T: ComplexFloat,
124    // i8: Into<T::Real>,
125    // T::Real: Into<T>,
126{
127    fn matmul<'a, La, Lb>(
128        &self,
129        a: &'a DSlice<T, 2, La>,
130        b: &'a DSlice<T, 2, Lb>,
131    ) -> impl MatMulBuilder<'a, T, La, Lb>
132    where
133        La: Layout,
134        Lb: Layout,
135    {
136        NaiveMatMulBuilder {
137            alpha: T::one(),
138            a,
139            b,
140        }
141    }
142
143    /// Contracts all axes of the first tensor with all axes of the second tensor.
144    fn contract_all<'a, La, Lb>(
145        &self,
146        a: &'a Slice<T, DynRank, La>,
147        b: &'a Slice<T, DynRank, Lb>,
148    ) -> impl ContractBuilder<'a, T, La, Lb>
149    where
150        T: 'a,
151        La: Layout,
152        Lb: Layout,
153    {
154        NaiveContractBuilder {
155            alpha: T::one(),
156            a,
157            b,
158            axes: Axes::All,
159        }
160    }
161
162    /// Contracts the last `n` axes of the first tensor with the first `n` axes of the second tensor.
163    /// # Example
164    /// For two matrices (2D tensors), `contract_n(1)` performs standard matrix multiplication.
165    fn contract_n<'a, La, Lb>(
166        &self,
167        a: &'a Slice<T, DynRank, La>,
168        b: &'a Slice<T, DynRank, Lb>,
169        n: usize,
170    ) -> impl ContractBuilder<'a, T, La, Lb>
171    where
172        T: 'a,
173        La: Layout,
174        Lb: Layout,
175    {
176        NaiveContractBuilder {
177            alpha: T::one(),
178            a,
179            b,
180            axes: Axes::LastFirst { k: (n) },
181        }
182    }
183
184    /// Specifies exactly which axes to contract_all.
185    /// # Example
186    /// `specific([1, 2], [3, 4])` contracts axis 1 and 2 of `a`
187    /// with axes 3 and 4 of `b`.
188    fn contract<'a, La, Lb>(
189        &self,
190        a: &'a Slice<T, DynRank, La>,
191        b: &'a Slice<T, DynRank, Lb>,
192        axes_a: impl Into<Box<[usize]>>,
193        axes_b: impl Into<Box<[usize]>>,
194    ) -> impl ContractBuilder<'a, T, La, Lb>
195    where
196        T: 'a,
197        La: Layout,
198        Lb: Layout,
199    {
200        NaiveContractBuilder {
201            alpha: T::one(),
202            a,
203            b,
204            axes: Axes::Specific(axes_a.into(), axes_b.into()),
205        }
206    }
207}