1use provable_contracts_macros::requires;
19
20#[cfg(target_arch = "x86_64")]
21use std::arch::x86_64::*;
22
23pub fn transpose_scalar(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
36 const BLOCK: usize = 8;
37
38 assert_eq!(a.len(), rows * cols, "a length mismatch");
39 assert_eq!(b.len(), rows * cols, "b length mismatch");
40
41 let rb_end = rows / BLOCK * BLOCK;
42 let cb_end = cols / BLOCK * BLOCK;
43
44 for r0 in (0..rb_end).step_by(BLOCK) {
46 for c0 in (0..cb_end).step_by(BLOCK) {
47 for r in r0..r0 + BLOCK {
48 let src_base = r * cols;
49 for c in c0..c0 + BLOCK {
50 b[c * rows + r] = a[src_base + c];
51 }
52 }
53 }
54 }
55
56 if cb_end < cols {
58 for r in 0..rb_end {
59 let src_base = r * cols;
60 for c in cb_end..cols {
61 b[c * rows + r] = a[src_base + c];
62 }
63 }
64 }
65
66 if rb_end < rows {
68 for r in rb_end..rows {
69 let src_base = r * cols;
70 for c in 0..cols {
71 b[c * rows + r] = a[src_base + c];
72 }
73 }
74 }
75}
76
77#[cfg(target_arch = "x86_64")]
92#[target_feature(enable = "avx2")]
93#[inline]
94unsafe fn transpose_8x8_avx2(src: *const f32, src_stride: usize, dst: *mut f32, dst_stride: usize) {
95 unsafe {
96 let r0 = _mm256_loadu_ps(src);
98 let r1 = _mm256_loadu_ps(src.add(src_stride));
99 let r2 = _mm256_loadu_ps(src.add(src_stride * 2));
100 let r3 = _mm256_loadu_ps(src.add(src_stride * 3));
101 let r4 = _mm256_loadu_ps(src.add(src_stride * 4));
102 let r5 = _mm256_loadu_ps(src.add(src_stride * 5));
103 let r6 = _mm256_loadu_ps(src.add(src_stride * 6));
104 let r7 = _mm256_loadu_ps(src.add(src_stride * 7));
105
106 let t0 = _mm256_unpacklo_ps(r0, r1); let t1 = _mm256_unpackhi_ps(r0, r1); let t2 = _mm256_unpacklo_ps(r2, r3); let t3 = _mm256_unpackhi_ps(r2, r3); let t4 = _mm256_unpacklo_ps(r4, r5); let t5 = _mm256_unpackhi_ps(r4, r5); let t6 = _mm256_unpacklo_ps(r6, r7); let t7 = _mm256_unpackhi_ps(r6, r7); let u0 = _mm256_shuffle_ps(t0, t2, 0x44); let u1 = _mm256_shuffle_ps(t0, t2, 0xEE); let u2 = _mm256_shuffle_ps(t1, t3, 0x44); let u3 = _mm256_shuffle_ps(t1, t3, 0xEE); let u4 = _mm256_shuffle_ps(t4, t6, 0x44); let u5 = _mm256_shuffle_ps(t4, t6, 0xEE); let u6 = _mm256_shuffle_ps(t5, t7, 0x44); let u7 = _mm256_shuffle_ps(t5, t7, 0xEE); let v0 = _mm256_permute2f128_ps(u0, u4, 0x20); let v1 = _mm256_permute2f128_ps(u1, u5, 0x20); let v2 = _mm256_permute2f128_ps(u2, u6, 0x20); let v3 = _mm256_permute2f128_ps(u3, u7, 0x20); let v4 = _mm256_permute2f128_ps(u0, u4, 0x31); let v5 = _mm256_permute2f128_ps(u1, u5, 0x31); let v6 = _mm256_permute2f128_ps(u2, u6, 0x31); let v7 = _mm256_permute2f128_ps(u3, u7, 0x31); _mm256_storeu_ps(dst, v0);
138 _mm256_storeu_ps(dst.add(dst_stride), v1);
139 _mm256_storeu_ps(dst.add(dst_stride * 2), v2);
140 _mm256_storeu_ps(dst.add(dst_stride * 3), v3);
141 _mm256_storeu_ps(dst.add(dst_stride * 4), v4);
142 _mm256_storeu_ps(dst.add(dst_stride * 5), v5);
143 _mm256_storeu_ps(dst.add(dst_stride * 6), v6);
144 _mm256_storeu_ps(dst.add(dst_stride * 7), v7);
145 }
146}
147
148#[cfg(target_arch = "x86_64")]
161#[target_feature(enable = "avx2")]
162pub unsafe fn transpose_avx2(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
163 assert_eq!(a.len(), rows * cols, "a length mismatch");
164 assert_eq!(b.len(), rows * cols, "b length mismatch");
165
166 let rb_end = rows / 8 * 8;
167 let cb_end = cols / 8 * 8;
168
169 unsafe {
171 for r0 in (0..rb_end).step_by(8) {
173 for c0 in (0..cb_end).step_by(8) {
174 let src = a.as_ptr().add(r0 * cols + c0);
175 let dst = b.as_mut_ptr().add(c0 * rows + r0);
176 transpose_8x8_avx2(src, cols, dst, rows);
177 }
178 }
179 }
180
181 if cb_end < cols {
183 for r in 0..rb_end {
184 let src_base = r * cols;
185 for c in cb_end..cols {
186 b[c * rows + r] = a[src_base + c];
187 }
188 }
189 }
190
191 if rb_end < rows {
193 for r in rb_end..rows {
194 let src_base = r * cols;
195 for c in 0..cols {
196 b[c * rows + r] = a[src_base + c];
197 }
198 }
199 }
200}
201
202#[requires(a.len() == rows * cols && b.len() == rows * cols)]
212pub fn transpose(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
213 #[cfg(target_arch = "x86_64")]
214 {
215 if is_x86_feature_detected!("avx2") {
216 unsafe {
218 transpose_avx2(rows, cols, a, b);
219 }
220 return;
221 }
222 }
223 transpose_scalar(rows, cols, a, b);
224}
225
226#[cfg(test)]
231mod tests {
232 use super::*;
233
234 fn transpose_naive(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
236 for i in 0..rows {
237 for j in 0..cols {
238 b[j * rows + i] = a[i * cols + j];
239 }
240 }
241 }
242
243 #[test]
246 fn falsify_tp_001_element_correctness() {
247 for (rows, cols) in [(4, 5), (8, 8), (16, 32), (31, 17), (64, 64)] {
248 let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
249 let mut b = vec![0.0f32; rows * cols];
250 transpose(rows, cols, &a, &mut b);
251
252 for i in 0..rows {
253 for j in 0..cols {
254 assert_eq!(
255 b[j * rows + i],
256 a[i * cols + j],
257 "Mismatch at ({i},{j}) for {rows}×{cols}"
258 );
259 }
260 }
261 }
262 }
263
264 #[test]
267 fn falsify_tp_002_involution() {
268 for (rows, cols) in [(7, 13), (16, 16), (33, 17), (64, 128)] {
269 let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.1 + 0.37).collect();
270 let mut b = vec![0.0f32; rows * cols];
271 let mut c = vec![0.0f32; rows * cols];
272
273 transpose(rows, cols, &a, &mut b);
274 transpose(cols, rows, &b, &mut c);
275
276 assert_eq!(a, c, "Involution failed for {rows}×{cols}");
277 }
278 }
279
280 #[test]
283 fn falsify_tp_003_non_aligned() {
284 for (rows, cols) in [(7, 13), (17, 3), (1, 32), (32, 1), (1, 1), (3, 3)] {
285 let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
286 let mut b_test = vec![0.0f32; rows * cols];
287 let mut b_ref = vec![0.0f32; rows * cols];
288
289 transpose(rows, cols, &a, &mut b_test);
290 transpose_naive(rows, cols, &a, &mut b_ref);
291
292 assert_eq!(b_test, b_ref, "Mismatch for {rows}×{cols}");
293 }
294 }
295
296 #[test]
298 fn falsify_tp_004_avx2_scalar_parity() {
299 let rows = 2048;
300 let cols = 128;
301 let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.001).collect();
302 let mut b_scalar = vec![0.0f32; rows * cols];
303 let mut b_dispatch = vec![0.0f32; rows * cols];
304
305 transpose_scalar(rows, cols, &a, &mut b_scalar);
306 transpose(rows, cols, &a, &mut b_dispatch);
307
308 assert_eq!(b_scalar, b_dispatch, "AVX2 vs scalar mismatch at 2048×128");
309 }
310
311 #[test]
314 fn falsify_tp_005_identity() {
315 for n in [4, 8, 16, 32] {
316 let mut a = vec![0.0f32; n * n];
317 for i in 0..n {
318 a[i * n + i] = 1.0;
319 }
320 let mut b = vec![0.0f32; n * n];
321 transpose(n, n, &a, &mut b);
322 assert_eq!(a, b, "Identity matrix not preserved for {n}×{n}");
323 }
324 }
325
326 #[test]
329 fn falsify_tp_006_attention_shape() {
330 let rows = 2048;
331 let cols = 128;
332 let a: Vec<f32> = (0..rows * cols)
333 .map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5)
334 .collect();
335 let mut b_test = vec![0.0f32; rows * cols];
336 let mut b_ref = vec![0.0f32; rows * cols];
337
338 transpose(rows, cols, &a, &mut b_test);
339 transpose_naive(rows, cols, &a, &mut b_ref);
340
341 assert_eq!(b_test, b_ref, "Attention shape 2048×128 mismatch");
342 }
343
344 #[test]
346 fn scalar_remainder_paths() {
347 for (rows, cols) in [(3, 5), (10, 13), (15, 9), (7, 7)] {
348 let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
349 let mut b_scalar = vec![0.0f32; rows * cols];
350 let mut b_ref = vec![0.0f32; rows * cols];
351
352 transpose_scalar(rows, cols, &a, &mut b_scalar);
353 transpose_naive(rows, cols, &a, &mut b_ref);
354
355 assert_eq!(b_scalar, b_ref, "Scalar mismatch for {rows}×{cols}");
356 }
357 }
358}