Skip to main content

provable_contracts/kernels/
transpose.rs

1//! Matrix transpose kernel: out-of-place B = A^T with AVX2 8×8 micro-kernel.
2//!
3//! Matches `transpose-kernel-v1.yaml`.
4//! Three phases: `outer_blocking` -> `avx2_8x8_microkernel` -> `remainder`.
5//!
6//! # Algorithm
7//!
8//! Process the matrix in 8×8 blocks. For each block, load 8 source rows
9//! into YMM registers, perform 3-phase in-register transpose (unpack →
10//! shuffle → permute), then store 8 transposed rows. Contiguous 32-byte
11//! stores coalesce cache misses (8 vs 64 in scalar).
12//!
13//! # References
14//!
15//! - Lam, Rothberg & Wolf (1991) Cache Performance of Blocked Algorithms
16//! - Intel Intrinsics Guide: _mm256_unpacklo_ps, _mm256_shuffle_ps, _mm256_permute2f128_ps
17
18use provable_contracts_macros::requires;
19
20#[cfg(target_arch = "x86_64")]
21use std::arch::x86_64::*;
22
23// ────────────────────────────────────────────────────────────────────────────
24// Scalar implementation
25// ────────────────────────────────────────────────────────────────────────────
26
27/// Scalar reference transpose: B[j * rows + i] = A[i * cols + j].
28///
29/// Uses 8×8 blocking for cache efficiency. Handles arbitrary dimensions
30/// via remainder loops for non-8-aligned edges.
31///
32/// # Panics
33///
34/// Panics if `a.len() != rows * cols` or `b.len() != rows * cols`.
35pub fn transpose_scalar(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
36    const BLOCK: usize = 8;
37
38    assert_eq!(a.len(), rows * cols, "a length mismatch");
39    assert_eq!(b.len(), rows * cols, "b length mismatch");
40
41    let rb_end = rows / BLOCK * BLOCK;
42    let cb_end = cols / BLOCK * BLOCK;
43
44    // Full 8×8 blocks
45    for r0 in (0..rb_end).step_by(BLOCK) {
46        for c0 in (0..cb_end).step_by(BLOCK) {
47            for r in r0..r0 + BLOCK {
48                let src_base = r * cols;
49                for c in c0..c0 + BLOCK {
50                    b[c * rows + r] = a[src_base + c];
51                }
52            }
53        }
54    }
55
56    // Right edge remainder (cols not divisible by 8)
57    if cb_end < cols {
58        for r in 0..rb_end {
59            let src_base = r * cols;
60            for c in cb_end..cols {
61                b[c * rows + r] = a[src_base + c];
62            }
63        }
64    }
65
66    // Bottom edge remainder (rows not divisible by 8)
67    if rb_end < rows {
68        for r in rb_end..rows {
69            let src_base = r * cols;
70            for c in 0..cols {
71                b[c * rows + r] = a[src_base + c];
72            }
73        }
74    }
75}
76
77// ────────────────────────────────────────────────────────────────────────────
78// AVX2 implementation
79// ────────────────────────────────────────────────────────────────────────────
80
81/// AVX2 8×8 in-register transpose micro-kernel.
82///
83/// Loads 8 rows of 8 f32 from source (stride = `src_stride` elements),
84/// performs 3-phase shuffle/permute, stores 8 transposed rows to dest
85/// (stride = `dst_stride` elements).
86///
87/// # Safety
88///
89/// Requires AVX2 support. Caller must ensure sufficient data at
90/// `src` and `dst` pointers (8 rows × stride elements each).
91#[cfg(target_arch = "x86_64")]
92#[target_feature(enable = "avx2")]
93#[inline]
94unsafe fn transpose_8x8_avx2(src: *const f32, src_stride: usize, dst: *mut f32, dst_stride: usize) {
95    unsafe {
96        // Load 8 source rows
97        let r0 = _mm256_loadu_ps(src);
98        let r1 = _mm256_loadu_ps(src.add(src_stride));
99        let r2 = _mm256_loadu_ps(src.add(src_stride * 2));
100        let r3 = _mm256_loadu_ps(src.add(src_stride * 3));
101        let r4 = _mm256_loadu_ps(src.add(src_stride * 4));
102        let r5 = _mm256_loadu_ps(src.add(src_stride * 5));
103        let r6 = _mm256_loadu_ps(src.add(src_stride * 6));
104        let r7 = _mm256_loadu_ps(src.add(src_stride * 7));
105
106        // Phase 1: Interleave adjacent row pairs using unpacklo/unpackhi
107        let t0 = _mm256_unpacklo_ps(r0, r1); // a0 b0 a1 b1 | a4 b4 a5 b5
108        let t1 = _mm256_unpackhi_ps(r0, r1); // a2 b2 a3 b3 | a6 b6 a7 b7
109        let t2 = _mm256_unpacklo_ps(r2, r3); // c0 d0 c1 d1 | c4 d4 c5 d5
110        let t3 = _mm256_unpackhi_ps(r2, r3); // c2 d2 c3 d3 | c6 d6 c7 d7
111        let t4 = _mm256_unpacklo_ps(r4, r5); // e0 f0 e1 f1 | e4 f4 e5 f5
112        let t5 = _mm256_unpackhi_ps(r4, r5); // e2 f2 e3 f3 | e6 f6 e7 f7
113        let t6 = _mm256_unpacklo_ps(r6, r7); // g0 h0 g1 h1 | g4 h4 g5 h5
114        let t7 = _mm256_unpackhi_ps(r6, r7); // g2 h2 g3 h3 | g6 h6 g7 h7
115
116        // Phase 2: Shuffle 64-bit pairs within 128-bit lanes
117        let u0 = _mm256_shuffle_ps(t0, t2, 0x44); // a0 b0 c0 d0 | a4 b4 c4 d4
118        let u1 = _mm256_shuffle_ps(t0, t2, 0xEE); // a1 b1 c1 d1 | a5 b5 c5 d5
119        let u2 = _mm256_shuffle_ps(t1, t3, 0x44); // a2 b2 c2 d2 | a6 b6 c6 d6
120        let u3 = _mm256_shuffle_ps(t1, t3, 0xEE); // a3 b3 c3 d3 | a7 b7 c7 d7
121        let u4 = _mm256_shuffle_ps(t4, t6, 0x44); // e0 f0 g0 h0 | e4 f4 g4 h4
122        let u5 = _mm256_shuffle_ps(t4, t6, 0xEE); // e1 f1 g1 h1 | e5 f5 g5 h5
123        let u6 = _mm256_shuffle_ps(t5, t7, 0x44); // e2 f2 g2 h2 | e6 f6 g6 h6
124        let u7 = _mm256_shuffle_ps(t5, t7, 0xEE); // e3 f3 g3 h3 | e7 f7 g7 h7
125
126        // Phase 3: Swap 128-bit halves across YMM registers
127        let v0 = _mm256_permute2f128_ps(u0, u4, 0x20); // row 0 of transpose
128        let v1 = _mm256_permute2f128_ps(u1, u5, 0x20); // row 1
129        let v2 = _mm256_permute2f128_ps(u2, u6, 0x20); // row 2
130        let v3 = _mm256_permute2f128_ps(u3, u7, 0x20); // row 3
131        let v4 = _mm256_permute2f128_ps(u0, u4, 0x31); // row 4
132        let v5 = _mm256_permute2f128_ps(u1, u5, 0x31); // row 5
133        let v6 = _mm256_permute2f128_ps(u2, u6, 0x31); // row 6
134        let v7 = _mm256_permute2f128_ps(u3, u7, 0x31); // row 7
135
136        // Store 8 transposed rows
137        _mm256_storeu_ps(dst, v0);
138        _mm256_storeu_ps(dst.add(dst_stride), v1);
139        _mm256_storeu_ps(dst.add(dst_stride * 2), v2);
140        _mm256_storeu_ps(dst.add(dst_stride * 3), v3);
141        _mm256_storeu_ps(dst.add(dst_stride * 4), v4);
142        _mm256_storeu_ps(dst.add(dst_stride * 5), v5);
143        _mm256_storeu_ps(dst.add(dst_stride * 6), v6);
144        _mm256_storeu_ps(dst.add(dst_stride * 7), v7);
145    }
146}
147
148/// AVX2 matrix transpose using 8×8 in-register micro-kernel.
149///
150/// Processes full 8×8 blocks with SIMD, remainder edges with scalar.
151/// Source row stride = `cols`, dest row stride = `rows` (transposed layout).
152///
153/// # Safety
154///
155/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
156///
157/// # Panics
158///
159/// Panics if `a.len() != rows * cols` or `b.len() != rows * cols`.
160#[cfg(target_arch = "x86_64")]
161#[target_feature(enable = "avx2")]
162pub unsafe fn transpose_avx2(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
163    assert_eq!(a.len(), rows * cols, "a length mismatch");
164    assert_eq!(b.len(), rows * cols, "b length mismatch");
165
166    let rb_end = rows / 8 * 8;
167    let cb_end = cols / 8 * 8;
168
169    // SAFETY: AVX2 verified by caller + target_feature gate.
170    unsafe {
171        // Full 8×8 blocks: AVX2 micro-kernel
172        for r0 in (0..rb_end).step_by(8) {
173            for c0 in (0..cb_end).step_by(8) {
174                let src = a.as_ptr().add(r0 * cols + c0);
175                let dst = b.as_mut_ptr().add(c0 * rows + r0);
176                transpose_8x8_avx2(src, cols, dst, rows);
177            }
178        }
179    }
180
181    // Right edge remainder (cols % 8 != 0): scalar
182    if cb_end < cols {
183        for r in 0..rb_end {
184            let src_base = r * cols;
185            for c in cb_end..cols {
186                b[c * rows + r] = a[src_base + c];
187            }
188        }
189    }
190
191    // Bottom edge remainder (rows % 8 != 0): scalar
192    if rb_end < rows {
193        for r in rb_end..rows {
194            let src_base = r * cols;
195            for c in 0..cols {
196                b[c * rows + r] = a[src_base + c];
197            }
198        }
199    }
200}
201
202// ────────────────────────────────────────────────────────────────────────────
203// Dispatch
204// ────────────────────────────────────────────────────────────────────────────
205
206/// Transpose a matrix: B = A^T. Dispatches to AVX2 or scalar.
207///
208/// # Panics
209///
210/// Panics if `a.len() != rows * cols` or `b.len() != rows * cols`.
211#[requires(a.len() == rows * cols && b.len() == rows * cols)]
212pub fn transpose(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
213    #[cfg(target_arch = "x86_64")]
214    {
215        if is_x86_feature_detected!("avx2") {
216            // SAFETY: AVX2 verified by feature detection above.
217            unsafe {
218                transpose_avx2(rows, cols, a, b);
219            }
220            return;
221        }
222    }
223    transpose_scalar(rows, cols, a, b);
224}
225
226// ────────────────────────────────────────────────────────────────────────────
227// Tests — contract falsification
228// ────────────────────────────────────────────────────────────────────────────
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    /// Naive reference transpose for validation.
235    fn transpose_naive(rows: usize, cols: usize, a: &[f32], b: &mut [f32]) {
236        for i in 0..rows {
237            for j in 0..cols {
238                b[j * rows + i] = a[i * cols + j];
239            }
240        }
241    }
242
243    /// FALSIFY-TP-001: Element correctness
244    /// transpose(A)[j][i] == A[i][j] for random A
245    #[test]
246    fn falsify_tp_001_element_correctness() {
247        for (rows, cols) in [(4, 5), (8, 8), (16, 32), (31, 17), (64, 64)] {
248            let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
249            let mut b = vec![0.0f32; rows * cols];
250            transpose(rows, cols, &a, &mut b);
251
252            for i in 0..rows {
253                for j in 0..cols {
254                    assert_eq!(
255                        b[j * rows + i],
256                        a[i * cols + j],
257                        "Mismatch at ({i},{j}) for {rows}×{cols}"
258                    );
259                }
260            }
261        }
262    }
263
264    /// FALSIFY-TP-002: Involution
265    /// transpose(transpose(A)) == A (bitwise exact)
266    #[test]
267    fn falsify_tp_002_involution() {
268        for (rows, cols) in [(7, 13), (16, 16), (33, 17), (64, 128)] {
269            let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.1 + 0.37).collect();
270            let mut b = vec![0.0f32; rows * cols];
271            let mut c = vec![0.0f32; rows * cols];
272
273            transpose(rows, cols, &a, &mut b);
274            transpose(cols, rows, &b, &mut c);
275
276            assert_eq!(a, c, "Involution failed for {rows}×{cols}");
277        }
278    }
279
280    /// FALSIFY-TP-003: Non-8-aligned dimensions
281    /// Correct for 7×13, 17×3, 1×N, N×1
282    #[test]
283    fn falsify_tp_003_non_aligned() {
284        for (rows, cols) in [(7, 13), (17, 3), (1, 32), (32, 1), (1, 1), (3, 3)] {
285            let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
286            let mut b_test = vec![0.0f32; rows * cols];
287            let mut b_ref = vec![0.0f32; rows * cols];
288
289            transpose(rows, cols, &a, &mut b_test);
290            transpose_naive(rows, cols, &a, &mut b_ref);
291
292            assert_eq!(b_test, b_ref, "Mismatch for {rows}×{cols}");
293        }
294    }
295
296    /// FALSIFY-TP-004: AVX2 vs scalar parity (bitwise exact)
297    #[test]
298    fn falsify_tp_004_avx2_scalar_parity() {
299        let rows = 2048;
300        let cols = 128;
301        let a: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.001).collect();
302        let mut b_scalar = vec![0.0f32; rows * cols];
303        let mut b_dispatch = vec![0.0f32; rows * cols];
304
305        transpose_scalar(rows, cols, &a, &mut b_scalar);
306        transpose(rows, cols, &a, &mut b_dispatch);
307
308        assert_eq!(b_scalar, b_dispatch, "AVX2 vs scalar mismatch at 2048×128");
309    }
310
311    /// FALSIFY-TP-005: Identity matrix
312    /// transpose(I) == I for square identity
313    #[test]
314    fn falsify_tp_005_identity() {
315        for n in [4, 8, 16, 32] {
316            let mut a = vec![0.0f32; n * n];
317            for i in 0..n {
318                a[i * n + i] = 1.0;
319            }
320            let mut b = vec![0.0f32; n * n];
321            transpose(n, n, &a, &mut b);
322            assert_eq!(a, b, "Identity matrix not preserved for {n}×{n}");
323        }
324    }
325
326    /// FALSIFY-TP-006: Attention shape (2048×128)
327    /// Matches naive reference
328    #[test]
329    fn falsify_tp_006_attention_shape() {
330        let rows = 2048;
331        let cols = 128;
332        let a: Vec<f32> = (0..rows * cols)
333            .map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5)
334            .collect();
335        let mut b_test = vec![0.0f32; rows * cols];
336        let mut b_ref = vec![0.0f32; rows * cols];
337
338        transpose(rows, cols, &a, &mut b_test);
339        transpose_naive(rows, cols, &a, &mut b_ref);
340
341        assert_eq!(b_test, b_ref, "Attention shape 2048×128 mismatch");
342    }
343
344    /// Cover scalar remainder paths (rows/cols not divisible by 8)
345    #[test]
346    fn scalar_remainder_paths() {
347        for (rows, cols) in [(3, 5), (10, 13), (15, 9), (7, 7)] {
348            let a: Vec<f32> = (0..rows * cols).map(|i| i as f32).collect();
349            let mut b_scalar = vec![0.0f32; rows * cols];
350            let mut b_ref = vec![0.0f32; rows * cols];
351
352            transpose_scalar(rows, cols, &a, &mut b_scalar);
353            transpose_naive(rows, cols, &a, &mut b_ref);
354
355            assert_eq!(b_scalar, b_ref, "Scalar mismatch for {rows}×{cols}");
356        }
357    }
358}