1use std::num::NonZero;
2
3use faer::Mat;
4use faer::linalg::matmul::matmul;
5use faer_traits::ComplexField;
6
7use faer::{Accum, Par};
8use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor};
9use num_complex::ComplexFloat;
10
11use num_traits::{One, Zero};
12
13use mdarray_linalg::matmul::{Axes, Side, ContractBuilder, Triangle, Type, _contract};
14use mdarray_linalg::prelude::*;
15use num_cpus;
16
17use crate::{Faer, into_faer, into_faer_mut, into_mdarray};
18
19struct FaerMatMulBuilder<'a, T, La, Lb>
20where
21 La: Layout,
22 Lb: Layout,
23{
24 alpha: T,
25 a: &'a DSlice<T, 2, La>,
26 b: &'a DSlice<T, 2, Lb>,
27 par: Par,
28}
29
30struct FaerContractBuilder<'a, T, La, Lb>
31where
32 La: Layout,
33 Lb: Layout,
34{
35 alpha: T,
36 a: &'a Slice<T, DynRank, La>,
37 b: &'a Slice<T, DynRank, Lb>,
38 axes: Axes,
39}
40
41impl<'a, T, La, Lb> FaerMatMulBuilder<'a, T, La, Lb>
42where
43 La: Layout,
44 Lb: Layout,
45 T: ComplexFloat + ComplexField + One + 'static,
46{
47 #[allow(dead_code)]
48 pub fn parallelize(mut self) -> Self {
49 self.par = Par::Rayon(NonZero::new(num_cpus::get()).unwrap());
51 self
52 }
53}
54
55impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for FaerMatMulBuilder<'a, T, La, Lb>
56where
57 La: Layout,
58 Lb: Layout,
59 T: ComplexFloat + ComplexField + One + 'static,
60{
61 fn parallelize(mut self) -> Self {
62 self.par = Par::Rayon(NonZero::new(num_cpus::get()).unwrap());
64 self
65 }
66
67 fn scale(mut self, factor: T) -> Self {
68 self.alpha = self.alpha * factor;
69 self
70 }
71
72 fn eval(self) -> DTensor<T, 2> {
73 let (ma, _) = *self.a.shape();
74 let (_, nb) = *self.b.shape();
75
76 let a_faer = into_faer(self.a);
77 let b_faer = into_faer(self.b);
78
79 let mut c_faer = Mat::<T>::zeros(ma, nb);
80
81 matmul(
82 &mut c_faer,
83 Accum::Replace,
84 a_faer,
85 b_faer,
86 self.alpha,
87 self.par,
88 );
89
90 into_mdarray::<T>(c_faer)
91 }
92
93 fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
94 let mut c_faer = into_faer_mut(c);
95 matmul(
96 &mut c_faer,
97 Accum::Replace,
98 into_faer(self.a),
99 into_faer(self.b),
100 self.alpha,
101 self.par,
102 );
103 }
104
105 fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
106 let mut c_faer = into_faer_mut(c);
107 matmul(
108 &mut c_faer,
109 Accum::Add,
110 into_faer(self.a),
111 into_faer(self.b),
112 self.alpha,
113 self.par,
114 );
115 }
116
117 fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, _beta: T) {
118 let mut c_faer = into_faer_mut(c);
119 matmul(
120 &mut c_faer,
121 Accum::Add,
122 into_faer(self.a),
123 into_faer(self.b),
124 self.alpha,
125 self.par,
126 );
127 todo!(); }
129
130 fn special(self, _lr: Side, _type_of_matrix: Type, _tr: Triangle) -> DTensor<T, 2> {
131 self.eval()
132 }
133}
134
135impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for FaerContractBuilder<'a, T, La, Lb>
136where
137 La: Layout,
138 Lb: Layout,
139 T: ComplexFloat + Zero + One + ComplexField + 'static,
140{
141 fn scale(mut self, factor: T) -> Self {
142 self.alpha = self.alpha * factor;
143 self
144 }
145
146 fn eval(self) -> Tensor<T, DynRank> {
147 _contract(Faer, self.a, self.b, self.axes, self.alpha)
148 }
149
150 fn overwrite(self, _c: &mut Slice<T>) {
151 todo!()
152 }
153}
154
155impl<T> MatMul<T> for Faer
156where
157 T: ComplexFloat + ComplexField + One + 'static,
158{
159 fn matmul<'a, La, Lb>(
160 &self,
161 a: &'a DSlice<T, 2, La>,
162 b: &'a DSlice<T, 2, Lb>,
163 ) -> impl MatMulBuilder<'a, T, La, Lb>
164 where
165 La: Layout,
166 Lb: Layout,
167 {
168 FaerMatMulBuilder {
169 alpha: T::one(),
170 a,
171 b,
172 par: Par::Seq,
173 }
174 }
175
176 fn contract_all<'a, La, Lb>(
178 &self,
179 a: &'a Slice<T, DynRank, La>,
180 b: &'a Slice<T, DynRank, Lb>,
181 ) -> impl ContractBuilder<'a, T, La, Lb>
182 where
183 T: 'a,
184 La: Layout,
185 Lb: Layout,
186 {
187 FaerContractBuilder {
188 alpha: T::one(),
189 a,
190 b,
191 axes: Axes::All,
192 }
193 }
194
195 fn contract_n<'a, La, Lb>(
199 &self,
200 a: &'a Slice<T, DynRank, La>,
201 b: &'a Slice<T, DynRank, Lb>,
202 n: usize,
203 ) -> impl ContractBuilder<'a, T, La, Lb>
204 where
205 T: 'a,
206 La: Layout,
207 Lb: Layout,
208 {
209 FaerContractBuilder {
210 alpha: T::one(),
211 a,
212 b,
213 axes: Axes::LastFirst { k: (n) },
214 }
215 }
216
217 fn contract<'a, La, Lb>(
222 &self,
223 a: &'a Slice<T, DynRank, La>,
224 b: &'a Slice<T, DynRank, Lb>,
225 axes_a: impl Into<Box<[usize]>>,
226 axes_b: impl Into<Box<[usize]>>,
227 ) -> impl ContractBuilder<'a, T, La, Lb>
228 where
229 T: 'a,
230 La: Layout,
231 Lb: Layout,
232 {
233 FaerContractBuilder {
234 alpha: T::one(),
235 a,
236 b,
237 axes: Axes::Specific(axes_a.into(), axes_b.into()),
238 }
239 }
240}