gemm_common/
gemv.rs

1use core::slice::from_raw_parts_mut;
2
3use num_traits::{One, Zero};
4use seq_macro::seq;
5
6use crate::simd::{Boilerplate, MixedSimd, Simd};
7
8#[inline(always)]
9pub unsafe fn gemv<
10    T: Copy
11        + Zero
12        + One
13        + Send
14        + Sync
15        + core::ops::Add<Output = T>
16        + core::ops::Mul<Output = T>
17        + core::cmp::PartialEq,
18    S: Simd,
19>(
20    _simd: S,
21    m: usize,
22    n: usize,
23    k: usize,
24    dst: *mut T,
25    dst_cs: isize,
26    dst_rs: isize,
27    lhs: *const T,
28    lhs_cs: isize,
29    lhs_rs: isize,
30    rhs: *const T,
31    rhs_cs: isize,
32    rhs_rs: isize,
33    alpha: T,
34    beta: T,
35    mul_add: impl Fn(T, T, T) -> T,
36) {
37    if !alpha.is_zero() {
38        for col in 0..n {
39            for row in 0..m {
40                let dst = dst
41                    .wrapping_offset(row as isize * dst_rs)
42                    .wrapping_offset(col as isize * dst_cs);
43
44                *dst = alpha * *dst;
45            }
46        }
47    } else {
48        for col in 0..n {
49            for row in 0..m {
50                let dst = dst
51                    .wrapping_offset(row as isize * dst_rs)
52                    .wrapping_offset(col as isize * dst_cs);
53
54                *dst = T::zero();
55            }
56        }
57    }
58
59    macro_rules! do_work {
60        ($n: tt) => {
61            for depth in 0..k {
62                seq!(COL in 0..$n {
63                    let rhs~COL = beta * *rhs
64                        .wrapping_offset(COL as isize * rhs_cs)
65                        .wrapping_offset(depth as isize * rhs_rs);
66                });
67                for row in 0..m {
68                    let lhs = *lhs
69                        .wrapping_offset(depth as isize * lhs_cs)
70                        .wrapping_offset(row as isize * lhs_rs);
71
72                    seq!(COL in 0..$n {
73                        {
74                            let dst = dst
75                                .wrapping_offset(COL as isize * dst_cs)
76                                .wrapping_offset(row as isize * dst_rs);
77                            *dst = mul_add(rhs~COL, lhs, *dst);
78                        }
79                    });
80                }
81            }
82        }
83    }
84    match n {
85        1 => do_work!(1),
86        _ => unreachable!(),
87    }
88}
89
90// dst, lhs are colmajor
91// n is small
92#[inline(always)]
93pub unsafe fn mixed_gemv_colmajor<
94    Lhs: Boilerplate + One + Zero,
95    Rhs: Boilerplate + One + Zero,
96    Dst: Boilerplate + One + Zero,
97    Acc: Boilerplate + One + Zero,
98    S: MixedSimd<Lhs, Rhs, Dst, Acc>,
99>(
100    simd: S,
101
102    m: usize,
103    n: usize,
104    k: usize,
105
106    dst: *mut Dst,
107    dst_cs: isize,
108    dst_rs: isize,
109
110    lhs: *const Lhs,
111    lhs_cs: isize,
112    lhs_rs: isize,
113
114    rhs: *const Rhs,
115    rhs_cs: isize,
116    rhs_rs: isize,
117
118    alpha: Acc,
119    beta: Acc,
120) {
121    #[inline(always)]
122    unsafe fn implementation<
123        'a,
124        Lhs: Boilerplate + One + Zero,
125        Rhs: Boilerplate + One + Zero,
126        Dst: Boilerplate + One + Zero,
127        Acc: Boilerplate + One + Zero,
128        S: MixedSimd<Lhs, Rhs, Dst, Acc>,
129    >(
130        noalias_dst: (&'a mut [Dst],),
131        simd: S,
132        m: usize,
133        k: usize,
134        lhs: *const Lhs,
135        lhs_cs: isize,
136        rhs: *const Rhs,
137        rhs_cs: isize,
138        rhs_rs: isize,
139        alpha: Acc,
140        beta: Acc,
141    ) {
142        #[allow(dead_code)]
143        struct Impl<'a, Lhs, Rhs, Dst, Acc, S> {
144            simd: S,
145            m: usize,
146            k: usize,
147            noalias_dst: (&'a mut [Dst],),
148            lhs: *const Lhs,
149            lhs_cs: isize,
150            rhs: *const Rhs,
151            rhs_cs: isize,
152            rhs_rs: isize,
153            alpha: Acc,
154            beta: Acc,
155        }
156        impl<
157                Lhs: Boilerplate + One + Zero,
158                Rhs: Boilerplate + One + Zero,
159                Dst: Boilerplate + One + Zero,
160                Acc: Boilerplate + One + Zero,
161                S: MixedSimd<Lhs, Rhs, Dst, Acc>,
162            > pulp::NullaryFnOnce for Impl<'_, Lhs, Rhs, Dst, Acc, S>
163        {
164            type Output = ();
165
166            #[inline(always)]
167            fn call(self) -> Self::Output {
168                unsafe {
169                    let Self {
170                        simd,
171                        m,
172                        k,
173                        noalias_dst,
174                        lhs,
175                        lhs_cs,
176                        rhs,
177                        rhs_cs: _,
178                        rhs_rs,
179                        mut alpha,
180                        beta,
181                    } = self;
182
183                    let lane = S::SIMD_WIDTH;
184                    let dst = noalias_dst.0.as_mut_ptr();
185                    let m_lane = m / lane * lane;
186                    for col in 0..k {
187                        let lhs = lhs.wrapping_offset(col as isize * lhs_cs);
188                        let rhs = simd.from_rhs(*rhs.wrapping_offset(col as isize * rhs_rs));
189
190                        let alpha_s = alpha;
191                        let alpha_v = simd.simd_splat(alpha_s);
192
193                        let rhs_scalar = simd.mult(beta, rhs);
194                        let rhs = simd.simd_splat(rhs_scalar);
195
196                        if alpha_s.is_zero() {
197                            let mut row = 0usize;
198                            while row < m_lane {
199                                let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
200                                let lhs =
201                                    simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
202                                *dst_ptr = simd.simd_into_dst(simd.simd_mul(lhs, rhs));
203                                row += lane;
204                            }
205                            while row < m {
206                                let dst_ptr = dst.wrapping_add(row);
207                                let lhs = simd.from_lhs(*lhs.wrapping_add(row));
208                                *dst_ptr = simd.into_dst(simd.mult(lhs, rhs_scalar));
209                                row += 1;
210                            }
211                        } else if alpha_s.is_one() {
212                            let mut row = 0usize;
213                            while row < m_lane {
214                                let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
215                                let dst = *dst_ptr;
216                                let lhs =
217                                    simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
218                                *dst_ptr = simd.simd_into_dst(simd.simd_mult_add(
219                                    lhs,
220                                    rhs,
221                                    simd.simd_from_dst(dst),
222                                ));
223                                row += lane;
224                            }
225                            while row < m {
226                                let dst_ptr = dst.wrapping_add(row);
227                                let dst = *dst_ptr;
228                                let lhs = simd.from_lhs(*lhs.wrapping_add(row));
229                                *dst_ptr = simd.into_dst(simd.mult_add(
230                                    lhs,
231                                    rhs_scalar,
232                                    simd.from_dst(dst),
233                                ));
234                                row += 1;
235                            }
236                        } else {
237                            let mut row = 0usize;
238                            while row < m_lane {
239                                let dst_ptr = dst.wrapping_add(row) as *mut S::DstN;
240                                let dst = *dst_ptr;
241                                let lhs =
242                                    simd.simd_from_lhs(*(lhs.wrapping_add(row) as *const S::LhsN));
243                                *dst_ptr = simd.simd_into_dst(simd.simd_add(
244                                    simd.simd_mul(lhs, rhs),
245                                    simd.simd_mul(alpha_v, simd.simd_from_dst(dst)),
246                                ));
247                                row += lane;
248                            }
249                            while row < m {
250                                let dst_ptr = dst.wrapping_add(row);
251                                let dst = *dst_ptr;
252                                let lhs = simd.from_lhs(*lhs.wrapping_add(row));
253                                *dst_ptr = simd.into_dst(simd.add(
254                                    simd.mult(lhs, rhs_scalar),
255                                    simd.mult(alpha_s, simd.from_dst(dst)),
256                                ));
257                                row += 1;
258                            }
259                        }
260                        alpha = Acc::one();
261                    }
262                }
263            }
264        }
265
266        simd.vectorize(Impl {
267            simd,
268            m,
269            k,
270            noalias_dst,
271            lhs,
272            lhs_cs,
273            rhs,
274            rhs_cs,
275            rhs_rs,
276            alpha,
277            beta,
278        })
279    }
280
281    assert_eq!(lhs_rs, 1);
282    assert_eq!(dst_rs, 1);
283
284    if k == 0 {
285        if alpha.is_one() {
286            return;
287        }
288        if alpha.is_zero() {
289            for j in 0..n {
290                core::ptr::write_bytes(dst.wrapping_offset(j as isize * dst_cs), 0u8, m);
291            }
292            return;
293        }
294
295        for j in 0..n {
296            let dst = dst.wrapping_offset(j as isize * dst_cs);
297            for i in 0..m {
298                let dst = dst.add(i);
299                *dst = simd.into_dst(simd.mult(simd.from_dst(*dst), alpha));
300            }
301        }
302    }
303
304    for x in 0..n {
305        implementation(
306            (from_raw_parts_mut(
307                dst.wrapping_offset(x as isize * dst_cs) as _,
308                m,
309            ),),
310            simd,
311            m,
312            k,
313            lhs,
314            lhs_cs,
315            rhs.wrapping_offset(rhs_cs * x as isize),
316            rhs_cs,
317            rhs_rs,
318            alpha,
319            beta,
320        );
321    }
322}
323
324// lhs is rowmajor
325// rhs is colmajor
326// n is small
327#[inline(always)]
328pub unsafe fn mixed_gemv_rowmajor<
329    Lhs: Boilerplate + One + Zero,
330    Rhs: Boilerplate + One + Zero,
331    Dst: Boilerplate + One + Zero,
332    Acc: Boilerplate + One + Zero,
333    S: MixedSimd<Lhs, Rhs, Dst, Acc>,
334>(
335    simd: S,
336
337    m: usize,
338    n: usize,
339    k: usize,
340
341    dst: *mut Dst,
342    dst_cs: isize,
343    dst_rs: isize,
344
345    lhs: *const Lhs,
346    lhs_cs: isize,
347    lhs_rs: isize,
348
349    rhs: *const Rhs,
350    rhs_cs: isize,
351    rhs_rs: isize,
352
353    alpha: Acc,
354    beta: Acc,
355) {
356    #[inline(always)]
357    unsafe fn implementation<
358        'a,
359        Lhs: Boilerplate + One + Zero,
360        Rhs: Boilerplate + One + Zero,
361        Dst: Boilerplate + One + Zero,
362        Acc: Boilerplate + One + Zero,
363        S: MixedSimd<Lhs, Rhs, Dst, Acc>,
364    >(
365        simd: S,
366        dst: *mut Dst,
367        dst_rs: isize,
368        m: usize,
369        k: usize,
370        lhs: *const Lhs,
371        lhs_rs: isize,
372        rhs: *const Rhs,
373        alpha: Acc,
374        beta: Acc,
375    ) {
376        #[allow(dead_code)]
377        struct Impl<Lhs, Rhs, Dst, Acc, S> {
378            simd: S,
379            dst: *mut Dst,
380            dst_rs: isize,
381            m: usize,
382            k: usize,
383            lhs: *const Lhs,
384            lhs_rs: isize,
385            rhs: *const Rhs,
386            alpha: Acc,
387            beta: Acc,
388        }
389        impl<
390                Lhs: Boilerplate + One + Zero,
391                Rhs: Boilerplate + One + Zero,
392                Dst: Boilerplate + One + Zero,
393                Acc: Boilerplate + One + Zero,
394                S: MixedSimd<Lhs, Rhs, Dst, Acc>,
395            > pulp::NullaryFnOnce for Impl<Lhs, Rhs, Dst, Acc, S>
396        {
397            type Output = ();
398
399            #[inline(always)]
400            fn call(self) -> Self::Output {
401                unsafe {
402                    let Self {
403                        simd,
404                        dst,
405                        dst_rs,
406                        m,
407                        k,
408                        lhs,
409                        lhs_rs,
410                        rhs,
411                        alpha,
412                        beta,
413                    } = self;
414
415                    let lane = S::SIMD_WIDTH;
416                    let lane8 = 8 * S::SIMD_WIDTH;
417
418                    let k_lane = k / lane * lane;
419                    let k_lane8 = k / lane8 * lane8;
420
421                    for row in 0..m {
422                        let lhs = lhs.wrapping_offset(row as isize * lhs_rs);
423
424                        let mut depth = 0;
425
426                        let mut acc0 = simd.simd_splat(Acc::zero());
427                        let mut acc1 = simd.simd_splat(Acc::zero());
428                        let mut acc2 = simd.simd_splat(Acc::zero());
429                        let mut acc3 = simd.simd_splat(Acc::zero());
430                        let mut acc4 = simd.simd_splat(Acc::zero());
431                        let mut acc5 = simd.simd_splat(Acc::zero());
432                        let mut acc6 = simd.simd_splat(Acc::zero());
433                        let mut acc7 = simd.simd_splat(Acc::zero());
434
435                        while depth < k_lane8 {
436                            let lhs0 = *(lhs.wrapping_add(depth + lane * 0) as *const S::LhsN);
437                            let rhs0 = *(rhs.wrapping_add(depth + lane * 0) as *const S::RhsN);
438                            acc0 = simd.simd_mult_add(
439                                simd.simd_from_lhs(lhs0),
440                                simd.simd_from_rhs(rhs0),
441                                acc0,
442                            );
443
444                            let lhs1 = *(lhs.wrapping_add(depth + lane * 1) as *const S::LhsN);
445                            let rhs1 = *(rhs.wrapping_add(depth + lane * 1) as *const S::RhsN);
446                            acc1 = simd.simd_mult_add(
447                                simd.simd_from_lhs(lhs1),
448                                simd.simd_from_rhs(rhs1),
449                                acc1,
450                            );
451
452                            let lhs2 = *(lhs.wrapping_add(depth + lane * 2) as *const S::LhsN);
453                            let rhs2 = *(rhs.wrapping_add(depth + lane * 2) as *const S::RhsN);
454                            acc2 = simd.simd_mult_add(
455                                simd.simd_from_lhs(lhs2),
456                                simd.simd_from_rhs(rhs2),
457                                acc2,
458                            );
459
460                            let lhs3 = *(lhs.wrapping_add(depth + lane * 3) as *const S::LhsN);
461                            let rhs3 = *(rhs.wrapping_add(depth + lane * 3) as *const S::RhsN);
462                            acc3 = simd.simd_mult_add(
463                                simd.simd_from_lhs(lhs3),
464                                simd.simd_from_rhs(rhs3),
465                                acc3,
466                            );
467
468                            let lhs4 = *(lhs.wrapping_add(depth + lane * 4) as *const S::LhsN);
469                            let rhs4 = *(rhs.wrapping_add(depth + lane * 4) as *const S::RhsN);
470                            acc4 = simd.simd_mult_add(
471                                simd.simd_from_lhs(lhs4),
472                                simd.simd_from_rhs(rhs4),
473                                acc4,
474                            );
475
476                            let lhs5 = *(lhs.wrapping_add(depth + lane * 5) as *const S::LhsN);
477                            let rhs5 = *(rhs.wrapping_add(depth + lane * 5) as *const S::RhsN);
478                            acc5 = simd.simd_mult_add(
479                                simd.simd_from_lhs(lhs5),
480                                simd.simd_from_rhs(rhs5),
481                                acc5,
482                            );
483
484                            let lhs6 = *(lhs.wrapping_add(depth + lane * 6) as *const S::LhsN);
485                            let rhs6 = *(rhs.wrapping_add(depth + lane * 6) as *const S::RhsN);
486                            acc6 = simd.simd_mult_add(
487                                simd.simd_from_lhs(lhs6),
488                                simd.simd_from_rhs(rhs6),
489                                acc6,
490                            );
491
492                            let lhs7 = *(lhs.wrapping_add(depth + lane * 7) as *const S::LhsN);
493                            let rhs7 = *(rhs.wrapping_add(depth + lane * 7) as *const S::RhsN);
494                            acc7 = simd.simd_mult_add(
495                                simd.simd_from_lhs(lhs7),
496                                simd.simd_from_rhs(rhs7),
497                                acc7,
498                            );
499
500                            depth += lane8;
501                        }
502
503                        let acc0 = simd.simd_add(acc0, acc1);
504                        let acc2 = simd.simd_add(acc2, acc3);
505                        let acc4 = simd.simd_add(acc4, acc5);
506                        let acc6 = simd.simd_add(acc6, acc7);
507
508                        let acc0 = simd.simd_add(acc0, acc2);
509                        let acc4 = simd.simd_add(acc4, acc6);
510
511                        let mut acc0 = simd.simd_add(acc0, acc4);
512
513                        while depth < k_lane {
514                            let lhs0 = *(lhs.wrapping_add(depth) as *const S::LhsN);
515                            let rhs0 = *(rhs.wrapping_add(depth) as *const S::RhsN);
516                            acc0 = simd.simd_mult_add(
517                                simd.simd_from_lhs(lhs0),
518                                simd.simd_from_rhs(rhs0),
519                                acc0,
520                            );
521
522                            depth += lane;
523                        }
524
525                        let acc_ptr = &acc0 as *const _ as *const Acc;
526                        let mut acc0 = *acc_ptr;
527                        for x in 1..S::SIMD_WIDTH {
528                            acc0 = simd.add(acc0, *acc_ptr.add(x));
529                        }
530
531                        while depth < k {
532                            let lhs0 = *(lhs.wrapping_add(depth + 0));
533                            let rhs0 = *(rhs.wrapping_add(depth + 0));
534
535                            acc0 = simd.mult_add(simd.from_lhs(lhs0), simd.from_rhs(rhs0), acc0);
536
537                            depth += 1;
538                        }
539
540                        if alpha.is_zero() {
541                            let dst = dst.wrapping_offset(dst_rs * row as isize);
542                            *dst = simd.into_dst(simd.mult(acc0, beta));
543                        } else {
544                            let dst = dst.wrapping_offset(dst_rs * row as isize);
545                            *dst =
546                                simd.into_dst(simd.add(
547                                    simd.mult(acc0, beta),
548                                    simd.mult(simd.from_dst(*dst), alpha),
549                                ));
550                        }
551                    }
552                }
553            }
554        }
555
556        simd.vectorize(Impl {
557            simd,
558            dst,
559            dst_rs,
560            m,
561            k,
562            lhs,
563            lhs_rs,
564            rhs,
565            alpha,
566            beta,
567        })
568    }
569
570    assert_eq!(lhs_cs, 1);
571    assert_eq!(rhs_rs, 1);
572
573    for x in 0..n {
574        implementation(
575            simd,
576            dst.wrapping_offset(x as isize * dst_cs),
577            dst_rs,
578            m,
579            k,
580            lhs,
581            lhs_rs,
582            rhs.wrapping_offset(rhs_cs * x as isize),
583            alpha,
584            beta,
585        );
586    }
587}