gemm_common/
gevv.rs

1use crate::simd::Simd;
2use num_traits::{One, Zero};
3
4#[inline(always)]
5pub unsafe fn gevv<
6    T: Copy
7        + Zero
8        + One
9        + Send
10        + Sync
11        + core::fmt::Debug
12        + core::ops::Add<Output = T>
13        + core::ops::Mul<Output = T>
14        + core::cmp::PartialEq,
15    S: Simd,
16>(
17    _simd: S,
18    m: usize,
19    n: usize,
20    k: usize,
21    dst: *mut T,
22    dst_cs: isize,
23    dst_rs: isize,
24    lhs: *const T,
25    lhs_cs: isize,
26    lhs_rs: isize,
27    rhs: *const T,
28    rhs_cs: isize,
29    rhs_rs: isize,
30    alpha: T,
31    beta: T,
32    mul_add: impl Fn(T, T, T) -> T,
33) {
34    macro_rules! do_work {
35        () => {
36            match k {
37                0 => {
38                    if !alpha.is_zero() {
39                        for col in 0..n {
40                            for row in 0..m {
41                                let dst = dst
42                                    .wrapping_offset(row as isize * dst_rs)
43                                    .wrapping_offset(col as isize * dst_cs);
44
45                                *dst = alpha * *dst;
46                            }
47                        }
48                    } else {
49                        for col in 0..n {
50                            for row in 0..m {
51                                let dst = dst
52                                    .wrapping_offset(row as isize * dst_rs)
53                                    .wrapping_offset(col as isize * dst_cs);
54
55                                *dst = T::zero();
56                            }
57                        }
58                    }
59                    return;
60                }
61                1 => {
62                    if !alpha.is_zero() {
63                        if alpha.is_one() {
64                            for col in 0..n {
65                                let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
66                                for row in 0..m {
67                                    let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
68                                    let dst = dst
69                                        .wrapping_offset(row as isize * dst_rs)
70                                        .wrapping_offset(col as isize * dst_cs);
71
72                                    *dst = mul_add(lhs, rhs, *dst);
73                                }
74                            }
75                        } else {
76                            for col in 0..n {
77                                let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
78                                for row in 0..m {
79                                    let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
80                                    let dst = dst
81                                        .wrapping_offset(row as isize * dst_rs)
82                                        .wrapping_offset(col as isize * dst_cs);
83
84                                    *dst = mul_add(lhs, rhs, alpha * *dst);
85                                }
86                            }
87                        }
88                    } else {
89                        for col in 0..n {
90                            let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
91                            for row in 0..m {
92                                let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
93                                let dst = dst
94                                    .wrapping_offset(row as isize * dst_rs)
95                                    .wrapping_offset(col as isize * dst_cs);
96
97                                *dst = lhs * rhs;
98                            }
99                        }
100                    }
101                    return;
102                }
103                2 => {
104                    if !alpha.is_zero() {
105                        if alpha.is_one() {
106                            for col in 0..n {
107                                let rhs0 =
108                                    beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
109                                let rhs1 =
110                                    beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
111                                for row in 0..m {
112                                    let lhs0 =
113                                        *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
114                                    let lhs1 =
115                                        *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
116                                    let dst = dst
117                                        .wrapping_offset(row as isize * dst_rs)
118                                        .wrapping_offset(col as isize * dst_cs);
119
120                                    *dst = mul_add(lhs1, rhs1, mul_add(lhs0, rhs0, *dst));
121                                }
122                            }
123                        } else {
124                            for col in 0..n {
125                                let rhs0 =
126                                    beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
127                                let rhs1 =
128                                    beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
129                                for row in 0..m {
130                                    let lhs0 =
131                                        *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
132                                    let lhs1 =
133                                        *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
134                                    let dst = dst
135                                        .wrapping_offset(row as isize * dst_rs)
136                                        .wrapping_offset(col as isize * dst_cs);
137
138                                    *dst = mul_add(lhs1, rhs1, mul_add(lhs0, rhs0, alpha * *dst));
139                                }
140                            }
141                        }
142                    } else {
143                        for col in 0..n {
144                            let rhs0 =
145                                beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
146                            let rhs1 =
147                                beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
148                            for row in 0..m {
149                                let lhs0 = *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
150                                let lhs1 = *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
151                                let dst = dst
152                                    .wrapping_offset(row as isize * dst_rs)
153                                    .wrapping_offset(col as isize * dst_cs);
154
155                                *dst = mul_add(lhs1, rhs1, lhs0 * rhs0);
156                            }
157                        }
158                    }
159                    return;
160                }
161                _ => unreachable!(),
162            }
163        };
164    }
165    do_work!()
166}