1use num_complex::ComplexFloat;
2use num_traits::{One, Zero};
3
4use mdarray::{DSlice, DTensor, DynRank, Layout, Slice, Tensor, tensor};
5
6use crate::matmul::{Axes, Side, Triangle, Type, _contract};
7use crate::prelude::*;
8
9use crate::Naive;
10
11use super::simple::naive_matmul;
12
13struct NaiveMatMulBuilder<'a, T, La, Lb>
14where
15 La: Layout,
16 Lb: Layout,
17{
18 alpha: T,
19 a: &'a DSlice<T, 2, La>,
20 b: &'a DSlice<T, 2, Lb>,
21}
22
23struct NaiveContractBuilder<'a, T, La, Lb>
24where
25 La: Layout,
26 Lb: Layout,
27{
28 alpha: T,
29 a: &'a Slice<T, DynRank, La>,
30 b: &'a Slice<T, DynRank, Lb>,
31 axes: Axes,
32}
33
34impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for NaiveMatMulBuilder<'a, T, La, Lb>
35where
36 La: Layout,
37 Lb: Layout,
38 T: ComplexFloat + Zero + One,
39 {
42 fn parallelize(self) -> Self {
44 self
45 }
46
47 fn scale(mut self, factor: T) -> Self {
49 self.alpha = self.alpha * factor;
50 self
51 }
52
53 fn eval(self) -> DTensor<T, 2> {
55 let (m, _) = *self.a.shape();
56 let (_, n) = *self.b.shape();
57 let mut c = tensor![[T::zero(); n]; m];
58 naive_matmul(self.alpha, self.a, self.b, T::zero(), &mut c);
59 c
60 }
61
62 fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
64 naive_matmul(self.alpha, self.a, self.b, T::zero(), c);
65 }
66
67 fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
69 naive_matmul(self.alpha, self.a, self.b, T::one(), c);
70 }
71
72 fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T) {
75 naive_matmul(self.alpha, self.a, self.b, beta, c);
76 }
77
78 fn special(self, _lr: Side, _type_of_matrix: Type, _tr: Triangle) -> DTensor<T, 2> {
97 todo!()
98 }
99}
100
101impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for NaiveContractBuilder<'a, T, La, Lb>
102where
103 La: Layout,
104 Lb: Layout,
105 T: ComplexFloat + Zero + One,
106{
107 fn scale(mut self, factor: T) -> Self {
108 self.alpha = self.alpha * factor;
109 self
110 }
111
112 fn eval(self) -> Tensor<T> {
113 _contract(Naive, self.a, self.b, self.axes, self.alpha)
114 }
115
116 fn overwrite(self, _c: &mut Slice<T>) {
117 todo!()
118 }
119}
120
121impl<T> MatMul<T> for Naive
122where
123 T: ComplexFloat,
124 {
127 fn matmul<'a, La, Lb>(
128 &self,
129 a: &'a DSlice<T, 2, La>,
130 b: &'a DSlice<T, 2, Lb>,
131 ) -> impl MatMulBuilder<'a, T, La, Lb>
132 where
133 La: Layout,
134 Lb: Layout,
135 {
136 NaiveMatMulBuilder {
137 alpha: T::one(),
138 a,
139 b,
140 }
141 }
142
143 fn contract_all<'a, La, Lb>(
145 &self,
146 a: &'a Slice<T, DynRank, La>,
147 b: &'a Slice<T, DynRank, Lb>,
148 ) -> impl ContractBuilder<'a, T, La, Lb>
149 where
150 T: 'a,
151 La: Layout,
152 Lb: Layout,
153 {
154 NaiveContractBuilder {
155 alpha: T::one(),
156 a,
157 b,
158 axes: Axes::All,
159 }
160 }
161
162 fn contract_n<'a, La, Lb>(
166 &self,
167 a: &'a Slice<T, DynRank, La>,
168 b: &'a Slice<T, DynRank, Lb>,
169 n: usize,
170 ) -> impl ContractBuilder<'a, T, La, Lb>
171 where
172 T: 'a,
173 La: Layout,
174 Lb: Layout,
175 {
176 NaiveContractBuilder {
177 alpha: T::one(),
178 a,
179 b,
180 axes: Axes::LastFirst { k: (n) },
181 }
182 }
183
184 fn contract<'a, La, Lb>(
189 &self,
190 a: &'a Slice<T, DynRank, La>,
191 b: &'a Slice<T, DynRank, Lb>,
192 axes_a: impl Into<Box<[usize]>>,
193 axes_b: impl Into<Box<[usize]>>,
194 ) -> impl ContractBuilder<'a, T, La, Lb>
195 where
196 T: 'a,
197 La: Layout,
198 Lb: Layout,
199 {
200 NaiveContractBuilder {
201 alpha: T::one(),
202 a,
203 b,
204 axes: Axes::Specific(axes_a.into(), axes_b.into()),
205 }
206 }
207}