Skip to main content

trueno/blis/
transpose.rs

1//! Matrix transpose operations.
2//!
3//! AVX2 8×8 in-register transpose with scalar fallback for small matrices
4//! and non-8-aligned remainder edges.
5//!
6//! # Algorithm
7//!
8//! Process matrix in 8×8 blocks. For each block, load 8 rows into YMM
9//! registers, perform 3-phase shuffle/permute transpose, store 8 transposed
10//! rows. Contiguous 32-byte stores coalesce cache misses (8 vs 64 in scalar).
11//!
12//! Contract: provable-contracts/contracts/transpose-kernel-v1.yaml
13//!
14//! # References
15//!
16//! - Lam, Rothberg & Wolf (1991). Cache Performance of Blocked Algorithms. ASPLOS
17//! - Intel Intrinsics Guide: _mm256_unpacklo_ps, _mm256_shuffle_ps, _mm256_permute2f128_ps
18//! - GH-388: transpose 242x slower than ndarray at attention shapes
19
20use crate::error::TruenoError;
21
22/// Scalar transpose of a sub-region of a row-major matrix.
23#[inline(always)]
24fn transpose_region(
25    a: &[f32],
26    b: &mut [f32],
27    rows: std::ops::Range<usize>,
28    cols: std::ops::Range<usize>,
29    src_cols: usize,
30    dst_rows: usize,
31) {
32    for r in rows {
33        let src_base = r * src_cols;
34        for c in cols.clone() {
35            b[c * dst_rows + r] = a[src_base + c];
36        }
37    }
38}
39
40/// AVX2 8×8 in-register transpose micro-kernel.
41///
42/// Loads 8 rows of 8 f32 from source (stride = `src_stride` elements),
43/// performs 3-phase shuffle/permute, stores 8 transposed rows to dest
44/// (stride = `dst_stride` elements).
45///
46/// # Safety
47///
48/// Requires AVX2 support. Caller must ensure sufficient data at
49/// `src` and `dst` pointers (8 rows × stride elements each).
50#[cfg(target_arch = "x86_64")]
51#[target_feature(enable = "avx2")]
52#[inline]
53unsafe fn transpose_8x8_avx2(src: *const f32, src_stride: usize, dst: *mut f32, dst_stride: usize) {
54    unsafe {
55        use std::arch::x86_64::*;
56
57        // Load 8 source rows
58        let r0 = _mm256_loadu_ps(src);
59        let r1 = _mm256_loadu_ps(src.add(src_stride));
60        let r2 = _mm256_loadu_ps(src.add(src_stride * 2));
61        let r3 = _mm256_loadu_ps(src.add(src_stride * 3));
62        let r4 = _mm256_loadu_ps(src.add(src_stride * 4));
63        let r5 = _mm256_loadu_ps(src.add(src_stride * 5));
64        let r6 = _mm256_loadu_ps(src.add(src_stride * 6));
65        let r7 = _mm256_loadu_ps(src.add(src_stride * 7));
66
67        // Phase 1: Interleave adjacent row pairs
68        let t0 = _mm256_unpacklo_ps(r0, r1);
69        let t1 = _mm256_unpackhi_ps(r0, r1);
70        let t2 = _mm256_unpacklo_ps(r2, r3);
71        let t3 = _mm256_unpackhi_ps(r2, r3);
72        let t4 = _mm256_unpacklo_ps(r4, r5);
73        let t5 = _mm256_unpackhi_ps(r4, r5);
74        let t6 = _mm256_unpacklo_ps(r6, r7);
75        let t7 = _mm256_unpackhi_ps(r6, r7);
76
77        // Phase 2: Shuffle 64-bit pairs within 128-bit lanes
78        let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
79        let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
80        let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
81        let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
82        let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
83        let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
84        let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
85        let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
86
87        // Phase 3: Swap 128-bit halves across YMM registers
88        let v0 = _mm256_permute2f128_ps(u0, u4, 0x20);
89        let v1 = _mm256_permute2f128_ps(u1, u5, 0x20);
90        let v2 = _mm256_permute2f128_ps(u2, u6, 0x20);
91        let v3 = _mm256_permute2f128_ps(u3, u7, 0x20);
92        let v4 = _mm256_permute2f128_ps(u0, u4, 0x31);
93        let v5 = _mm256_permute2f128_ps(u1, u5, 0x31);
94        let v6 = _mm256_permute2f128_ps(u2, u6, 0x31);
95        let v7 = _mm256_permute2f128_ps(u3, u7, 0x31);
96
97        // Store 8 transposed rows
98        _mm256_storeu_ps(dst, v0);
99        _mm256_storeu_ps(dst.add(dst_stride), v1);
100        _mm256_storeu_ps(dst.add(dst_stride * 2), v2);
101        _mm256_storeu_ps(dst.add(dst_stride * 3), v3);
102        _mm256_storeu_ps(dst.add(dst_stride * 4), v4);
103        _mm256_storeu_ps(dst.add(dst_stride * 5), v5);
104        _mm256_storeu_ps(dst.add(dst_stride * 6), v6);
105        _mm256_storeu_ps(dst.add(dst_stride * 7), v7);
106    }
107}
108
109/// Transpose a matrix: B = A^T
110///
111/// Uses AVX2 8×8 in-register micro-kernel for full blocks, scalar for
112/// remainder edges. Runtime feature detection selects AVX2 or scalar.
113///
114/// Contract: transpose-kernel-v1, equations "transpose"
115///
116/// # Arguments
117///
118/// * `rows` - Number of rows in A (cols in B)
119/// * `cols` - Number of cols in A (rows in B)
120/// * `a` - Input matrix A (rows x cols, row-major)
121/// * `b` - Output matrix B (cols x rows, row-major)
122///
123/// # Returns
124///
125/// `Ok(())` on success, `Err` if dimensions mismatch
126pub fn transpose(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) -> Result<(), TruenoError> {
127    // Contract: transpose-kernel-v1.yaml, equation = transpose
128    debug_assert!(!a.is_empty(), "Contract transpose: input is empty");
129    debug_assert!(rows > 0 && cols > 0, "Contract transpose: zero dimensions");
130    let expected = rows * cols;
131    if a.len() != expected || b.len() != expected {
132        return Err(TruenoError::InvalidInput(format!(
133            "transpose size mismatch: a[{}], b[{}], expected {}",
134            a.len(),
135            b.len(),
136            expected
137        )));
138    }
139
140    if expected < 64 {
141        transpose_region(a, b, 0..rows, 0..cols, cols, rows);
142        return Ok(());
143    }
144
145    // Parallel transpose for large matrices (DRAM-bound → multi-channel).
146    // CGP-DBUF: lowered threshold from 4M to 1M. Previous regression at 1M was
147    // likely from higher thread::scope overhead; Rayon dispatch is ~3µs.
148    // 1024×1024 = 1M elements, 4MB traffic, ~200µs single-thread → 3µs is 1.5%.
149    #[cfg(feature = "parallel")]
150    {
151        const PARALLEL_THRESHOLD: usize = 1_000_000;
152        if expected >= PARALLEL_THRESHOLD {
153            return transpose_parallel(rows, cols, a, b);
154        }
155    }
156
157    #[cfg(target_arch = "x86_64")]
158    {
159        if is_x86_feature_detected!("avx2") {
160            // SAFETY: AVX2 verified by feature detection above.
161            // Slice bounds: 8×8 blocks within rows×cols guaranteed by loop bounds.
162            unsafe {
163                return transpose_avx2_impl(rows, cols, a, b);
164            }
165        }
166    }
167
168    transpose_scalar_impl(rows, cols, a, b);
169    Ok(())
170}
171
172/// Parallel transpose: chunk A by row ranges, each thread transposes its strip.
173/// Each thread writes to a disjoint COLUMN range of B (non-overlapping mutations).
174#[cfg(feature = "parallel")]
175fn transpose_parallel(
176    rows: usize,
177    cols: usize,
178    a: &[f32],
179    b: &mut [f32],
180) -> Result<(), TruenoError> {
181    use rayon::prelude::*;
182
183    let num_threads = rayon::current_num_threads().min(8);
184    // Chunk A's rows; each thread handles rows_per rows of A = cols_per (in B)
185    let rows_per = (rows + num_threads - 1) / num_threads;
186
187    // SAFETY: each thread writes to disjoint slices of B via raw pointer.
188    // Thread t writes B[c, rt..rt+rows_per] for all c in 0..cols — disjoint row ranges in A
189    // map to disjoint COLUMN ranges in B; B is accessed via (c*rows + r) indexing so
190    // threads write non-overlapping memory regions (separated by rows_per elements per row).
191    let b_ptr = b.as_mut_ptr() as usize;
192
193    (0..num_threads).into_par_iter().try_for_each(|t| {
194        let r_start = t * rows_per;
195        let r_end = (r_start + rows_per).min(rows);
196        if r_start >= r_end {
197            return Ok::<(), TruenoError>(());
198        }
199        let sub_rows = r_end - r_start;
200
201        // Input strip: A[r_start..r_end, 0..cols]
202        let a_strip = &a[r_start * cols..r_end * cols];
203
204        // SAFETY: see above - disjoint column range in B
205        unsafe {
206            let b_ptr_mut = b_ptr as *mut f32;
207            // For this strip, transpose into B using absolute B coordinates.
208            // B[c, r] for r in r_start..r_end, c in 0..cols.
209            // Equivalent: substrip of B has shape [cols, sub_rows] with stride rows.
210            transpose_strided_avx2(sub_rows, cols, a_strip, b_ptr_mut.add(r_start), rows)?;
211        }
212        Ok(())
213    })?;
214    Ok(())
215}
216
217/// Transpose A[sub_rows×cols] → B viewed as [cols × ...] with given b_stride (= rows of full B).
218/// Writes into B at offsets (c * b_stride + r) for r in 0..sub_rows, c in 0..cols.
219#[cfg(all(feature = "parallel", target_arch = "x86_64"))]
220unsafe fn transpose_strided_avx2(
221    sub_rows: usize,
222    cols: usize,
223    a: &[f32],
224    b_ptr: *mut f32,
225    b_stride: usize,
226) -> Result<(), TruenoError> {
227    const TILE: usize = 64; // L1-resident outer tile (16KB working set)
228    const BLOCK: usize = 8; // AVX2 micro-kernel
229
230    let rb_end = sub_rows / BLOCK * BLOCK;
231    let cb_end = cols / BLOCK * BLOCK;
232
233    unsafe {
234        // Two-level tiling: 64×64 outer (L1) + 8×8 inner (AVX2)
235        // Standard inner order (r0 outer, c0 inner): sequential A reads.
236        for rt in (0..rb_end).step_by(TILE) {
237            let rt_end = (rt + TILE).min(rb_end);
238            for ct in (0..cb_end).step_by(TILE) {
239                let ct_end = (ct + TILE).min(cb_end);
240                for r0 in (rt..rt_end).step_by(BLOCK) {
241                    for c0 in (ct..ct_end).step_by(BLOCK) {
242                        let src = a.as_ptr().add(r0 * cols + c0);
243                        let dst = b_ptr.add(c0 * b_stride + r0);
244                        transpose_8x8_avx2(src, cols, dst, b_stride);
245                    }
246                }
247            }
248        }
249        // Remainder: edge rows
250        if rb_end < sub_rows {
251            for r in rb_end..sub_rows {
252                for c in 0..cols {
253                    *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
254                }
255            }
256        }
257        // Remainder: edge cols (for row-aligned part)
258        if cb_end < cols {
259            for r in 0..rb_end {
260                for c in cb_end..cols {
261                    *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
262                }
263            }
264        }
265    }
266    Ok(())
267}
268
269/// Scalar fallback strided transpose (for non-x86 or no-AVX2).
270#[cfg(all(feature = "parallel", not(target_arch = "x86_64")))]
271unsafe fn transpose_strided_avx2(
272    sub_rows: usize,
273    cols: usize,
274    a: &[f32],
275    b_ptr: *mut f32,
276    b_stride: usize,
277) -> Result<(), TruenoError> {
278    unsafe {
279        for r in 0..sub_rows {
280            for c in 0..cols {
281                *b_ptr.add(c * b_stride + r) = *a.get_unchecked(r * cols + c);
282            }
283        }
284    }
285    Ok(())
286}
287
288/// AVX2 transpose with two-level tiling: 64×64 outer (L1), 8×8 inner (AVX2).
289///
290/// Two-level tiling keeps the working set within L1 cache (64×64×4 = 16KB < 32KB).
291///
292/// **Shape-adaptive loop order**:
293/// - Tall-skinny (rows ≥ 4×cols): inner loop over row-blocks (r0), outer column-blocks.
294///   This makes destination writes sequential: B[c0..c0+8, r0], B[c0..c0+8, r0+8], ...
295///   are adjacent in memory, maximizing cache line reuse on the write side.
296/// - Otherwise: inner loop over column-blocks (standard order).
297///
298/// Software prefetch hints for the next micro-kernel's destination lines.
299///
300/// # Safety
301///
302/// Requires AVX2 support.
303#[cfg(target_arch = "x86_64")]
304#[target_feature(enable = "avx2")]
305unsafe fn transpose_avx2_impl(
306    rows: usize,
307    cols: usize,
308    a: &[f32],
309    b: &mut [f32],
310) -> Result<(), TruenoError> {
311    use std::arch::x86_64::*;
312
313    const TILE: usize = 64; // L1-resident outer tile
314    const BLOCK: usize = 8; // AVX2 micro-kernel
315
316    let rb_end = rows / BLOCK * BLOCK;
317    let cb_end = cols / BLOCK * BLOCK;
318
319    // Tall-skinny: rows >> cols → destination stride (=rows) is large.
320    // Swap loop order so inner loop walks consecutive r0 values,
321    // making destination writes sequential within each cache line.
322    // Tested rows >= 1024 extension (2026-04-05): regressed 1024×1024
323    // from 25 → 15 GB/s. Standard order is better for square matrices
324    // due to sequential A reads (prefetcher friendly).
325    let tall_skinny = rows >= 4 * cols;
326
327    unsafe {
328        for rt in (0..rb_end).step_by(TILE) {
329            let rt_end = (rt + TILE).min(rb_end);
330            for ct in (0..cb_end).step_by(TILE) {
331                let ct_end = (ct + TILE).min(cb_end);
332
333                if tall_skinny {
334                    // Outer c0, inner r0: destination writes are sequential
335                    for c0 in (ct..ct_end).step_by(BLOCK) {
336                        for r0 in (rt..rt_end).step_by(BLOCK) {
337                            // Prefetch next micro-kernel's destination
338                            if r0 + BLOCK < rt_end {
339                                let pf_dst = b.as_ptr().add(c0 * rows + r0 + BLOCK);
340                                _mm_prefetch(pf_dst as *const i8, _MM_HINT_T0);
341                                _mm_prefetch(pf_dst.add(rows) as *const i8, _MM_HINT_T0);
342                            }
343                            let src = a.as_ptr().add(r0 * cols + c0);
344                            let dst = b.as_mut_ptr().add(c0 * rows + r0);
345                            transpose_8x8_avx2(src, cols, dst, rows);
346                        }
347                    }
348                } else {
349                    // Square/wide: standard order (no prefetch — at large strides
350                    // the destination is too far apart for L1 prefetch to help)
351                    for r0 in (rt..rt_end).step_by(BLOCK) {
352                        for c0 in (ct..ct_end).step_by(BLOCK) {
353                            let src = a.as_ptr().add(r0 * cols + c0);
354                            let dst = b.as_mut_ptr().add(c0 * rows + r0);
355                            transpose_8x8_avx2(src, cols, dst, rows);
356                        }
357                    }
358                }
359            }
360        }
361    }
362
363    // Right edge remainder (cols % 8 != 0): scalar
364    if cb_end < cols {
365        transpose_region(a, b, 0..rb_end, cb_end..cols, cols, rows);
366    }
367
368    // Bottom edge remainder (rows % 8 != 0): scalar
369    if rb_end < rows {
370        transpose_region(a, b, rb_end..rows, 0..cols, cols, rows);
371    }
372
373    Ok(())
374}
375
376/// Scalar transpose with 8×8 blocking.
377fn transpose_scalar_impl(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
378    const BLOCK: usize = 8;
379    let row_blocks = rows / BLOCK;
380    let col_blocks = cols / BLOCK;
381
382    for rb in 0..row_blocks {
383        for cb in 0..col_blocks {
384            let rs = rb * BLOCK;
385            let cs = cb * BLOCK;
386            transpose_region(a, b, rs..rs + BLOCK, cs..cs + BLOCK, cols, rows);
387        }
388    }
389
390    let col_rem = col_blocks * BLOCK;
391    if col_rem < cols {
392        transpose_region(a, b, 0..row_blocks * BLOCK, col_rem..cols, cols, rows);
393    }
394
395    let row_rem = row_blocks * BLOCK;
396    if row_rem < rows {
397        transpose_region(a, b, row_rem..rows, 0..cols, cols, rows);
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn transpose_naive(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
406        for i in 0..rows {
407            for j in 0..cols {
408                b[j * rows + i] = a[i * cols + j];
409            }
410        }
411    }
412
413    /// FALSIFY-TP-001: Element correctness
414    #[test]
415    fn test_element_correctness() {
416        for (rows, cols) in [(4, 5), (8, 8), (16, 32), (31, 17), (64, 64)] {
417            let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
418            let mut b = vec![0.0f32; rows * cols];
419            transpose(rows, cols, &a, &mut b).unwrap();
420
421            for i in 0..rows {
422                for j in 0..cols {
423                    assert_eq!(b[j * rows + i], a[i * cols + j], "({i},{j}) {rows}×{cols}");
424                }
425            }
426        }
427    }
428
429    /// FALSIFY-TP-002: Involution
430    #[test]
431    fn test_involution() {
432        for (rows, cols) in [(7, 13), (16, 16), (33, 17), (64, 128)] {
433            let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.1 + 0.37).collect();
434            let mut b = vec![0.0f32; rows * cols];
435            let mut c = vec![0.0f32; rows * cols];
436
437            transpose(rows, cols, &a, &mut b).unwrap();
438            transpose(cols, rows, &b, &mut c).unwrap();
439
440            assert_eq!(a, c, "Involution failed for {rows}×{cols}");
441        }
442    }
443
444    /// FALSIFY-TP-003: Non-8-aligned dimensions
445    #[test]
446    fn test_non_aligned() {
447        for (rows, cols) in [(7, 13), (17, 3), (1, 32), (32, 1), (1, 1), (3, 3)] {
448            let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
449            let mut b_test = vec![0.0f32; rows * cols];
450            let mut b_ref = vec![0.0f32; rows * cols];
451
452            transpose(rows, cols, &a, &mut b_test).unwrap();
453            transpose_naive(rows, cols, &a, &mut b_ref);
454
455            assert_eq!(b_test, b_ref, "Mismatch for {rows}×{cols}");
456        }
457    }
458
459    /// FALSIFY-TP-004: AVX2 vs scalar parity (bitwise exact)
460    #[test]
461    fn test_avx2_scalar_parity() {
462        let rows = 2048;
463        let cols = 128;
464        let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.001).collect();
465        let mut b_scalar = vec![0.0f32; rows * cols];
466        let mut b_dispatch = vec![0.0f32; rows * cols];
467
468        transpose_scalar_impl(rows, cols, &a, &mut b_scalar);
469        transpose(rows, cols, &a, &mut b_dispatch).unwrap();
470
471        assert_eq!(b_scalar, b_dispatch, "AVX2 vs scalar mismatch at 2048×128");
472    }
473
474    /// FALSIFY-TP-005: Identity matrix
475    #[test]
476    fn test_identity() {
477        for n in [4, 8, 16, 32] {
478            let mut a = vec![0.0f32; n * n];
479            for i in 0..n {
480                a[i * n + i] = 1.0;
481            }
482            let mut b = vec![0.0f32; n * n];
483            transpose(n, n, &a, &mut b).unwrap();
484            assert_eq!(a, b, "Identity not preserved for {n}×{n}");
485        }
486    }
487
488    /// FALSIFY-TP-006: Attention shape (2048×128)
489    #[test]
490    fn test_attention_shape() {
491        let rows = 2048;
492        let cols = 128;
493        let a: Vec<f32> =
494            (0..rows * cols).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect();
495        let mut b_test = vec![0.0f32; rows * cols];
496        let mut b_ref = vec![0.0f32; rows * cols];
497
498        transpose(rows, cols, &a, &mut b_test).unwrap();
499        transpose_naive(rows, cols, &a, &mut b_ref);
500
501        assert_eq!(b_test, b_ref, "Attention shape 2048×128 mismatch");
502    }
503
504    #[test]
505    fn test_dimension_mismatch() {
506        let a = vec![1.0f32; 12];
507        let mut b = vec![0.0f32; 10]; // wrong size
508        assert!(transpose(3, 4, &a, &mut b).is_err());
509    }
510
511    #[test]
512    fn test_small_matrix() {
513        // Below 64 elements threshold — uses scalar directly
514        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
515        let mut b = vec![0.0f32; 6];
516        transpose(2, 3, &a, &mut b).unwrap();
517        assert_eq!(b, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
518    }
519}