gemm_common/
pack_operands.rs

1use crate::simd::Simd;
2
3#[inline(always)]
4pub fn quick_zero<T: Copy>(slice: &mut [core::mem::MaybeUninit<T>]) {
5    let n = slice.len();
6    match n {
7        1 => unsafe { *(slice.as_mut_ptr() as *mut [T; 1]) = core::mem::zeroed() },
8        2 => unsafe { *(slice.as_mut_ptr() as *mut [T; 2]) = core::mem::zeroed() },
9        3 => unsafe { *(slice.as_mut_ptr() as *mut [T; 3]) = core::mem::zeroed() },
10        4 => unsafe { *(slice.as_mut_ptr() as *mut [T; 4]) = core::mem::zeroed() },
11        5 => unsafe { *(slice.as_mut_ptr() as *mut [T; 5]) = core::mem::zeroed() },
12        6 => unsafe { *(slice.as_mut_ptr() as *mut [T; 6]) = core::mem::zeroed() },
13        7 => unsafe { *(slice.as_mut_ptr() as *mut [T; 7]) = core::mem::zeroed() },
14        8 => unsafe { *(slice.as_mut_ptr() as *mut [T; 8]) = core::mem::zeroed() },
15        9 => unsafe { *(slice.as_mut_ptr() as *mut [T; 9]) = core::mem::zeroed() },
16        10 => unsafe { *(slice.as_mut_ptr() as *mut [T; 10]) = core::mem::zeroed() },
17        11 => unsafe { *(slice.as_mut_ptr() as *mut [T; 11]) = core::mem::zeroed() },
18        12 => unsafe { *(slice.as_mut_ptr() as *mut [T; 12]) = core::mem::zeroed() },
19        13 => unsafe { *(slice.as_mut_ptr() as *mut [T; 13]) = core::mem::zeroed() },
20        14 => unsafe { *(slice.as_mut_ptr() as *mut [T; 14]) = core::mem::zeroed() },
21        15 => unsafe { *(slice.as_mut_ptr() as *mut [T; 15]) = core::mem::zeroed() },
22        16 => unsafe { *(slice.as_mut_ptr() as *mut [T; 16]) = core::mem::zeroed() },
23        17 => unsafe { *(slice.as_mut_ptr() as *mut [T; 17]) = core::mem::zeroed() },
24        18 => unsafe { *(slice.as_mut_ptr() as *mut [T; 18]) = core::mem::zeroed() },
25        19 => unsafe { *(slice.as_mut_ptr() as *mut [T; 19]) = core::mem::zeroed() },
26        20 => unsafe { *(slice.as_mut_ptr() as *mut [T; 20]) = core::mem::zeroed() },
27        21 => unsafe { *(slice.as_mut_ptr() as *mut [T; 21]) = core::mem::zeroed() },
28        22 => unsafe { *(slice.as_mut_ptr() as *mut [T; 22]) = core::mem::zeroed() },
29        23 => unsafe { *(slice.as_mut_ptr() as *mut [T; 23]) = core::mem::zeroed() },
30        24 => unsafe { *(slice.as_mut_ptr() as *mut [T; 24]) = core::mem::zeroed() },
31        25 => unsafe { *(slice.as_mut_ptr() as *mut [T; 25]) = core::mem::zeroed() },
32        26 => unsafe { *(slice.as_mut_ptr() as *mut [T; 26]) = core::mem::zeroed() },
33        27 => unsafe { *(slice.as_mut_ptr() as *mut [T; 27]) = core::mem::zeroed() },
34        28 => unsafe { *(slice.as_mut_ptr() as *mut [T; 28]) = core::mem::zeroed() },
35        29 => unsafe { *(slice.as_mut_ptr() as *mut [T; 29]) = core::mem::zeroed() },
36        30 => unsafe { *(slice.as_mut_ptr() as *mut [T; 30]) = core::mem::zeroed() },
37        31 => unsafe { *(slice.as_mut_ptr() as *mut [T; 31]) = core::mem::zeroed() },
38        32 => unsafe { *(slice.as_mut_ptr() as *mut [T; 32]) = core::mem::zeroed() },
39        33 => unsafe { *(slice.as_mut_ptr() as *mut [T; 33]) = core::mem::zeroed() },
40        34 => unsafe { *(slice.as_mut_ptr() as *mut [T; 34]) = core::mem::zeroed() },
41        35 => unsafe { *(slice.as_mut_ptr() as *mut [T; 35]) = core::mem::zeroed() },
42        36 => unsafe { *(slice.as_mut_ptr() as *mut [T; 36]) = core::mem::zeroed() },
43        37 => unsafe { *(slice.as_mut_ptr() as *mut [T; 37]) = core::mem::zeroed() },
44        38 => unsafe { *(slice.as_mut_ptr() as *mut [T; 38]) = core::mem::zeroed() },
45        39 => unsafe { *(slice.as_mut_ptr() as *mut [T; 39]) = core::mem::zeroed() },
46        40 => unsafe { *(slice.as_mut_ptr() as *mut [T; 40]) = core::mem::zeroed() },
47        41 => unsafe { *(slice.as_mut_ptr() as *mut [T; 41]) = core::mem::zeroed() },
48        42 => unsafe { *(slice.as_mut_ptr() as *mut [T; 42]) = core::mem::zeroed() },
49        43 => unsafe { *(slice.as_mut_ptr() as *mut [T; 43]) = core::mem::zeroed() },
50        44 => unsafe { *(slice.as_mut_ptr() as *mut [T; 44]) = core::mem::zeroed() },
51        45 => unsafe { *(slice.as_mut_ptr() as *mut [T; 45]) = core::mem::zeroed() },
52        46 => unsafe { *(slice.as_mut_ptr() as *mut [T; 46]) = core::mem::zeroed() },
53        47 => unsafe { *(slice.as_mut_ptr() as *mut [T; 47]) = core::mem::zeroed() },
54        48 => unsafe { *(slice.as_mut_ptr() as *mut [T; 48]) = core::mem::zeroed() },
55        49 => unsafe { *(slice.as_mut_ptr() as *mut [T; 49]) = core::mem::zeroed() },
56        50 => unsafe { *(slice.as_mut_ptr() as *mut [T; 50]) = core::mem::zeroed() },
57        51 => unsafe { *(slice.as_mut_ptr() as *mut [T; 51]) = core::mem::zeroed() },
58        52 => unsafe { *(slice.as_mut_ptr() as *mut [T; 52]) = core::mem::zeroed() },
59        53 => unsafe { *(slice.as_mut_ptr() as *mut [T; 53]) = core::mem::zeroed() },
60        54 => unsafe { *(slice.as_mut_ptr() as *mut [T; 54]) = core::mem::zeroed() },
61        55 => unsafe { *(slice.as_mut_ptr() as *mut [T; 55]) = core::mem::zeroed() },
62        56 => unsafe { *(slice.as_mut_ptr() as *mut [T; 56]) = core::mem::zeroed() },
63        57 => unsafe { *(slice.as_mut_ptr() as *mut [T; 57]) = core::mem::zeroed() },
64        58 => unsafe { *(slice.as_mut_ptr() as *mut [T; 58]) = core::mem::zeroed() },
65        59 => unsafe { *(slice.as_mut_ptr() as *mut [T; 59]) = core::mem::zeroed() },
66        60 => unsafe { *(slice.as_mut_ptr() as *mut [T; 60]) = core::mem::zeroed() },
67        61 => unsafe { *(slice.as_mut_ptr() as *mut [T; 61]) = core::mem::zeroed() },
68        62 => unsafe { *(slice.as_mut_ptr() as *mut [T; 62]) = core::mem::zeroed() },
69        63 => unsafe { *(slice.as_mut_ptr() as *mut [T; 63]) = core::mem::zeroed() },
70        64 => unsafe { *(slice.as_mut_ptr() as *mut [T; 64]) = core::mem::zeroed() },
71        _ => {
72            for value in slice {
73                *value = unsafe { core::mem::zeroed() };
74            }
75        }
76    }
77}
78
79#[inline(always)]
80unsafe fn quick_copy<T: Copy>(dst: *mut T, src: *const T, n: usize) {
81    match n {
82        1 => unsafe { *(dst as *mut [T; 1]) = *(src as *const [T; 1]) },
83        2 => unsafe { *(dst as *mut [T; 2]) = *(src as *const [T; 2]) },
84        3 => unsafe { *(dst as *mut [T; 3]) = *(src as *const [T; 3]) },
85        4 => unsafe { *(dst as *mut [T; 4]) = *(src as *const [T; 4]) },
86        5 => unsafe { *(dst as *mut [T; 5]) = *(src as *const [T; 5]) },
87        6 => unsafe { *(dst as *mut [T; 6]) = *(src as *const [T; 6]) },
88        7 => unsafe { *(dst as *mut [T; 7]) = *(src as *const [T; 7]) },
89        8 => unsafe { *(dst as *mut [T; 8]) = *(src as *const [T; 8]) },
90        9 => unsafe { *(dst as *mut [T; 9]) = *(src as *const [T; 9]) },
91        10 => unsafe { *(dst as *mut [T; 10]) = *(src as *const [T; 10]) },
92        11 => unsafe { *(dst as *mut [T; 11]) = *(src as *const [T; 11]) },
93        12 => unsafe { *(dst as *mut [T; 12]) = *(src as *const [T; 12]) },
94        13 => unsafe { *(dst as *mut [T; 13]) = *(src as *const [T; 13]) },
95        14 => unsafe { *(dst as *mut [T; 14]) = *(src as *const [T; 14]) },
96        15 => unsafe { *(dst as *mut [T; 15]) = *(src as *const [T; 15]) },
97        16 => unsafe { *(dst as *mut [T; 16]) = *(src as *const [T; 16]) },
98        17 => unsafe { *(dst as *mut [T; 17]) = *(src as *const [T; 17]) },
99        18 => unsafe { *(dst as *mut [T; 18]) = *(src as *const [T; 18]) },
100        19 => unsafe { *(dst as *mut [T; 19]) = *(src as *const [T; 19]) },
101        20 => unsafe { *(dst as *mut [T; 20]) = *(src as *const [T; 20]) },
102        21 => unsafe { *(dst as *mut [T; 21]) = *(src as *const [T; 21]) },
103        22 => unsafe { *(dst as *mut [T; 22]) = *(src as *const [T; 22]) },
104        23 => unsafe { *(dst as *mut [T; 23]) = *(src as *const [T; 23]) },
105        24 => unsafe { *(dst as *mut [T; 24]) = *(src as *const [T; 24]) },
106        25 => unsafe { *(dst as *mut [T; 25]) = *(src as *const [T; 25]) },
107        26 => unsafe { *(dst as *mut [T; 26]) = *(src as *const [T; 26]) },
108        27 => unsafe { *(dst as *mut [T; 27]) = *(src as *const [T; 27]) },
109        28 => unsafe { *(dst as *mut [T; 28]) = *(src as *const [T; 28]) },
110        29 => unsafe { *(dst as *mut [T; 29]) = *(src as *const [T; 29]) },
111        30 => unsafe { *(dst as *mut [T; 30]) = *(src as *const [T; 30]) },
112        31 => unsafe { *(dst as *mut [T; 31]) = *(src as *const [T; 31]) },
113        32 => unsafe { *(dst as *mut [T; 32]) = *(src as *const [T; 32]) },
114        33 => unsafe { *(dst as *mut [T; 33]) = *(src as *const [T; 33]) },
115        34 => unsafe { *(dst as *mut [T; 34]) = *(src as *const [T; 34]) },
116        35 => unsafe { *(dst as *mut [T; 35]) = *(src as *const [T; 35]) },
117        36 => unsafe { *(dst as *mut [T; 36]) = *(src as *const [T; 36]) },
118        37 => unsafe { *(dst as *mut [T; 37]) = *(src as *const [T; 37]) },
119        38 => unsafe { *(dst as *mut [T; 38]) = *(src as *const [T; 38]) },
120        39 => unsafe { *(dst as *mut [T; 39]) = *(src as *const [T; 39]) },
121        40 => unsafe { *(dst as *mut [T; 40]) = *(src as *const [T; 40]) },
122        41 => unsafe { *(dst as *mut [T; 41]) = *(src as *const [T; 41]) },
123        42 => unsafe { *(dst as *mut [T; 42]) = *(src as *const [T; 42]) },
124        43 => unsafe { *(dst as *mut [T; 43]) = *(src as *const [T; 43]) },
125        44 => unsafe { *(dst as *mut [T; 44]) = *(src as *const [T; 44]) },
126        45 => unsafe { *(dst as *mut [T; 45]) = *(src as *const [T; 45]) },
127        46 => unsafe { *(dst as *mut [T; 46]) = *(src as *const [T; 46]) },
128        47 => unsafe { *(dst as *mut [T; 47]) = *(src as *const [T; 47]) },
129        48 => unsafe { *(dst as *mut [T; 48]) = *(src as *const [T; 48]) },
130        49 => unsafe { *(dst as *mut [T; 49]) = *(src as *const [T; 49]) },
131        50 => unsafe { *(dst as *mut [T; 50]) = *(src as *const [T; 50]) },
132        51 => unsafe { *(dst as *mut [T; 51]) = *(src as *const [T; 51]) },
133        52 => unsafe { *(dst as *mut [T; 52]) = *(src as *const [T; 52]) },
134        53 => unsafe { *(dst as *mut [T; 53]) = *(src as *const [T; 53]) },
135        54 => unsafe { *(dst as *mut [T; 54]) = *(src as *const [T; 54]) },
136        55 => unsafe { *(dst as *mut [T; 55]) = *(src as *const [T; 55]) },
137        56 => unsafe { *(dst as *mut [T; 56]) = *(src as *const [T; 56]) },
138        57 => unsafe { *(dst as *mut [T; 57]) = *(src as *const [T; 57]) },
139        58 => unsafe { *(dst as *mut [T; 58]) = *(src as *const [T; 58]) },
140        59 => unsafe { *(dst as *mut [T; 59]) = *(src as *const [T; 59]) },
141        60 => unsafe { *(dst as *mut [T; 60]) = *(src as *const [T; 60]) },
142        61 => unsafe { *(dst as *mut [T; 61]) = *(src as *const [T; 61]) },
143        62 => unsafe { *(dst as *mut [T; 62]) = *(src as *const [T; 62]) },
144        63 => unsafe { *(dst as *mut [T; 63]) = *(src as *const [T; 63]) },
145        64 => unsafe { *(dst as *mut [T; 64]) = *(src as *const [T; 64]) },
146        _ => core::ptr::copy_nonoverlapping(src, dst, n),
147    }
148}
149
150#[inline(always)]
151unsafe fn pack_generic_inner_loop<T: Copy, const N: usize, const DST_WIDTH: usize>(
152    mut dst: *mut T,
153    mut src: *const T,
154    src_rs: isize,
155    src_cs: isize,
156    src_width: usize,
157    k: usize,
158) {
159    if src_width == DST_WIDTH {
160        if src_rs == 1 {
161            for _ in 0..k {
162                let val = (src as *const [T; DST_WIDTH]).read();
163                (dst as *mut [T; DST_WIDTH]).write(val);
164
165                src = src.wrapping_offset(src_cs);
166                dst = dst.add(DST_WIDTH);
167            }
168        } else {
169            for _ in 0..k {
170                for j in 0..DST_WIDTH {
171                    *dst.add(j) = *src.offset(j as isize * src_rs);
172                }
173                src = src.wrapping_offset(src_cs);
174                dst = dst.add(DST_WIDTH);
175            }
176        }
177    } else if src_width == N {
178        if src_rs == 1 {
179            for _ in 0..k {
180                let val = (src as *const [T; N]).read();
181                (dst as *mut [T; N]).write(val);
182
183                src = src.wrapping_offset(src_cs);
184                dst = dst.add(DST_WIDTH);
185            }
186        } else {
187            for _ in 0..k {
188                for j in 0..N {
189                    *dst.add(j) = *src.offset(j as isize * src_rs);
190                }
191                src = src.wrapping_offset(src_cs);
192                dst = dst.add(DST_WIDTH);
193            }
194        }
195    } else if src_width == 2 * N {
196        if src_rs == 1 {
197            for _ in 0..k {
198                let val0 = (src as *const [T; N]).read();
199                let val1 = (src.add(N) as *const [T; N]).read();
200                (dst as *mut [T; N]).write(val0);
201                (dst.add(N) as *mut [T; N]).write(val1);
202
203                src = src.wrapping_offset(src_cs);
204                dst = dst.add(DST_WIDTH);
205            }
206        } else {
207            for _ in 0..k {
208                for j in 0..2 * N {
209                    *dst.add(j) = *src.offset(j as isize * src_rs);
210                }
211                src = src.wrapping_offset(src_cs);
212                dst = dst.add(DST_WIDTH);
213            }
214        }
215    } else {
216        if src_rs == 1 {
217            for _ in 0..k {
218                quick_copy(dst, src, src_width);
219                quick_zero::<T>(core::slice::from_raw_parts_mut(
220                    dst.add(src_width) as _,
221                    DST_WIDTH - src_width,
222                ));
223                src = src.wrapping_offset(src_cs);
224                dst = dst.add(DST_WIDTH);
225            }
226        } else {
227            for _ in 0..k {
228                for j in 0..src_width {
229                    *dst.add(j) = *src.offset(j as isize * src_rs);
230                }
231                quick_zero::<T>(core::slice::from_raw_parts_mut(
232                    dst.add(src_width) as _,
233                    DST_WIDTH - src_width,
234                ));
235                src = src.wrapping_offset(src_cs);
236                dst = dst.add(DST_WIDTH);
237            }
238        }
239    }
240}
241
242#[inline(always)]
243unsafe fn pack_generic<T: Copy, const N: usize, const DST_WIDTH: usize>(
244    m: usize,
245    k: usize,
246    mut dst: *mut T,
247    mut src: *const T,
248    src_cs: isize,
249    src_rs: isize,
250    dst_stride: usize,
251) {
252    let m_width = m / DST_WIDTH * DST_WIDTH;
253
254    let mut i = 0;
255    while i < m_width {
256        pack_generic_inner_loop::<_, N, DST_WIDTH>(dst, src, src_rs, src_cs, DST_WIDTH, k);
257        src = src.wrapping_offset(src_rs * DST_WIDTH as isize);
258        dst = dst.add(dst_stride);
259
260        i += DST_WIDTH;
261    }
262    if i < m {
263        pack_generic_inner_loop::<_, N, DST_WIDTH>(dst, src, src_rs, src_cs, m - i, k);
264    }
265}
266
267#[inline(never)]
268pub unsafe fn pack_lhs<T: Copy, const N: usize, const MR: usize, S: Simd>(
269    _: S,
270    m: usize,
271    k: usize,
272    dst: crate::Ptr<T>,
273    src: crate::Ptr<T>,
274    src_cs: isize,
275    src_rs: isize,
276    dst_stride: usize,
277) {
278    let dst = dst.0;
279    let src = src.0;
280    S::vectorize(
281        #[inline(always)]
282        || pack_generic::<T, N, MR>(m, k, dst, src, src_cs, src_rs, dst_stride),
283    );
284}
285
286#[inline(never)]
287pub unsafe fn pack_rhs<T: Copy, const N: usize, const NR: usize, S: Simd>(
288    _: S,
289    n: usize,
290    k: usize,
291    dst: crate::Ptr<T>,
292    src: crate::Ptr<T>,
293    src_cs: isize,
294    src_rs: isize,
295    dst_stride: usize,
296) {
297    let dst = dst.0;
298    let src = src.0;
299    S::vectorize(
300        #[inline(always)]
301        || pack_generic::<T, N, NR>(n, k, dst, src, src_rs, src_cs, dst_stride),
302    );
303}