1use core::slice::from_raw_parts_mut;
2
3use num_traits::{One, Zero};
4use seq_macro::seq;
5
6use crate::simd::{Boilerplate, MixedSimd, Simd};
7
8#[inline(always)]
9pub unsafe fn gemv<
10 T: Copy
11 + Zero
12 + One
13 + Send
14 + Sync
15 + core::ops::Add<Output = T>
16 + core::ops::Mul<Output = T>
17 + core::cmp::PartialEq,
18 S: Simd,
19>(
20 _simd: S,
21 m: usize,
22 n: usize,
23 k: usize,
24 dst: *mut T,
25 dst_cs: isize,
26 dst_rs: isize,
27 lhs: *const T,
28 lhs_cs: isize,
29 lhs_rs: isize,
30 rhs: *const T,
31 rhs_cs: isize,
32 rhs_rs: isize,
33 alpha: T,
34 beta: T,
35 mul_add: impl Fn(T, T, T) -> T,
36) {
37 if !alpha.is_zero() {
38 for col in 0..n {
39 for row in 0..m {
40 let dst = dst
41 .wrapping_offset(row as isize * dst_rs)
42 .wrapping_offset(col as isize * dst_cs);
43
44 *dst = alpha * *dst;
45 }
46 }
47 } else {
48 for col in 0..n {
49 for row in 0..m {
50 let dst = dst
51 .wrapping_offset(row as isize * dst_rs)
52 .wrapping_offset(col as isize * dst_cs);
53
54 *dst = T::zero();
55 }
56 }
57 }
58
59 macro_rules! do_work {
60 ($n: tt) => {
61 for depth in 0..k {
62 seq!(COL in 0..$n {
63 let rhs~COL = beta * *rhs
64 .wrapping_offset(COL as isize * rhs_cs)
65 .wrapping_offset(depth as isize * rhs_rs);
66 });
67 for row in 0..m {
68 let lhs = *lhs
69 .wrapping_offset(depth as isize * lhs_cs)
70 .wrapping_offset(row as isize * lhs_rs);
71
72 seq!(COL in 0..$n {
73 {
74 let dst = dst
75 .wrapping_offset(COL as isize * dst_cs)
76 .wrapping_offset(row as isize * dst_rs);
77 *dst = mul_add(rhs~COL, lhs, *dst);
78 }
79 });
80 }
81 }
82 }
83 }
84 match n {
85 1 => do_work!(1),
86 _ => unreachable!(),
87 }
88}
89
90#[inline(always)]
93pub unsafe fn mixed_gemv_colmajor<
94 Lhs: Boilerplate + One + Zero,
95 Rhs: Boilerplate + One + Zero,
96 Dst: Boilerplate + One + Zero,
97 Acc: Boilerplate + One + Zero,
98 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
99>(
100 simd: S,
101
102 m: usize,
103 n: usize,
104 k: usize,
105
106 dst: *mut Dst,
107 dst_cs: isize,
108 dst_rs: isize,
109
110 lhs: *const Lhs,
111 lhs_cs: isize,
112 lhs_rs: isize,
113
114 rhs: *const Rhs,
115 rhs_cs: isize,
116 rhs_rs: isize,
117
118 alpha: Acc,
119 beta: Acc,
120) {
121 #[inline(always)]
122 unsafe fn implementation<
123 'a,
124 Lhs: Boilerplate + One + Zero,
125 Rhs: Boilerplate + One + Zero,
126 Dst: Boilerplate + One + Zero,
127 Acc: Boilerplate + One + Zero,
128 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
129 >(
130 noalias_dst: (&'a mut [Dst],),
131 simd: S,
132 m: usize,
133 k: usize,
134 lhs: *const Lhs,
135 lhs_cs: isize,
136 rhs: *const Rhs,
137 rhs_cs: isize,
138 rhs_rs: isize,
139 alpha: Acc,
140 beta: Acc,
141 ) {
142 #[allow(dead_code)]
143 struct Impl<'a, Lhs, Rhs, Dst, Acc, S> {
144 simd: S,
145 m: usize,
146 k: usize,
147 noalias_dst: (&'a mut [Dst],),
148 lhs: *const Lhs,
149 lhs_cs: isize,
150 rhs: *const Rhs,
151 rhs_cs: isize,
152 rhs_rs: isize,
153 alpha: Acc,
154 beta: Acc,
155 }
156 impl<
157 Lhs: Boilerplate + One + Zero,
158 Rhs: Boilerplate + One + Zero,
159 Dst: Boilerplate + One + Zero,
160 Acc: Boilerplate + One + Zero,
161 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
162 > pulp::NullaryFnOnce for Impl<'_, Lhs, Rhs, Dst, Acc, S>
163 {
164 type Output = ();
165
166 #[inline(always)]
167 fn call(self) -> Self::Output {
168 unsafe {
169 let Self {
170 simd,
171 m,
172 k,
173 noalias_dst,
174 lhs,
175 lhs_cs,
176 rhs,
177 rhs_cs: _,
178 rhs_rs,
179 mut alpha,
180 beta,
181 } = self;
182
183 let lane = S::SIMD_WIDTH;
184 let dst = noalias_dst.0.as_mut_ptr();
185 let m_lane = m / lane * lane;
186 for col in 0..k {
187 let lhs = lhs.wrapping_offset(col as isize * lhs_cs);
188 let rhs = simd.from_rhs(*rhs.wrapping_offset(col as isize * rhs_rs));
189
190 let alpha_s = alpha;
191 let alpha_v = simd.simd_splat(alpha_s);
192
193 let rhs_scalar = simd.mult(beta, rhs);
194 let rhs = simd.simd_splat(rhs_scalar);
195
196 if alpha_s.is_zero() {
197 let mut row = 0usize;
198 while row < m_lane {
199 let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
200 let lhs =
201 simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
202 *dst_ptr = simd.simd_into_dst(simd.simd_mul(lhs, rhs));
203 row += lane;
204 }
205 while row < m {
206 let dst_ptr = dst.wrapping_add(row);
207 let lhs = simd.from_lhs(*lhs.wrapping_add(row));
208 *dst_ptr = simd.into_dst(simd.mult(lhs, rhs_scalar));
209 row += 1;
210 }
211 } else if alpha_s.is_one() {
212 let mut row = 0usize;
213 while row < m_lane {
214 let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
215 let dst = *dst_ptr;
216 let lhs =
217 simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
218 *dst_ptr = simd.simd_into_dst(simd.simd_mult_add(
219 lhs,
220 rhs,
221 simd.simd_from_dst(dst),
222 ));
223 row += lane;
224 }
225 while row < m {
226 let dst_ptr = dst.wrapping_add(row);
227 let dst = *dst_ptr;
228 let lhs = simd.from_lhs(*lhs.wrapping_add(row));
229 *dst_ptr = simd.into_dst(simd.mult_add(
230 lhs,
231 rhs_scalar,
232 simd.from_dst(dst),
233 ));
234 row += 1;
235 }
236 } else {
237 let mut row = 0usize;
238 while row < m_lane {
239 let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
240 let dst = *dst_ptr;
241 let lhs =
242 simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
243 *dst_ptr = simd.simd_into_dst(simd.simd_add(
244 simd.simd_mul(lhs, rhs),
245 simd.simd_mul(alpha_v, simd.simd_from_dst(dst)),
246 ));
247 row += lane;
248 }
249 while row < m {
250 let dst_ptr = dst.wrapping_add(row);
251 let dst = *dst_ptr;
252 let lhs = simd.from_lhs(*lhs.wrapping_add(row));
253 *dst_ptr = simd.into_dst(simd.add(
254 simd.mult(lhs, rhs_scalar),
255 simd.mult(alpha_s, simd.from_dst(dst)),
256 ));
257 row += 1;
258 }
259 }
260 alpha = Acc::one();
261 }
262 }
263 }
264 }
265
266 simd.vectorize(Impl {
267 simd,
268 m,
269 k,
270 noalias_dst,
271 lhs,
272 lhs_cs,
273 rhs,
274 rhs_cs,
275 rhs_rs,
276 alpha,
277 beta,
278 })
279 }
280
281 assert_eq!(lhs_rs, 1);
282 assert_eq!(dst_rs, 1);
283
284 if k == 0 {
285 if alpha.is_one() {
286 return;
287 }
288 if alpha.is_zero() {
289 for j in 0..n {
290 core::ptr::write_bytes(dst.wrapping_offset(j as isize * dst_cs), 0u8, m);
291 }
292 return;
293 }
294
295 for j in 0..n {
296 let dst = dst.wrapping_offset(j as isize * dst_cs);
297 for i in 0..m {
298 let dst = dst.add(i);
299 *dst = simd.into_dst(simd.mult(simd.from_dst(*dst), alpha));
300 }
301 }
302 }
303
304 for x in 0..n {
305 implementation(
306 (from_raw_parts_mut(
307 dst.wrapping_offset(x as isize * dst_cs) as _,
308 m,
309 ),),
310 simd,
311 m,
312 k,
313 lhs,
314 lhs_cs,
315 rhs.wrapping_offset(rhs_cs * x as isize),
316 rhs_cs,
317 rhs_rs,
318 alpha,
319 beta,
320 );
321 }
322}
323
324#[inline(always)]
328pub unsafe fn mixed_gemv_rowmajor<
329 Lhs: Boilerplate + One + Zero,
330 Rhs: Boilerplate + One + Zero,
331 Dst: Boilerplate + One + Zero,
332 Acc: Boilerplate + One + Zero,
333 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
334>(
335 simd: S,
336
337 m: usize,
338 n: usize,
339 k: usize,
340
341 dst: *mut Dst,
342 dst_cs: isize,
343 dst_rs: isize,
344
345 lhs: *const Lhs,
346 lhs_cs: isize,
347 lhs_rs: isize,
348
349 rhs: *const Rhs,
350 rhs_cs: isize,
351 rhs_rs: isize,
352
353 alpha: Acc,
354 beta: Acc,
355) {
356 #[inline(always)]
357 unsafe fn implementation<
358 'a,
359 Lhs: Boilerplate + One + Zero,
360 Rhs: Boilerplate + One + Zero,
361 Dst: Boilerplate + One + Zero,
362 Acc: Boilerplate + One + Zero,
363 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
364 >(
365 simd: S,
366 dst: *mut Dst,
367 dst_rs: isize,
368 m: usize,
369 k: usize,
370 lhs: *const Lhs,
371 lhs_rs: isize,
372 rhs: *const Rhs,
373 alpha: Acc,
374 beta: Acc,
375 ) {
376 #[allow(dead_code)]
377 struct Impl<Lhs, Rhs, Dst, Acc, S> {
378 simd: S,
379 dst: *mut Dst,
380 dst_rs: isize,
381 m: usize,
382 k: usize,
383 lhs: *const Lhs,
384 lhs_rs: isize,
385 rhs: *const Rhs,
386 alpha: Acc,
387 beta: Acc,
388 }
389 impl<
390 Lhs: Boilerplate + One + Zero,
391 Rhs: Boilerplate + One + Zero,
392 Dst: Boilerplate + One + Zero,
393 Acc: Boilerplate + One + Zero,
394 S: MixedSimd<Lhs, Rhs, Dst, Acc>,
395 > pulp::NullaryFnOnce for Impl<Lhs, Rhs, Dst, Acc, S>
396 {
397 type Output = ();
398
399 #[inline(always)]
400 fn call(self) -> Self::Output {
401 unsafe {
402 let Self {
403 simd,
404 dst,
405 dst_rs,
406 m,
407 k,
408 lhs,
409 lhs_rs,
410 rhs,
411 alpha,
412 beta,
413 } = self;
414
415 let lane = S::SIMD_WIDTH;
416 let lane8 = 8 * S::SIMD_WIDTH;
417
418 let k_lane = k / lane * lane;
419 let k_lane8 = k / lane8 * lane8;
420
421 for row in 0..m {
422 let lhs = lhs.wrapping_offset(row as isize * lhs_rs);
423
424 let mut depth = 0;
425
426 let mut acc0 = simd.simd_splat(Acc::zero());
427 let mut acc1 = simd.simd_splat(Acc::zero());
428 let mut acc2 = simd.simd_splat(Acc::zero());
429 let mut acc3 = simd.simd_splat(Acc::zero());
430 let mut acc4 = simd.simd_splat(Acc::zero());
431 let mut acc5 = simd.simd_splat(Acc::zero());
432 let mut acc6 = simd.simd_splat(Acc::zero());
433 let mut acc7 = simd.simd_splat(Acc::zero());
434
435 while depth < k_lane8 {
436 let lhs0 = *(lhs.wrapping_add(depth + lane * 0) as *const S::LhsN);
437 let rhs0 = *(rhs.wrapping_add(depth + lane * 0) as *const S::RhsN);
438 acc0 = simd.simd_mult_add(
439 simd.simd_from_lhs(lhs0),
440 simd.simd_from_rhs(rhs0),
441 acc0,
442 );
443
444 let lhs1 = *(lhs.wrapping_add(depth + lane * 1) as *const S::LhsN);
445 let rhs1 = *(rhs.wrapping_add(depth + lane * 1) as *const S::RhsN);
446 acc1 = simd.simd_mult_add(
447 simd.simd_from_lhs(lhs1),
448 simd.simd_from_rhs(rhs1),
449 acc1,
450 );
451
452 let lhs2 = *(lhs.wrapping_add(depth + lane * 2) as *const S::LhsN);
453 let rhs2 = *(rhs.wrapping_add(depth + lane * 2) as *const S::RhsN);
454 acc2 = simd.simd_mult_add(
455 simd.simd_from_lhs(lhs2),
456 simd.simd_from_rhs(rhs2),
457 acc2,
458 );
459
460 let lhs3 = *(lhs.wrapping_add(depth + lane * 3) as *const S::LhsN);
461 let rhs3 = *(rhs.wrapping_add(depth + lane * 3) as *const S::RhsN);
462 acc3 = simd.simd_mult_add(
463 simd.simd_from_lhs(lhs3),
464 simd.simd_from_rhs(rhs3),
465 acc3,
466 );
467
468 let lhs4 = *(lhs.wrapping_add(depth + lane * 4) as *const S::LhsN);
469 let rhs4 = *(rhs.wrapping_add(depth + lane * 4) as *const S::RhsN);
470 acc4 = simd.simd_mult_add(
471 simd.simd_from_lhs(lhs4),
472 simd.simd_from_rhs(rhs4),
473 acc4,
474 );
475
476 let lhs5 = *(lhs.wrapping_add(depth + lane * 5) as *const S::LhsN);
477 let rhs5 = *(rhs.wrapping_add(depth + lane * 5) as *const S::RhsN);
478 acc5 = simd.simd_mult_add(
479 simd.simd_from_lhs(lhs5),
480 simd.simd_from_rhs(rhs5),
481 acc5,
482 );
483
484 let lhs6 = *(lhs.wrapping_add(depth + lane * 6) as *const S::LhsN);
485 let rhs6 = *(rhs.wrapping_add(depth + lane * 6) as *const S::RhsN);
486 acc6 = simd.simd_mult_add(
487 simd.simd_from_lhs(lhs6),
488 simd.simd_from_rhs(rhs6),
489 acc6,
490 );
491
492 let lhs7 = *(lhs.wrapping_add(depth + lane * 7) as *const S::LhsN);
493 let rhs7 = *(rhs.wrapping_add(depth + lane * 7) as *const S::RhsN);
494 acc7 = simd.simd_mult_add(
495 simd.simd_from_lhs(lhs7),
496 simd.simd_from_rhs(rhs7),
497 acc7,
498 );
499
500 depth += lane8;
501 }
502
503 let acc0 = simd.simd_add(acc0, acc1);
504 let acc2 = simd.simd_add(acc2, acc3);
505 let acc4 = simd.simd_add(acc4, acc5);
506 let acc6 = simd.simd_add(acc6, acc7);
507
508 let acc0 = simd.simd_add(acc0, acc2);
509 let acc4 = simd.simd_add(acc4, acc6);
510
511 let mut acc0 = simd.simd_add(acc0, acc4);
512
513 while depth < k_lane {
514 let lhs0 = *(lhs.wrapping_add(depth) as *const S::LhsN);
515 let rhs0 = *(rhs.wrapping_add(depth) as *const S::RhsN);
516 acc0 = simd.simd_mult_add(
517 simd.simd_from_lhs(lhs0),
518 simd.simd_from_rhs(rhs0),
519 acc0,
520 );
521
522 depth += lane;
523 }
524
525 let acc_ptr = &acc0 as *const _ as *const Acc;
526 let mut acc0 = *acc_ptr;
527 for x in 1..S::SIMD_WIDTH {
528 acc0 = simd.add(acc0, *acc_ptr.add(x));
529 }
530
531 while depth < k {
532 let lhs0 = *(lhs.wrapping_add(depth + 0));
533 let rhs0 = *(rhs.wrapping_add(depth + 0));
534
535 acc0 = simd.mult_add(simd.from_lhs(lhs0), simd.from_rhs(rhs0), acc0);
536
537 depth += 1;
538 }
539
540 if alpha.is_zero() {
541 let dst = dst.wrapping_offset(dst_rs * row as isize);
542 *dst = simd.into_dst(simd.mult(acc0, beta));
543 } else {
544 let dst = dst.wrapping_offset(dst_rs * row as isize);
545 *dst =
546 simd.into_dst(simd.add(
547 simd.mult(acc0, beta),
548 simd.mult(simd.from_dst(*dst), alpha),
549 ));
550 }
551 }
552 }
553 }
554 }
555
556 simd.vectorize(Impl {
557 simd,
558 dst,
559 dst_rs,
560 m,
561 k,
562 lhs,
563 lhs_rs,
564 rhs,
565 alpha,
566 beta,
567 })
568 }
569
570 assert_eq!(lhs_cs, 1);
571 assert_eq!(rhs_rs, 1);
572
573 for x in 0..n {
574 implementation(
575 simd,
576 dst.wrapping_offset(x as isize * dst_cs),
577 dst_rs,
578 m,
579 k,
580 lhs,
581 lhs_rs,
582 rhs.wrapping_offset(rhs_cs * x as isize),
583 alpha,
584 beta,
585 );
586 }
587}