1use num_complex::ComplexFloat;
23use num_traits::{One, Zero};
24
25use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor};
26
27pub enum Side {
29 Left,
30 Right,
31}
32
33pub enum Type {
35 Sym,
36 Her,
37 Tri,
38}
39
40pub enum Triangle {
42 Upper,
43 Lower,
44}
45
46pub trait MatMul<T: One> {
48 fn matmul<'a, La, Lb>(
49 &self,
50 a: &'a DSlice<T, 2, La>,
51 b: &'a DSlice<T, 2, Lb>,
52 ) -> impl MatMulBuilder<'a, T, La, Lb>
53 where
54 T: One,
55 La: Layout,
56 Lb: Layout;
57
58 fn contract_all<'a, La, Lb>(
60 &self,
61 a: &'a Slice<T, DynRank, La>,
62 b: &'a Slice<T, DynRank, Lb>,
63 ) -> impl ContractBuilder<'a, T, La, Lb>
64 where
65 T: 'a,
66 La: Layout,
67 Lb: Layout;
68
69 fn contract_n<'a, La, Lb>(
73 &self,
74 a: &'a Slice<T, DynRank, La>,
75 b: &'a Slice<T, DynRank, Lb>,
76 n: usize,
77 ) -> impl ContractBuilder<'a, T, La, Lb>
78 where
79 T: 'a,
80 La: Layout,
81 Lb: Layout;
82
83 fn contract<'a, La, Lb>(
88 &self,
89 a: &'a Slice<T, DynRank, La>,
90 b: &'a Slice<T, DynRank, Lb>,
91 axes_a: impl Into<Box<[usize]>>,
92 axes_b: impl Into<Box<[usize]>>,
93 ) -> impl ContractBuilder<'a, T, La, Lb>
94 where
95 T: 'a,
96 La: Layout,
97 Lb: Layout;
98}
99
100pub trait MatMulBuilder<'a, T, La, Lb>
102where
103 La: Layout,
104 Lb: Layout,
105 T: 'a,
106 La: 'a,
107 Lb: 'a,
108{
109 fn parallelize(self) -> Self;
111
112 fn scale(self, factor: T) -> Self;
114
115 fn eval(self) -> DTensor<T, 2>;
117
118 fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
120
121 fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>);
123
124 fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T);
127
128 fn special(self, lr: Side, type_of_matrix: Type, tr: Triangle) -> DTensor<T, 2>;
147}
148
149pub trait ContractBuilder<'a, T, La, Lb>
151where
152 T: 'a,
153 La: Layout,
154 Lb: Layout,
155{
156 fn scale(self, factor: T) -> Self;
158
159 fn eval(self) -> Tensor<T, DynRank>;
161
162 fn overwrite(self, c: &mut Slice<T>);
164}
165
166pub enum Axes {
167 All,
168 LastFirst { k: usize },
169 Specific(Box<[usize]>, Box<[usize]>),
170}
171
172pub fn _contract<T: Zero + ComplexFloat, La: Layout, Lb: Layout>(
174 bd: impl MatMul<T>,
175 a: &Slice<T, DynRank, La>,
176 b: &Slice<T, DynRank, Lb>,
177 axes: Axes,
178 alpha: T,
179) -> Tensor<T, DynRank> {
180 let rank_a = a.rank();
181 let rank_b = b.rank();
182
183 let extract_shape = |s: &DynRank| match s {
184 DynRank::Dyn(arr) => arr.clone(),
185 DynRank::One(n) => Box::new([*n]),
186 };
187 let shape_a = extract_shape(a.shape());
188 let shape_b = extract_shape(b.shape());
189
190 let (axes_a, axes_b) = match axes {
191 Axes::All => ((0..rank_a).collect(), (0..rank_b).collect()),
192 Axes::LastFirst { k } => (((rank_a - k)..rank_a).collect(), (0..k).collect()),
193 Axes::Specific(ax_a, ax_b) => (ax_a, ax_b),
194 };
195
196 assert_eq!(
197 axes_a.len(),
198 axes_b.len(),
199 "Axis count mismatch: {} (tensor A) vs {} (tensor B)",
200 axes_a.len(),
201 axes_b.len()
202 );
203
204 axes_a.iter().zip(&axes_b).for_each(|(a_ax, b_ax)| {
205 assert_eq!(
206 shape_a[*a_ax], shape_b[*b_ax],
207 "Dimension mismatch at contraction: A[axis {}] = {} ≠ B[axis {}] = {}",
208 *a_ax, shape_a[*a_ax], *b_ax, shape_b[*b_ax]
209 );
210 });
211
212 let compute_keep_axes = |rank: usize, axes: &[usize]| -> Vec<usize> {
213 (0..rank).filter(|k| !axes.contains(k)).collect()
214 };
215 let keep_axes_a = compute_keep_axes(rank_a, &axes_a);
216 let keep_axes_b = compute_keep_axes(rank_b, &axes_b);
217 let compute_keep_shape = |axes: &[usize], shape: &[usize]| -> Vec<usize> {
218 axes.iter().map(|&ax| shape[ax]).collect()
219 };
220
221 let mut keep_shape_a = compute_keep_shape(&keep_axes_a, &shape_a);
222 let keep_shape_b = compute_keep_shape(&keep_axes_b, &shape_b);
223
224 let compute_size =
225 |axes: &[usize], shape: &[usize]| -> usize { axes.iter().map(|&k| shape[k]).product() };
226
227 let contract_size_a = compute_size(&axes_a, &shape_a);
228 let contract_size_b = compute_size(&axes_b, &shape_b);
229 let keep_size_a = compute_size(&keep_axes_a, &shape_a);
230 let keep_size_b = compute_size(&keep_axes_b, &shape_b);
231
232 let order_a: Vec<usize> = keep_axes_a.iter().chain(axes_a.iter()).copied().collect();
233 let order_b: Vec<usize> = axes_b.iter().chain(keep_axes_b.iter()).copied().collect();
234
235 let trans_a = a.permute(order_a).to_tensor();
236 let trans_b = b.permute(order_b).to_tensor();
237
238 let a_resh = trans_a.reshape([keep_size_a, contract_size_a]);
239 let b_resh = trans_b.reshape([contract_size_b, keep_size_b]);
240
241 let ab_resh = bd.matmul(&a_resh, &b_resh).scale(alpha).eval();
242
243 if keep_shape_a.is_empty() && keep_shape_b.is_empty() {
244 ab_resh.to_owned().into_dyn()
245 } else if keep_shape_a.is_empty() {
246 ab_resh
247 .view(0, ..)
248 .reshape(keep_shape_a)
249 .to_owned()
250 .into_dyn()
251 .into()
252 } else if keep_shape_b.is_empty() {
253 ab_resh
254 .view(.., 0)
255 .reshape(keep_shape_b)
256 .to_owned()
257 .into_dyn()
258 .into()
259 } else {
260 keep_shape_a.extend(keep_shape_b);
261 ab_resh.reshape(keep_shape_a).to_owned().into_dyn().into()
262 }
263}