1use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceCpuSerial
10where
11 TA: Clone,
12 TB: Clone,
13 TC: Clone,
14 DA: DimAPI,
15 DB: DimAPI,
16 DC: DimAPI,
17 TA: Mul<TB, Output = TC>,
18 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
19 Self: DeviceGEMMAPI<TA, TB, TC>
20 + DeviceGEMVAPI<TA, TB, TC>
21 + DeviceInnerDotAPI<TA, TB, TC>
22 + DeviceAPI<TA, Raw = Vec<TA>>
23 + DeviceAPI<TB, Raw = Vec<TB>>
24 + DeviceAPI<TC, Raw = Vec<TC>>,
25{
26 fn matmul(
27 &self,
28 c: &mut Vec<TC>,
29 lc: &Layout<DC>,
30 a: &Vec<TA>,
31 la: &Layout<DA>,
32 b: &Vec<TB>,
33 lb: &Layout<DB>,
34 alpha: TC,
35 beta: TC,
36 ) -> Result<()> {
37 let default_order = self.default_order();
38 match (la.ndim(), lb.ndim(), lc.ndim()) {
39 (1, 1, 0) => {
40 let la = &la.clone().into_dim::<Ix1>().unwrap();
42 let lb = &lb.clone().into_dim::<Ix1>().unwrap();
43 let lc = &lc.clone().into_dim::<Ix0>().unwrap();
44 self.inner_dot(c, lc, a, la, b, lb, alpha, beta)?;
45 },
46 (2, 2, 2) => {
47 let la = &la.clone().into_dim::<Ix2>().unwrap();
49 let lb = &lb.clone().into_dim::<Ix2>().unwrap();
50 let lc = &lc.clone().into_dim::<Ix2>().unwrap();
51 self.gemm(c, lc, a, la, b, lb, alpha, beta)?;
52 },
53 (2, 1, 1) => {
54 let la = &la.clone().into_dim::<Ix2>().unwrap();
56 let lb = &lb.clone().into_dim::<Ix1>().unwrap();
57 let lc = &lc.clone().into_dim::<Ix1>().unwrap();
58 self.gemv(c, lc, a, la, b, lb, alpha, beta)?;
59 },
60 (1, 2, 1) => {
61 let la = &la.clone().into_dim::<Ix1>().unwrap();
63 let lb = &lb.clone().into_dim::<Ix2>().unwrap();
64 let lc = &lc.clone().into_dim::<Ix1>().unwrap();
65 self.gevm(c, lc, a, la, b, lb, alpha, beta)?;
66 },
67 (1, 2.., _) => {
68 rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
70 if default_order == ColMajor && lb.ndim() > 2 {
71 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
72 }
73 let la = &la.clone().into_dim::<Ix1>().unwrap();
74 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
75 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
76 let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
77 let lc_matmul = &mut lc_matmul.into_dim::<Ix1>()?;
78 let l_rest = translate_to_col_major(&[&lc_rest, &lb_rest], TensorIterOrder::K)?;
79 let (lc_rest, lb_rest) = (&l_rest[0], &l_rest[1]);
80 let itb_rest = IterLayoutColMajor::new(lb_rest)?;
81 let itc_rest = IterLayoutColMajor::new(lc_rest)?;
82 for (ib_rest, ic_rest) in izip!(itb_rest, itc_rest) {
83 unsafe { lb_matmul.set_offset(ib_rest) };
84 unsafe { lc_matmul.set_offset(ic_rest) };
85 self.gevm(c, lc_matmul, a, la, b, lb_matmul, alpha.clone(), beta.clone())?;
86 }
87 },
88 (2.., 1, _) => {
89 rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
91 if default_order == ColMajor && la.ndim() > 2 {
92 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
93 }
94 let lb = &lb.clone().into_dim::<Ix1>().unwrap();
95 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
96 let (lc_rest, lc_matmul) = lc.dim_split_at(-1)?;
97 let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
98 let lc_matmul = &mut lc_matmul.into_dim::<Ix1>()?;
99 let l_rest = translate_to_col_major(&[&lc_rest, &la_rest], TensorIterOrder::K)?;
100 let (lc_rest, la_rest) = (&l_rest[0], &l_rest[1]);
101 let ita_rest = IterLayoutColMajor::new(la_rest)?;
102 let itc_rest = IterLayoutColMajor::new(lc_rest)?;
103 for (ib_rest, ic_rest) in izip!(ita_rest, itc_rest) {
104 unsafe { la_matmul.set_offset(ib_rest) };
105 unsafe { lc_matmul.set_offset(ic_rest) };
106 self.gemv(c, lc_matmul, a, la_matmul, b, lb, alpha.clone(), beta.clone())?;
107 }
108 },
109 (2, 3.., _) => {
110 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
112 if default_order == ColMajor {
113 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
114 }
115 let la = &la.clone().into_dim::<Ix2>().unwrap();
116 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
117 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
118 let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
119 let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
120 let l_rest = translate_to_col_major(&[&lc_rest, &lb_rest], TensorIterOrder::K)?;
121 let (lc_rest, lb_rest) = (&l_rest[0], &l_rest[1]);
122 let itb_rest = IterLayoutColMajor::new(lb_rest)?;
123 let itc_rest = IterLayoutColMajor::new(lc_rest)?;
124 for (ib_rest, ic_rest) in izip!(itb_rest, itc_rest) {
125 unsafe { lb_matmul.set_offset(ib_rest) };
126 unsafe { lc_matmul.set_offset(ic_rest) };
127 self.gemm(c, lc_matmul, a, la, b, lb_matmul, alpha.clone(), beta.clone())?;
128 }
129 },
130 (3.., 2, _) => {
131 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
133 if default_order == ColMajor {
134 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
135 }
136 let lb = &lb.clone().into_dim::<Ix2>().unwrap();
137 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
138 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
139 let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
140 let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
141 let l_rest = translate_to_col_major(&[&lc_rest, &la_rest], TensorIterOrder::K)?;
142 let (lc_rest, la_rest) = (&l_rest[0], &l_rest[1]);
143 let ita_rest = IterLayoutColMajor::new(la_rest)?;
144 let itc_rest = IterLayoutColMajor::new(lc_rest)?;
145 for (ib_rest, ic_rest) in izip!(ita_rest, itc_rest) {
146 unsafe { la_matmul.set_offset(ib_rest) };
147 unsafe { lc_matmul.set_offset(ic_rest) };
148 self.gemm(c, lc_matmul, a, la_matmul, b, lb, alpha.clone(), beta.clone())?;
149 }
150 },
151 (3.., 3.., _) => {
152 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
154 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
155 if default_order == ColMajor {
156 rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")?;
157 }
158 let (la_rest, la_matmul) = la.dim_split_at(-2)?;
159 let (lb_rest, lb_matmul) = lb.dim_split_at(-2)?;
160 let (lc_rest, lc_matmul) = lc.dim_split_at(-2)?;
161 let la_matmul = &mut la_matmul.into_dim::<Ix2>()?;
162 let lb_matmul = &mut lb_matmul.into_dim::<Ix2>()?;
163 let lc_matmul = &mut lc_matmul.into_dim::<Ix2>()?;
164 let l_rest =
165 translate_to_col_major(&[&lc_rest, &la_rest, &lb_rest], TensorIterOrder::K)?;
166 let (lc_rest, la_rest, lb_rest) = (&l_rest[0], &l_rest[1], &l_rest[2]);
167 let ita_rest = IterLayoutColMajor::new(la_rest)?;
168 let itb_rest = IterLayoutColMajor::new(lb_rest)?;
169 let itc_rest = IterLayoutColMajor::new(lc_rest)?;
170 for (ia_rest, ib_rest, ic_rest) in izip!(ita_rest, itb_rest, itc_rest) {
171 unsafe { la_matmul.set_offset(ia_rest) };
172 unsafe { lb_matmul.set_offset(ib_rest) };
173 unsafe { lc_matmul.set_offset(ic_rest) };
174 self.gemm(
175 c,
176 lc_matmul,
177 a,
178 la_matmul,
179 b,
180 lb_matmul,
181 alpha.clone(),
182 beta.clone(),
183 )?;
184 }
185 },
186 (0, _, _) | (_, 0, _) | (1, 1, 1..) | (2, 2, 3..) | (2, 2, 0..2) => {
191 rstsr_raise!(
192 InvalidLayout,
193 "Invalid ndim for matmul: {}, {}, {}",
194 la.ndim(),
195 lb.ndim(),
196 lc.ndim()
197 )?;
198 },
199 }
200 return Ok(());
201 }
202}
203
204impl<TA, TB, TC> DeviceGEMMAPI<TA, TB, TC> for DeviceCpuSerial
205where
206 TA: Clone,
207 TB: Clone,
208 TC: Clone,
209 TA: Mul<TB, Output = TC>,
210 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
211{
212 fn gemm(
213 &self,
214 c: &mut Vec<TC>,
215 lc: &Layout<Ix2>,
216 a: &Vec<TA>,
217 la: &Layout<Ix2>,
218 b: &Vec<TB>,
219 lb: &Layout<Ix2>,
220 alpha: TC,
221 beta: TC,
222 ) -> Result<()> {
223 let sc = lc.shape();
225 let sa = la.shape();
226 let sb = lb.shape();
227 rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
228 rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
229 rstsr_assert_eq!(sc[1], sb[1], InvalidLayout)?;
230 let (m, n, k) = (sc[0], sc[1], sa[1]);
231
232 unsafe {
234 for i_m in 0..m {
235 for i_n in 0..n {
236 let idx_c = lc.index_uncheck(&[i_m, i_n]) as usize;
237 c[idx_c] = beta.clone() * c[idx_c].clone();
238 }
239 for i_k in 0..k {
240 let idx_a = la.index_uncheck(&[i_m, i_k]) as usize;
241 for i_n in 0..n {
242 let idx_c = lc.index_uncheck(&[i_m, i_n]) as usize;
243 let idx_b = lb.index_uncheck(&[i_k, i_n]) as usize;
244 c[idx_c] = alpha.clone() * (a[idx_a].clone() * b[idx_b].clone())
245 + c[idx_c].clone();
246 }
247 }
248 }
249 }
250 return Ok(());
251 }
252}
253
254impl<TA, TB, TC> DeviceGEMVAPI<TA, TB, TC> for DeviceCpuSerial
255where
256 TA: Clone,
257 TB: Clone,
258 TC: Clone,
259 TA: Mul<TB, Output = TC>,
260 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
261{
262 fn gemv(
263 &self,
264 c: &mut Vec<TC>,
265 lc: &Layout<Ix1>,
266 a: &Vec<TA>,
267 la: &Layout<Ix2>,
268 b: &Vec<TB>,
269 lb: &Layout<Ix1>,
270 alpha: TC,
271 beta: TC,
272 ) -> Result<()> {
273 let sc = lc.shape();
275 let sa = la.shape();
276 let sb = lb.shape();
277 rstsr_assert_eq!(sc[0], sa[0], InvalidLayout)?;
278 rstsr_assert_eq!(sa[1], sb[0], InvalidLayout)?;
279 let (n, k) = (sa[0], sa[1]);
280
281 unsafe {
283 for i_n in 0..n {
284 let idx_c = lc.index_uncheck(&[i_n]) as usize;
285 c[idx_c] = beta.clone() * c[idx_c].clone();
286 for i_k in 0..k {
287 let idx_a = la.index_uncheck(&[i_n, i_k]) as usize;
288 let idx_b = lb.index_uncheck(&[i_k]) as usize;
289 c[idx_c] =
290 alpha.clone() * (a[idx_a].clone() * b[idx_b].clone()) + c[idx_c].clone();
291 }
292 }
293 }
294 return Ok(());
295 }
296
297 fn gevm(
298 &self,
299 c: &mut Vec<TC>,
300 lc: &Layout<Ix1>,
301 a: &Vec<TA>,
302 la: &Layout<Ix1>,
303 b: &Vec<TB>,
304 lb: &Layout<Ix2>,
305 alpha: TC,
306 beta: TC,
307 ) -> Result<()> {
308 let sc = lc.shape();
310 let sa = la.shape();
311 let sb = lb.shape();
312 rstsr_assert_eq!(sc[0], sb[1], InvalidLayout)?;
313 rstsr_assert_eq!(sa[0], sb[0], InvalidLayout)?;
314 let (n, k) = (sb[1], sb[0]);
315
316 unsafe {
318 for i_n in 0..n {
319 let idx_c = lc.index_uncheck(&[i_n]) as usize;
320 c[idx_c] = beta.clone() * c[idx_c].clone();
321 for i_k in 0..k {
322 let idx_a = la.index_uncheck(&[i_k]) as usize;
323 let idx_b = lb.index_uncheck(&[i_k, i_n]) as usize;
324 c[idx_c] =
325 alpha.clone() * (a[idx_a].clone() * b[idx_b].clone()) + c[idx_c].clone();
326 }
327 }
328 }
329 return Ok(());
330 }
331}
332
333impl<TA, TB, TC> DeviceInnerDotAPI<TA, TB, TC> for DeviceCpuSerial
334where
335 TA: Clone,
336 TB: Clone,
337 TC: Clone,
338 TA: Mul<TB, Output = TC>,
339 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
340{
341 fn inner_dot(
342 &self,
343 c: &mut Vec<TC>,
344 lc: &Layout<Ix0>,
345 a: &Vec<TA>,
346 la: &Layout<Ix1>,
347 b: &Vec<TB>,
348 lb: &Layout<Ix1>,
349 alpha: TC,
350 beta: TC,
351 ) -> Result<()> {
352 let sa = la.shape();
354 let sb = lb.shape();
355 rstsr_assert_eq!(sa[0], sb[0], InvalidLayout)?;
356 let n = sa[0];
357
358 unsafe {
360 let idx_c = lc.index_uncheck(&[]) as usize;
361 let mut sum = beta * c[idx_c].clone();
362 for i in 0..n {
363 let idx_a = la.index_uncheck(&[i]) as usize;
364 let idx_b = lb.index_uncheck(&[i]) as usize;
365 sum = sum + alpha.clone() * (a[idx_a].clone() * b[idx_b].clone());
366 }
367 c[0] = sum;
368 }
369 return Ok(());
370 }
371}