1use crate::simd::Simd;
2
3#[inline(always)]
4pub fn quick_zero<T: Copy>(slice: &mut [core::mem::MaybeUninit<T>]) {
5 let n = slice.len();
6 match n {
7 1 => unsafe { *(slice.as_mut_ptr() as *mut [T; 1]) = core::mem::zeroed() },
8 2 => unsafe { *(slice.as_mut_ptr() as *mut [T; 2]) = core::mem::zeroed() },
9 3 => unsafe { *(slice.as_mut_ptr() as *mut [T; 3]) = core::mem::zeroed() },
10 4 => unsafe { *(slice.as_mut_ptr() as *mut [T; 4]) = core::mem::zeroed() },
11 5 => unsafe { *(slice.as_mut_ptr() as *mut [T; 5]) = core::mem::zeroed() },
12 6 => unsafe { *(slice.as_mut_ptr() as *mut [T; 6]) = core::mem::zeroed() },
13 7 => unsafe { *(slice.as_mut_ptr() as *mut [T; 7]) = core::mem::zeroed() },
14 8 => unsafe { *(slice.as_mut_ptr() as *mut [T; 8]) = core::mem::zeroed() },
15 9 => unsafe { *(slice.as_mut_ptr() as *mut [T; 9]) = core::mem::zeroed() },
16 10 => unsafe { *(slice.as_mut_ptr() as *mut [T; 10]) = core::mem::zeroed() },
17 11 => unsafe { *(slice.as_mut_ptr() as *mut [T; 11]) = core::mem::zeroed() },
18 12 => unsafe { *(slice.as_mut_ptr() as *mut [T; 12]) = core::mem::zeroed() },
19 13 => unsafe { *(slice.as_mut_ptr() as *mut [T; 13]) = core::mem::zeroed() },
20 14 => unsafe { *(slice.as_mut_ptr() as *mut [T; 14]) = core::mem::zeroed() },
21 15 => unsafe { *(slice.as_mut_ptr() as *mut [T; 15]) = core::mem::zeroed() },
22 16 => unsafe { *(slice.as_mut_ptr() as *mut [T; 16]) = core::mem::zeroed() },
23 17 => unsafe { *(slice.as_mut_ptr() as *mut [T; 17]) = core::mem::zeroed() },
24 18 => unsafe { *(slice.as_mut_ptr() as *mut [T; 18]) = core::mem::zeroed() },
25 19 => unsafe { *(slice.as_mut_ptr() as *mut [T; 19]) = core::mem::zeroed() },
26 20 => unsafe { *(slice.as_mut_ptr() as *mut [T; 20]) = core::mem::zeroed() },
27 21 => unsafe { *(slice.as_mut_ptr() as *mut [T; 21]) = core::mem::zeroed() },
28 22 => unsafe { *(slice.as_mut_ptr() as *mut [T; 22]) = core::mem::zeroed() },
29 23 => unsafe { *(slice.as_mut_ptr() as *mut [T; 23]) = core::mem::zeroed() },
30 24 => unsafe { *(slice.as_mut_ptr() as *mut [T; 24]) = core::mem::zeroed() },
31 25 => unsafe { *(slice.as_mut_ptr() as *mut [T; 25]) = core::mem::zeroed() },
32 26 => unsafe { *(slice.as_mut_ptr() as *mut [T; 26]) = core::mem::zeroed() },
33 27 => unsafe { *(slice.as_mut_ptr() as *mut [T; 27]) = core::mem::zeroed() },
34 28 => unsafe { *(slice.as_mut_ptr() as *mut [T; 28]) = core::mem::zeroed() },
35 29 => unsafe { *(slice.as_mut_ptr() as *mut [T; 29]) = core::mem::zeroed() },
36 30 => unsafe { *(slice.as_mut_ptr() as *mut [T; 30]) = core::mem::zeroed() },
37 31 => unsafe { *(slice.as_mut_ptr() as *mut [T; 31]) = core::mem::zeroed() },
38 32 => unsafe { *(slice.as_mut_ptr() as *mut [T; 32]) = core::mem::zeroed() },
39 33 => unsafe { *(slice.as_mut_ptr() as *mut [T; 33]) = core::mem::zeroed() },
40 34 => unsafe { *(slice.as_mut_ptr() as *mut [T; 34]) = core::mem::zeroed() },
41 35 => unsafe { *(slice.as_mut_ptr() as *mut [T; 35]) = core::mem::zeroed() },
42 36 => unsafe { *(slice.as_mut_ptr() as *mut [T; 36]) = core::mem::zeroed() },
43 37 => unsafe { *(slice.as_mut_ptr() as *mut [T; 37]) = core::mem::zeroed() },
44 38 => unsafe { *(slice.as_mut_ptr() as *mut [T; 38]) = core::mem::zeroed() },
45 39 => unsafe { *(slice.as_mut_ptr() as *mut [T; 39]) = core::mem::zeroed() },
46 40 => unsafe { *(slice.as_mut_ptr() as *mut [T; 40]) = core::mem::zeroed() },
47 41 => unsafe { *(slice.as_mut_ptr() as *mut [T; 41]) = core::mem::zeroed() },
48 42 => unsafe { *(slice.as_mut_ptr() as *mut [T; 42]) = core::mem::zeroed() },
49 43 => unsafe { *(slice.as_mut_ptr() as *mut [T; 43]) = core::mem::zeroed() },
50 44 => unsafe { *(slice.as_mut_ptr() as *mut [T; 44]) = core::mem::zeroed() },
51 45 => unsafe { *(slice.as_mut_ptr() as *mut [T; 45]) = core::mem::zeroed() },
52 46 => unsafe { *(slice.as_mut_ptr() as *mut [T; 46]) = core::mem::zeroed() },
53 47 => unsafe { *(slice.as_mut_ptr() as *mut [T; 47]) = core::mem::zeroed() },
54 48 => unsafe { *(slice.as_mut_ptr() as *mut [T; 48]) = core::mem::zeroed() },
55 49 => unsafe { *(slice.as_mut_ptr() as *mut [T; 49]) = core::mem::zeroed() },
56 50 => unsafe { *(slice.as_mut_ptr() as *mut [T; 50]) = core::mem::zeroed() },
57 51 => unsafe { *(slice.as_mut_ptr() as *mut [T; 51]) = core::mem::zeroed() },
58 52 => unsafe { *(slice.as_mut_ptr() as *mut [T; 52]) = core::mem::zeroed() },
59 53 => unsafe { *(slice.as_mut_ptr() as *mut [T; 53]) = core::mem::zeroed() },
60 54 => unsafe { *(slice.as_mut_ptr() as *mut [T; 54]) = core::mem::zeroed() },
61 55 => unsafe { *(slice.as_mut_ptr() as *mut [T; 55]) = core::mem::zeroed() },
62 56 => unsafe { *(slice.as_mut_ptr() as *mut [T; 56]) = core::mem::zeroed() },
63 57 => unsafe { *(slice.as_mut_ptr() as *mut [T; 57]) = core::mem::zeroed() },
64 58 => unsafe { *(slice.as_mut_ptr() as *mut [T; 58]) = core::mem::zeroed() },
65 59 => unsafe { *(slice.as_mut_ptr() as *mut [T; 59]) = core::mem::zeroed() },
66 60 => unsafe { *(slice.as_mut_ptr() as *mut [T; 60]) = core::mem::zeroed() },
67 61 => unsafe { *(slice.as_mut_ptr() as *mut [T; 61]) = core::mem::zeroed() },
68 62 => unsafe { *(slice.as_mut_ptr() as *mut [T; 62]) = core::mem::zeroed() },
69 63 => unsafe { *(slice.as_mut_ptr() as *mut [T; 63]) = core::mem::zeroed() },
70 64 => unsafe { *(slice.as_mut_ptr() as *mut [T; 64]) = core::mem::zeroed() },
71 _ => {
72 for value in slice {
73 *value = unsafe { core::mem::zeroed() };
74 }
75 }
76 }
77}
78
79#[inline(always)]
80unsafe fn quick_copy<T: Copy>(dst: *mut T, src: *const T, n: usize) {
81 match n {
82 1 => unsafe { *(dst as *mut [T; 1]) = *(src as *const [T; 1]) },
83 2 => unsafe { *(dst as *mut [T; 2]) = *(src as *const [T; 2]) },
84 3 => unsafe { *(dst as *mut [T; 3]) = *(src as *const [T; 3]) },
85 4 => unsafe { *(dst as *mut [T; 4]) = *(src as *const [T; 4]) },
86 5 => unsafe { *(dst as *mut [T; 5]) = *(src as *const [T; 5]) },
87 6 => unsafe { *(dst as *mut [T; 6]) = *(src as *const [T; 6]) },
88 7 => unsafe { *(dst as *mut [T; 7]) = *(src as *const [T; 7]) },
89 8 => unsafe { *(dst as *mut [T; 8]) = *(src as *const [T; 8]) },
90 9 => unsafe { *(dst as *mut [T; 9]) = *(src as *const [T; 9]) },
91 10 => unsafe { *(dst as *mut [T; 10]) = *(src as *const [T; 10]) },
92 11 => unsafe { *(dst as *mut [T; 11]) = *(src as *const [T; 11]) },
93 12 => unsafe { *(dst as *mut [T; 12]) = *(src as *const [T; 12]) },
94 13 => unsafe { *(dst as *mut [T; 13]) = *(src as *const [T; 13]) },
95 14 => unsafe { *(dst as *mut [T; 14]) = *(src as *const [T; 14]) },
96 15 => unsafe { *(dst as *mut [T; 15]) = *(src as *const [T; 15]) },
97 16 => unsafe { *(dst as *mut [T; 16]) = *(src as *const [T; 16]) },
98 17 => unsafe { *(dst as *mut [T; 17]) = *(src as *const [T; 17]) },
99 18 => unsafe { *(dst as *mut [T; 18]) = *(src as *const [T; 18]) },
100 19 => unsafe { *(dst as *mut [T; 19]) = *(src as *const [T; 19]) },
101 20 => unsafe { *(dst as *mut [T; 20]) = *(src as *const [T; 20]) },
102 21 => unsafe { *(dst as *mut [T; 21]) = *(src as *const [T; 21]) },
103 22 => unsafe { *(dst as *mut [T; 22]) = *(src as *const [T; 22]) },
104 23 => unsafe { *(dst as *mut [T; 23]) = *(src as *const [T; 23]) },
105 24 => unsafe { *(dst as *mut [T; 24]) = *(src as *const [T; 24]) },
106 25 => unsafe { *(dst as *mut [T; 25]) = *(src as *const [T; 25]) },
107 26 => unsafe { *(dst as *mut [T; 26]) = *(src as *const [T; 26]) },
108 27 => unsafe { *(dst as *mut [T; 27]) = *(src as *const [T; 27]) },
109 28 => unsafe { *(dst as *mut [T; 28]) = *(src as *const [T; 28]) },
110 29 => unsafe { *(dst as *mut [T; 29]) = *(src as *const [T; 29]) },
111 30 => unsafe { *(dst as *mut [T; 30]) = *(src as *const [T; 30]) },
112 31 => unsafe { *(dst as *mut [T; 31]) = *(src as *const [T; 31]) },
113 32 => unsafe { *(dst as *mut [T; 32]) = *(src as *const [T; 32]) },
114 33 => unsafe { *(dst as *mut [T; 33]) = *(src as *const [T; 33]) },
115 34 => unsafe { *(dst as *mut [T; 34]) = *(src as *const [T; 34]) },
116 35 => unsafe { *(dst as *mut [T; 35]) = *(src as *const [T; 35]) },
117 36 => unsafe { *(dst as *mut [T; 36]) = *(src as *const [T; 36]) },
118 37 => unsafe { *(dst as *mut [T; 37]) = *(src as *const [T; 37]) },
119 38 => unsafe { *(dst as *mut [T; 38]) = *(src as *const [T; 38]) },
120 39 => unsafe { *(dst as *mut [T; 39]) = *(src as *const [T; 39]) },
121 40 => unsafe { *(dst as *mut [T; 40]) = *(src as *const [T; 40]) },
122 41 => unsafe { *(dst as *mut [T; 41]) = *(src as *const [T; 41]) },
123 42 => unsafe { *(dst as *mut [T; 42]) = *(src as *const [T; 42]) },
124 43 => unsafe { *(dst as *mut [T; 43]) = *(src as *const [T; 43]) },
125 44 => unsafe { *(dst as *mut [T; 44]) = *(src as *const [T; 44]) },
126 45 => unsafe { *(dst as *mut [T; 45]) = *(src as *const [T; 45]) },
127 46 => unsafe { *(dst as *mut [T; 46]) = *(src as *const [T; 46]) },
128 47 => unsafe { *(dst as *mut [T; 47]) = *(src as *const [T; 47]) },
129 48 => unsafe { *(dst as *mut [T; 48]) = *(src as *const [T; 48]) },
130 49 => unsafe { *(dst as *mut [T; 49]) = *(src as *const [T; 49]) },
131 50 => unsafe { *(dst as *mut [T; 50]) = *(src as *const [T; 50]) },
132 51 => unsafe { *(dst as *mut [T; 51]) = *(src as *const [T; 51]) },
133 52 => unsafe { *(dst as *mut [T; 52]) = *(src as *const [T; 52]) },
134 53 => unsafe { *(dst as *mut [T; 53]) = *(src as *const [T; 53]) },
135 54 => unsafe { *(dst as *mut [T; 54]) = *(src as *const [T; 54]) },
136 55 => unsafe { *(dst as *mut [T; 55]) = *(src as *const [T; 55]) },
137 56 => unsafe { *(dst as *mut [T; 56]) = *(src as *const [T; 56]) },
138 57 => unsafe { *(dst as *mut [T; 57]) = *(src as *const [T; 57]) },
139 58 => unsafe { *(dst as *mut [T; 58]) = *(src as *const [T; 58]) },
140 59 => unsafe { *(dst as *mut [T; 59]) = *(src as *const [T; 59]) },
141 60 => unsafe { *(dst as *mut [T; 60]) = *(src as *const [T; 60]) },
142 61 => unsafe { *(dst as *mut [T; 61]) = *(src as *const [T; 61]) },
143 62 => unsafe { *(dst as *mut [T; 62]) = *(src as *const [T; 62]) },
144 63 => unsafe { *(dst as *mut [T; 63]) = *(src as *const [T; 63]) },
145 64 => unsafe { *(dst as *mut [T; 64]) = *(src as *const [T; 64]) },
146 _ => core::ptr::copy_nonoverlapping(src, dst, n),
147 }
148}
149
150#[inline(always)]
151unsafe fn pack_generic_inner_loop<T: Copy, const N: usize, const DST_WIDTH: usize>(
152 mut dst: *mut T,
153 mut src: *const T,
154 src_rs: isize,
155 src_cs: isize,
156 src_width: usize,
157 k: usize,
158) {
159 if src_width == DST_WIDTH {
160 if src_rs == 1 {
161 for _ in 0..k {
162 let val = (src as *const [T; DST_WIDTH]).read();
163 (dst as *mut [T; DST_WIDTH]).write(val);
164
165 src = src.wrapping_offset(src_cs);
166 dst = dst.add(DST_WIDTH);
167 }
168 } else {
169 for _ in 0..k {
170 for j in 0..DST_WIDTH {
171 *dst.add(j) = *src.offset(j as isize * src_rs);
172 }
173 src = src.wrapping_offset(src_cs);
174 dst = dst.add(DST_WIDTH);
175 }
176 }
177 } else if src_width == N {
178 if src_rs == 1 {
179 for _ in 0..k {
180 let val = (src as *const [T; N]).read();
181 (dst as *mut [T; N]).write(val);
182
183 src = src.wrapping_offset(src_cs);
184 dst = dst.add(DST_WIDTH);
185 }
186 } else {
187 for _ in 0..k {
188 for j in 0..N {
189 *dst.add(j) = *src.offset(j as isize * src_rs);
190 }
191 src = src.wrapping_offset(src_cs);
192 dst = dst.add(DST_WIDTH);
193 }
194 }
195 } else if src_width == 2 * N {
196 if src_rs == 1 {
197 for _ in 0..k {
198 let val0 = (src as *const [T; N]).read();
199 let val1 = (src.add(N) as *const [T; N]).read();
200 (dst as *mut [T; N]).write(val0);
201 (dst.add(N) as *mut [T; N]).write(val1);
202
203 src = src.wrapping_offset(src_cs);
204 dst = dst.add(DST_WIDTH);
205 }
206 } else {
207 for _ in 0..k {
208 for j in 0..2 * N {
209 *dst.add(j) = *src.offset(j as isize * src_rs);
210 }
211 src = src.wrapping_offset(src_cs);
212 dst = dst.add(DST_WIDTH);
213 }
214 }
215 } else {
216 if src_rs == 1 {
217 for _ in 0..k {
218 quick_copy(dst, src, src_width);
219 quick_zero::<T>(core::slice::from_raw_parts_mut(
220 dst.add(src_width) as _,
221 DST_WIDTH - src_width,
222 ));
223 src = src.wrapping_offset(src_cs);
224 dst = dst.add(DST_WIDTH);
225 }
226 } else {
227 for _ in 0..k {
228 for j in 0..src_width {
229 *dst.add(j) = *src.offset(j as isize * src_rs);
230 }
231 quick_zero::<T>(core::slice::from_raw_parts_mut(
232 dst.add(src_width) as _,
233 DST_WIDTH - src_width,
234 ));
235 src = src.wrapping_offset(src_cs);
236 dst = dst.add(DST_WIDTH);
237 }
238 }
239 }
240}
241
242#[inline(always)]
243unsafe fn pack_generic<T: Copy, const N: usize, const DST_WIDTH: usize>(
244 m: usize,
245 k: usize,
246 mut dst: *mut T,
247 mut src: *const T,
248 src_cs: isize,
249 src_rs: isize,
250 dst_stride: usize,
251) {
252 let m_width = m / DST_WIDTH * DST_WIDTH;
253
254 let mut i = 0;
255 while i < m_width {
256 pack_generic_inner_loop::<_, N, DST_WIDTH>(dst, src, src_rs, src_cs, DST_WIDTH, k);
257 src = src.wrapping_offset(src_rs * DST_WIDTH as isize);
258 dst = dst.add(dst_stride);
259
260 i += DST_WIDTH;
261 }
262 if i < m {
263 pack_generic_inner_loop::<_, N, DST_WIDTH>(dst, src, src_rs, src_cs, m - i, k);
264 }
265}
266
267#[inline(never)]
268pub unsafe fn pack_lhs<T: Copy, const N: usize, const MR: usize, S: Simd>(
269 _: S,
270 m: usize,
271 k: usize,
272 dst: crate::Ptr<T>,
273 src: crate::Ptr<T>,
274 src_cs: isize,
275 src_rs: isize,
276 dst_stride: usize,
277) {
278 let dst = dst.0;
279 let src = src.0;
280 S::vectorize(
281 #[inline(always)]
282 || pack_generic::<T, N, MR>(m, k, dst, src, src_cs, src_rs, dst_stride),
283 );
284}
285
286#[inline(never)]
287pub unsafe fn pack_rhs<T: Copy, const N: usize, const NR: usize, S: Simd>(
288 _: S,
289 n: usize,
290 k: usize,
291 dst: crate::Ptr<T>,
292 src: crate::Ptr<T>,
293 src_cs: isize,
294 src_rs: isize,
295 dst_stride: usize,
296) {
297 let dst = dst.0;
298 let src = src.0;
299 S::vectorize(
300 #[inline(always)]
301 || pack_generic::<T, N, NR>(n, k, dst, src, src_rs, src_cs, dst_stride),
302 );
303}