mdarray_linalg_faer/matmul/
context.rs

1use std::num::NonZero;
2
3use faer::Mat;
4use faer::linalg::matmul::matmul;
5use faer_traits::ComplexField;
6
7use faer::{Accum, Par};
8use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor};
9use num_complex::ComplexFloat;
10
11use num_traits::{One, Zero};
12
13use mdarray_linalg::matmul::{Axes, Side, ContractBuilder, Triangle, Type, _contract};
14use mdarray_linalg::prelude::*;
15use num_cpus;
16
17use crate::{Faer, into_faer, into_faer_mut, into_mdarray};
18
19struct FaerMatMulBuilder<'a, T, La, Lb>
20where
21    La: Layout,
22    Lb: Layout,
23{
24    alpha: T,
25    a: &'a DSlice<T, 2, La>,
26    b: &'a DSlice<T, 2, Lb>,
27    par: Par,
28}
29
30struct FaerContractBuilder<'a, T, La, Lb>
31where
32    La: Layout,
33    Lb: Layout,
34{
35    alpha: T,
36    a: &'a Slice<T, DynRank, La>,
37    b: &'a Slice<T, DynRank, Lb>,
38    axes: Axes,
39}
40
41impl<'a, T, La, Lb> FaerMatMulBuilder<'a, T, La, Lb>
42where
43    La: Layout,
44    Lb: Layout,
45    T: ComplexFloat + ComplexField + One + 'static,
46{
47    #[allow(dead_code)]
48    pub fn parallelize(mut self) -> Self {
49        // Alternative ??? : use faer::get_global_parallelism()
50        self.par = Par::Rayon(NonZero::new(num_cpus::get()).unwrap());
51        self
52    }
53}
54
55impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for FaerMatMulBuilder<'a, T, La, Lb>
56where
57    La: Layout,
58    Lb: Layout,
59    T: ComplexFloat + ComplexField + One + 'static,
60{
61    fn parallelize(mut self) -> Self {
62        // Alternative ?????
63        self.par = Par::Rayon(NonZero::new(num_cpus::get()).unwrap());
64        self
65    }
66
67    fn scale(mut self, factor: T) -> Self {
68        self.alpha = self.alpha * factor;
69        self
70    }
71
72    fn eval(self) -> DTensor<T, 2> {
73        let (ma, _) = *self.a.shape();
74        let (_, nb) = *self.b.shape();
75
76        let a_faer = into_faer(self.a);
77        let b_faer = into_faer(self.b);
78
79        let mut c_faer = Mat::<T>::zeros(ma, nb);
80
81        matmul(
82            &mut c_faer,
83            Accum::Replace,
84            a_faer,
85            b_faer,
86            self.alpha,
87            self.par,
88        );
89
90        into_mdarray::<T>(c_faer)
91    }
92
93    fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
94        let mut c_faer = into_faer_mut(c);
95        matmul(
96            &mut c_faer,
97            Accum::Replace,
98            into_faer(self.a),
99            into_faer(self.b),
100            self.alpha,
101            self.par,
102        );
103    }
104
105    fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
106        let mut c_faer = into_faer_mut(c);
107        matmul(
108            &mut c_faer,
109            Accum::Add,
110            into_faer(self.a),
111            into_faer(self.b),
112            self.alpha,
113            self.par,
114        );
115    }
116
117    fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, _beta: T) {
118        let mut c_faer = into_faer_mut(c);
119        matmul(
120            &mut c_faer,
121            Accum::Add,
122            into_faer(self.a),
123            into_faer(self.b),
124            self.alpha,
125            self.par,
126        );
127        todo!(); // multiplication by beta not implemented in faer ?
128    }
129
130    fn special(self, _lr: Side, _type_of_matrix: Type, _tr: Triangle) -> DTensor<T, 2> {
131        self.eval()
132    }
133}
134
135impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for FaerContractBuilder<'a, T, La, Lb>
136where
137    La: Layout,
138    Lb: Layout,
139    T: ComplexFloat + Zero + One + ComplexField + 'static,
140{
141    fn scale(mut self, factor: T) -> Self {
142        self.alpha = self.alpha * factor;
143        self
144    }
145
146    fn eval(self) -> Tensor<T, DynRank> {
147        _contract(Faer, self.a, self.b, self.axes, self.alpha)
148    }
149
150    fn overwrite(self, _c: &mut Slice<T>) {
151        todo!()
152    }
153}
154
155impl<T> MatMul<T> for Faer
156where
157    T: ComplexFloat + ComplexField + One + 'static,
158{
159    fn matmul<'a, La, Lb>(
160        &self,
161        a: &'a DSlice<T, 2, La>,
162        b: &'a DSlice<T, 2, Lb>,
163    ) -> impl MatMulBuilder<'a, T, La, Lb>
164    where
165        La: Layout,
166        Lb: Layout,
167    {
168        FaerMatMulBuilder {
169            alpha: T::one(),
170            a,
171            b,
172            par: Par::Seq,
173        }
174    }
175
176    /// Contracts all axes of the first tensor with all axes of the second tensor.
177    fn contract_all<'a, La, Lb>(
178        &self,
179        a: &'a Slice<T, DynRank, La>,
180        b: &'a Slice<T, DynRank, Lb>,
181    ) -> impl ContractBuilder<'a, T, La, Lb>
182    where
183        T: 'a,
184        La: Layout,
185        Lb: Layout,
186    {
187        FaerContractBuilder {
188            alpha: T::one(),
189            a,
190            b,
191            axes: Axes::All,
192        }
193    }
194
195    /// Contracts the last `n` axes of the first tensor with the first `n` axes of the second tensor.
196    /// # Example
197    /// For two matrices (2D tensors), `contract_n(1)` performs standard matrix multiplication.
198    fn contract_n<'a, La, Lb>(
199        &self,
200        a: &'a Slice<T, DynRank, La>,
201        b: &'a Slice<T, DynRank, Lb>,
202        n: usize,
203    ) -> impl ContractBuilder<'a, T, La, Lb>
204    where
205        T: 'a,
206        La: Layout,
207        Lb: Layout,
208    {
209        FaerContractBuilder {
210            alpha: T::one(),
211            a,
212            b,
213            axes: Axes::LastFirst { k: (n) },
214        }
215    }
216
217    /// Specifies exactly which axes to contract_all.
218    /// # Example
219    /// `specific([1, 2], [3, 4])` contracts axis 1 and 2 of `a`
220    /// with axes 3 and 4 of `b`.
221    fn contract<'a, La, Lb>(
222        &self,
223        a: &'a Slice<T, DynRank, La>,
224        b: &'a Slice<T, DynRank, Lb>,
225        axes_a: impl Into<Box<[usize]>>,
226        axes_b: impl Into<Box<[usize]>>,
227    ) -> impl ContractBuilder<'a, T, La, Lb>
228    where
229        T: 'a,
230        La: Layout,
231        Lb: Layout,
232    {
233        FaerContractBuilder {
234            alpha: T::one(),
235            a,
236            b,
237            axes: Axes::Specific(axes_a.into(), axes_b.into()),
238        }
239    }
240}