1use crate::error::TruenoError;
21
22#[inline(always)]
24fn transpose_region(
25 a: &[f32],
26 b: &mut [f32],
27 rows: std::ops::Range<usize>,
28 cols: std::ops::Range<usize>,
29 src_cols: usize,
30 dst_rows: usize,
31) {
32 for r in rows {
33 let src_base = r * src_cols;
34 for c in cols.clone() {
35 b[c * dst_rows + r] = a[src_base + c];
36 }
37 }
38}
39
40#[cfg(target_arch = "x86_64")]
51#[target_feature(enable = "avx2")]
52#[inline]
53unsafe fn transpose_8x8_avx2(src: *const f32, src_stride: usize, dst: *mut f32, dst_stride: usize) {
54 unsafe {
55 use std::arch::x86_64::*;
56
57 let r0 = _mm256_loadu_ps(src);
59 let r1 = _mm256_loadu_ps(src.add(src_stride));
60 let r2 = _mm256_loadu_ps(src.add(src_stride * 2));
61 let r3 = _mm256_loadu_ps(src.add(src_stride * 3));
62 let r4 = _mm256_loadu_ps(src.add(src_stride * 4));
63 let r5 = _mm256_loadu_ps(src.add(src_stride * 5));
64 let r6 = _mm256_loadu_ps(src.add(src_stride * 6));
65 let r7 = _mm256_loadu_ps(src.add(src_stride * 7));
66
67 let t0 = _mm256_unpacklo_ps(r0, r1);
69 let t1 = _mm256_unpackhi_ps(r0, r1);
70 let t2 = _mm256_unpacklo_ps(r2, r3);
71 let t3 = _mm256_unpackhi_ps(r2, r3);
72 let t4 = _mm256_unpacklo_ps(r4, r5);
73 let t5 = _mm256_unpackhi_ps(r4, r5);
74 let t6 = _mm256_unpacklo_ps(r6, r7);
75 let t7 = _mm256_unpackhi_ps(r6, r7);
76
77 let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
79 let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
80 let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
81 let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
82 let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
83 let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
84 let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
85 let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
86
87 let v0 = _mm256_permute2f128_ps(u0, u4, 0x20);
89 let v1 = _mm256_permute2f128_ps(u1, u5, 0x20);
90 let v2 = _mm256_permute2f128_ps(u2, u6, 0x20);
91 let v3 = _mm256_permute2f128_ps(u3, u7, 0x20);
92 let v4 = _mm256_permute2f128_ps(u0, u4, 0x31);
93 let v5 = _mm256_permute2f128_ps(u1, u5, 0x31);
94 let v6 = _mm256_permute2f128_ps(u2, u6, 0x31);
95 let v7 = _mm256_permute2f128_ps(u3, u7, 0x31);
96
97 _mm256_storeu_ps(dst, v0);
99 _mm256_storeu_ps(dst.add(dst_stride), v1);
100 _mm256_storeu_ps(dst.add(dst_stride * 2), v2);
101 _mm256_storeu_ps(dst.add(dst_stride * 3), v3);
102 _mm256_storeu_ps(dst.add(dst_stride * 4), v4);
103 _mm256_storeu_ps(dst.add(dst_stride * 5), v5);
104 _mm256_storeu_ps(dst.add(dst_stride * 6), v6);
105 _mm256_storeu_ps(dst.add(dst_stride * 7), v7);
106 }
107}
108
109pub fn transpose(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) -> Result<(), TruenoError> {
127 debug_assert!(!a.is_empty(), "Contract transpose: input is empty");
129 debug_assert!(rows > 0 && cols > 0, "Contract transpose: zero dimensions");
130 let expected = rows * cols;
131 if a.len() != expected || b.len() != expected {
132 return Err(TruenoError::InvalidInput(format!(
133 "transpose size mismatch: a[{}], b[{}], expected {}",
134 a.len(),
135 b.len(),
136 expected
137 )));
138 }
139
140 if expected < 64 {
141 transpose_region(a, b, 0..rows, 0..cols, cols, rows);
142 return Ok(());
143 }
144
145 #[cfg(feature = "parallel")]
150 {
151 const PARALLEL_THRESHOLD: usize = 1_000_000;
152 if expected >= PARALLEL_THRESHOLD {
153 return transpose_parallel(rows, cols, a, b);
154 }
155 }
156
157 #[cfg(target_arch = "x86_64")]
158 {
159 if is_x86_feature_detected!("avx2") {
160 unsafe {
163 return transpose_avx2_impl(rows, cols, a, b);
164 }
165 }
166 }
167
168 transpose_scalar_impl(rows, cols, a, b);
169 Ok(())
170}
171
172#[cfg(feature = "parallel")]
175fn transpose_parallel(
176 rows: usize,
177 cols: usize,
178 a: &[f32],
179 b: &mut [f32],
180) -> Result<(), TruenoError> {
181 use rayon::prelude::*;
182
183 let num_threads = rayon::current_num_threads().min(8);
184 let rows_per = (rows + num_threads - 1) / num_threads;
186
187 let b_ptr = b.as_mut_ptr() as usize;
192
193 (0..num_threads).into_par_iter().try_for_each(|t| {
194 let r_start = t * rows_per;
195 let r_end = (r_start + rows_per).min(rows);
196 if r_start >= r_end {
197 return Ok::<(), TruenoError>(());
198 }
199 let sub_rows = r_end - r_start;
200
201 let a_strip = &a[r_start * cols..r_end * cols];
203
204 unsafe {
206 let b_ptr_mut = b_ptr as *mut f32;
207 transpose_strided_avx2(sub_rows, cols, a_strip, b_ptr_mut.add(r_start), rows)?;
211 }
212 Ok(())
213 })?;
214 Ok(())
215}
216
217#[cfg(all(feature = "parallel", target_arch = "x86_64"))]
220unsafe fn transpose_strided_avx2(
221 sub_rows: usize,
222 cols: usize,
223 a: &[f32],
224 b_ptr: *mut f32,
225 b_stride: usize,
226) -> Result<(), TruenoError> {
227 const TILE: usize = 64; const BLOCK: usize = 8; let rb_end = sub_rows / BLOCK * BLOCK;
231 let cb_end = cols / BLOCK * BLOCK;
232
233 unsafe {
234 for rt in (0..rb_end).step_by(TILE) {
237 let rt_end = (rt + TILE).min(rb_end);
238 for ct in (0..cb_end).step_by(TILE) {
239 let ct_end = (ct + TILE).min(cb_end);
240 for r0 in (rt..rt_end).step_by(BLOCK) {
241 for c0 in (ct..ct_end).step_by(BLOCK) {
242 let src = a.as_ptr().add(r0 * cols + c0);
243 let dst = b_ptr.add(c0 * b_stride + r0);
244 transpose_8x8_avx2(src, cols, dst, b_stride);
245 }
246 }
247 }
248 }
249 if rb_end < sub_rows {
251 for r in rb_end..sub_rows {
252 for c in 0..cols {
253 *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
254 }
255 }
256 }
257 if cb_end < cols {
259 for r in 0..rb_end {
260 for c in cb_end..cols {
261 *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
262 }
263 }
264 }
265 }
266 Ok(())
267}
268
269#[cfg(all(feature = "parallel", not(target_arch = "x86_64")))]
271unsafe fn transpose_strided_avx2(
272 sub_rows: usize,
273 cols: usize,
274 a: &[f32],
275 b_ptr: *mut f32,
276 b_stride: usize,
277) -> Result<(), TruenoError> {
278 unsafe {
279 for r in 0..sub_rows {
280 for c in 0..cols {
281 *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
282 }
283 }
284 }
285 Ok(())
286}
287
288#[cfg(target_arch = "x86_64")]
304#[target_feature(enable = "avx2")]
305unsafe fn transpose_avx2_impl(
306 rows: usize,
307 cols: usize,
308 a: &[f32],
309 b: &mut [f32],
310) -> Result<(), TruenoError> {
311 use std::arch::x86_64::*;
312
313 const TILE: usize = 64; const BLOCK: usize = 8; let rb_end = rows / BLOCK * BLOCK;
317 let cb_end = cols / BLOCK * BLOCK;
318
319 let tall_skinny = rows >= 4 * cols;
326
327 unsafe {
328 for rt in (0..rb_end).step_by(TILE) {
329 let rt_end = (rt + TILE).min(rb_end);
330 for ct in (0..cb_end).step_by(TILE) {
331 let ct_end = (ct + TILE).min(cb_end);
332
333 if tall_skinny {
334 for c0 in (ct..ct_end).step_by(BLOCK) {
336 for r0 in (rt..rt_end).step_by(BLOCK) {
337 if r0 + BLOCK < rt_end {
339 let pf_dst = b.as_ptr().add(c0 * rows + r0 + BLOCK);
340 _mm_prefetch(pf_dst as *const i8, _MM_HINT_T0);
341 _mm_prefetch(pf_dst.add(rows) as *const i8, _MM_HINT_T0);
342 }
343 let src = a.as_ptr().add(r0 * cols + c0);
344 let dst = b.as_mut_ptr().add(c0 * rows + r0);
345 transpose_8x8_avx2(src, cols, dst, rows);
346 }
347 }
348 } else {
349 for r0 in (rt..rt_end).step_by(BLOCK) {
352 for c0 in (ct..ct_end).step_by(BLOCK) {
353 let src = a.as_ptr().add(r0 * cols + c0);
354 let dst = b.as_mut_ptr().add(c0 * rows + r0);
355 transpose_8x8_avx2(src, cols, dst, rows);
356 }
357 }
358 }
359 }
360 }
361 }
362
363 if cb_end < cols {
365 transpose_region(a, b, 0..rb_end, cb_end..cols, cols, rows);
366 }
367
368 if rb_end < rows {
370 transpose_region(a, b, rb_end..rows, 0..cols, cols, rows);
371 }
372
373 Ok(())
374}
375
376fn transpose_scalar_impl(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
378 const BLOCK: usize = 8;
379 let row_blocks = rows / BLOCK;
380 let col_blocks = cols / BLOCK;
381
382 for rb in 0..row_blocks {
383 for cb in 0..col_blocks {
384 let rs = rb * BLOCK;
385 let cs = cb * BLOCK;
386 transpose_region(a, b, rs..rs + BLOCK, cs..cs + BLOCK, cols, rows);
387 }
388 }
389
390 let col_rem = col_blocks * BLOCK;
391 if col_rem < cols {
392 transpose_region(a, b, 0..row_blocks * BLOCK, col_rem..cols, cols, rows);
393 }
394
395 let row_rem = row_blocks * BLOCK;
396 if row_rem < rows {
397 transpose_region(a, b, row_rem..rows, 0..cols, cols, rows);
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn transpose_naive(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
406 for i in 0..rows {
407 for j in 0..cols {
408 b[j * rows + i] = a[i * cols + j];
409 }
410 }
411 }
412
413 #[test]
415 fn test_element_correctness() {
416 for (rows, cols) in [(4, 5), (8, 8), (16, 32), (31, 17), (64, 64)] {
417 let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
418 let mut b = vec![0.0f32; rows * cols];
419 transpose(rows, cols, &a, &mut b).unwrap();
420
421 for i in 0..rows {
422 for j in 0..cols {
423 assert_eq!(b[j * rows + i], a[i * cols + j], "({i},{j}) {rows}×{cols}");
424 }
425 }
426 }
427 }
428
429 #[test]
431 fn test_involution() {
432 for (rows, cols) in [(7, 13), (16, 16), (33, 17), (64, 128)] {
433 let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.1 + 0.37).collect();
434 let mut b = vec![0.0f32; rows * cols];
435 let mut c = vec![0.0f32; rows * cols];
436
437 transpose(rows, cols, &a, &mut b).unwrap();
438 transpose(cols, rows, &b, &mut c).unwrap();
439
440 assert_eq!(a, c, "Involution failed for {rows}×{cols}");
441 }
442 }
443
444 #[test]
446 fn test_non_aligned() {
447 for (rows, cols) in [(7, 13), (17, 3), (1, 32), (32, 1), (1, 1), (3, 3)] {
448 let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
449 let mut b_test = vec![0.0f32; rows * cols];
450 let mut b_ref = vec![0.0f32; rows * cols];
451
452 transpose(rows, cols, &a, &mut b_test).unwrap();
453 transpose_naive(rows, cols, &a, &mut b_ref);
454
455 assert_eq!(b_test, b_ref, "Mismatch for {rows}×{cols}");
456 }
457 }
458
459 #[test]
461 fn test_avx2_scalar_parity() {
462 let rows = 2048;
463 let cols = 128;
464 let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.001).collect();
465 let mut b_scalar = vec![0.0f32; rows * cols];
466 let mut b_dispatch = vec![0.0f32; rows * cols];
467
468 transpose_scalar_impl(rows, cols, &a, &mut b_scalar);
469 transpose(rows, cols, &a, &mut b_dispatch).unwrap();
470
471 assert_eq!(b_scalar, b_dispatch, "AVX2 vs scalar mismatch at 2048×128");
472 }
473
474 #[test]
476 fn test_identity() {
477 for n in [4, 8, 16, 32] {
478 let mut a = vec![0.0f32; n * n];
479 for i in 0..n {
480 a[i * n + i] = 1.0;
481 }
482 let mut b = vec![0.0f32; n * n];
483 transpose(n, n, &a, &mut b).unwrap();
484 assert_eq!(a, b, "Identity not preserved for {n}×{n}");
485 }
486 }
487
488 #[test]
490 fn test_attention_shape() {
491 let rows = 2048;
492 let cols = 128;
493 let a: Vec<f32> =
494 (0..rows * cols).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
495 let mut b_test = vec![0.0f32; rows * cols];
496 let mut b_ref = vec![0.0f32; rows * cols];
497
498 transpose(rows, cols, &a, &mut b_test).unwrap();
499 transpose_naive(rows, cols, &a, &mut b_ref);
500
501 assert_eq!(b_test, b_ref, "Attention shape 2048×128 mismatch");
502 }
503
504 #[test]
505 fn test_dimension_mismatch() {
506 let a = vec![1.0f32; 12];
507 let mut b = vec![0.0f32; 10]; assert!(transpose(3, 4, &a, &mut b).is_err());
509 }
510
511 #[test]
512 fn test_small_matrix() {
513 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
515 let mut b = vec![0.0f32; 6];
516 transpose(2, 3, &a, &mut b).unwrap();
517 assert_eq!(b, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
518 }
519}