candle_gemm_common/
gemv.rs1use num_traits::{One, Zero};
2use seq_macro::seq;
3
4use crate::simd::Simd;
5
6#[inline(always)]
7pub unsafe fn gemv<
8 T: Copy
9 + Zero
10 + One
11 + Send
12 + Sync
13 + core::ops::Add<Output = T>
14 + core::ops::Mul<Output = T>
15 + core::cmp::PartialEq,
16 S: Simd,
17>(
18 _simd: S,
19 m: usize,
20 n: usize,
21 k: usize,
22 dst: *mut T,
23 dst_cs: isize,
24 dst_rs: isize,
25 lhs: *const T,
26 lhs_cs: isize,
27 lhs_rs: isize,
28 rhs: *const T,
29 rhs_cs: isize,
30 rhs_rs: isize,
31 alpha: T,
32 beta: T,
33 mul_add: impl Fn(T, T, T) -> T,
34) {
35 if !alpha.is_zero() {
36 for col in 0..n {
37 for row in 0..m {
38 let dst = dst
39 .wrapping_offset(row as isize * dst_rs)
40 .wrapping_offset(col as isize * dst_cs);
41
42 *dst = alpha * *dst;
43 }
44 }
45 } else {
46 for col in 0..n {
47 for row in 0..m {
48 let dst = dst
49 .wrapping_offset(row as isize * dst_rs)
50 .wrapping_offset(col as isize * dst_cs);
51
52 *dst = T::zero();
53 }
54 }
55 }
56
57 macro_rules! do_work {
58 ($n: tt) => {
59 for depth in 0..k {
60 seq!(COL in 0..$n {
61 let rhs~COL = beta * *rhs
62 .wrapping_offset(COL as isize * rhs_cs)
63 .wrapping_offset(depth as isize * rhs_rs);
64 });
65 for row in 0..m {
66 let lhs = *lhs
67 .wrapping_offset(depth as isize * lhs_cs)
68 .wrapping_offset(row as isize * lhs_rs);
69
70 seq!(COL in 0..$n {
71 {
72 let dst = dst
73 .wrapping_offset(COL as isize * dst_cs)
74 .wrapping_offset(row as isize * dst_rs);
75 *dst = mul_add(rhs~COL, lhs, *dst);
76 }
77 });
78 }
79 }
80 }
81 }
82 match n {
83 1 => do_work!(1),
84 _ => unreachable!(),
85 }
86}