1use crate::simd::Simd;
2use num_traits::{One, Zero};
3
4#[inline(always)]
5pub unsafe fn gevv<
6 T: Copy
7 + Zero
8 + One
9 + Send
10 + Sync
11 + core::fmt::Debug
12 + core::ops::Add<Output = T>
13 + core::ops::Mul<Output = T>
14 + core::cmp::PartialEq,
15 S: Simd,
16>(
17 _simd: S,
18 m: usize,
19 n: usize,
20 k: usize,
21 dst: *mut T,
22 dst_cs: isize,
23 dst_rs: isize,
24 lhs: *const T,
25 lhs_cs: isize,
26 lhs_rs: isize,
27 rhs: *const T,
28 rhs_cs: isize,
29 rhs_rs: isize,
30 alpha: T,
31 beta: T,
32 mul_add: impl Fn(T, T, T) -> T,
33) {
34 macro_rules! do_work {
35 () => {
36 match k {
37 0 => {
38 if !alpha.is_zero() {
39 for col in 0..n {
40 for row in 0..m {
41 let dst = dst
42 .wrapping_offset(row as isize * dst_rs)
43 .wrapping_offset(col as isize * dst_cs);
44
45 *dst = alpha * *dst;
46 }
47 }
48 } else {
49 for col in 0..n {
50 for row in 0..m {
51 let dst = dst
52 .wrapping_offset(row as isize * dst_rs)
53 .wrapping_offset(col as isize * dst_cs);
54
55 *dst = T::zero();
56 }
57 }
58 }
59 return;
60 }
61 1 => {
62 if !alpha.is_zero() {
63 if alpha.is_one() {
64 for col in 0..n {
65 let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
66 for row in 0..m {
67 let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
68 let dst = dst
69 .wrapping_offset(row as isize * dst_rs)
70 .wrapping_offset(col as isize * dst_cs);
71
72 *dst = mul_add(lhs, rhs, *dst);
73 }
74 }
75 } else {
76 for col in 0..n {
77 let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
78 for row in 0..m {
79 let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
80 let dst = dst
81 .wrapping_offset(row as isize * dst_rs)
82 .wrapping_offset(col as isize * dst_cs);
83
84 *dst = mul_add(lhs, rhs, alpha * *dst);
85 }
86 }
87 }
88 } else {
89 for col in 0..n {
90 let rhs = beta * *rhs.wrapping_offset(col as isize * rhs_cs);
91 for row in 0..m {
92 let lhs = *lhs.wrapping_offset(row as isize * lhs_rs);
93 let dst = dst
94 .wrapping_offset(row as isize * dst_rs)
95 .wrapping_offset(col as isize * dst_cs);
96
97 *dst = lhs * rhs;
98 }
99 }
100 }
101 return;
102 }
103 2 => {
104 if !alpha.is_zero() {
105 if alpha.is_one() {
106 for col in 0..n {
107 let rhs0 =
108 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
109 let rhs1 =
110 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
111 for row in 0..m {
112 let lhs0 =
113 *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
114 let lhs1 =
115 *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
116 let dst = dst
117 .wrapping_offset(row as isize * dst_rs)
118 .wrapping_offset(col as isize * dst_cs);
119
120 *dst = mul_add(lhs1, rhs1, mul_add(lhs0, rhs0, *dst));
121 }
122 }
123 } else {
124 for col in 0..n {
125 let rhs0 =
126 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
127 let rhs1 =
128 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
129 for row in 0..m {
130 let lhs0 =
131 *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
132 let lhs1 =
133 *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
134 let dst = dst
135 .wrapping_offset(row as isize * dst_rs)
136 .wrapping_offset(col as isize * dst_cs);
137
138 *dst = mul_add(lhs1, rhs1, mul_add(lhs0, rhs0, alpha * *dst));
139 }
140 }
141 }
142 } else {
143 for col in 0..n {
144 let rhs0 =
145 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 0 * rhs_rs);
146 let rhs1 =
147 beta * *rhs.wrapping_offset(col as isize * rhs_cs + 1 * rhs_rs);
148 for row in 0..m {
149 let lhs0 = *lhs.wrapping_offset(row as isize * lhs_rs + 0 * lhs_cs);
150 let lhs1 = *lhs.wrapping_offset(row as isize * lhs_rs + 1 * lhs_cs);
151 let dst = dst
152 .wrapping_offset(row as isize * dst_rs)
153 .wrapping_offset(col as isize * dst_cs);
154
155 *dst = mul_add(lhs1, rhs1, lhs0 * rhs0);
156 }
157 }
158 }
159 return;
160 }
161 _ => unreachable!(),
162 }
163 };
164 }
165 do_work!()
166}