candle_gemm_common/
gemv.rs

1use num_traits::{One, Zero};
2use seq_macro::seq;
3
4use crate::simd::Simd;
5
6#[inline(always)]
7pub unsafe fn gemv<
8    T: Copy
9        + Zero
10        + One
11        + Send
12        + Sync
13        + core::ops::Add<Output = T>
14        + core::ops::Mul<Output = T>
15        + core::cmp::PartialEq,
16    S: Simd,
17>(
18    _simd: S,
19    m: usize,
20    n: usize,
21    k: usize,
22    dst: *mut T,
23    dst_cs: isize,
24    dst_rs: isize,
25    lhs: *const T,
26    lhs_cs: isize,
27    lhs_rs: isize,
28    rhs: *const T,
29    rhs_cs: isize,
30    rhs_rs: isize,
31    alpha: T,
32    beta: T,
33    mul_add: impl Fn(T, T, T) -> T,
34) {
35    if !alpha.is_zero() {
36        for col in 0..n {
37            for row in 0..m {
38                let dst = dst
39                    .wrapping_offset(row as isize * dst_rs)
40                    .wrapping_offset(col as isize * dst_cs);
41
42                *dst = alpha * *dst;
43            }
44        }
45    } else {
46        for col in 0..n {
47            for row in 0..m {
48                let dst = dst
49                    .wrapping_offset(row as isize * dst_rs)
50                    .wrapping_offset(col as isize * dst_cs);
51
52                *dst = T::zero();
53            }
54        }
55    }
56
57    macro_rules! do_work {
58        ($n: tt) => {
59            for depth in 0..k {
60                seq!(COL in 0..$n {
61                    let rhs~COL = beta * *rhs
62                        .wrapping_offset(COL as isize * rhs_cs)
63                        .wrapping_offset(depth as isize * rhs_rs);
64                });
65                for row in 0..m {
66                    let lhs = *lhs
67                        .wrapping_offset(depth as isize * lhs_cs)
68                        .wrapping_offset(row as isize * lhs_rs);
69
70                    seq!(COL in 0..$n {
71                        {
72                            let dst = dst
73                                .wrapping_offset(COL as isize * dst_cs)
74                                .wrapping_offset(row as isize * dst_rs);
75                            *dst = mul_add(rhs~COL, lhs, *dst);
76                        }
77                    });
78                }
79            }
80        }
81    }
82    match n {
83        1 => do_work!(1),
84        _ => unreachable!(),
85    }
86}