//
// Optimized GEMM (General Matrix Multiplication) kernel in PTX
// Implements tiled matrix multiplication for maximum memory coalescing
// Target: CUDA compute capability 6.0+
//
.version 7.0
.target sm_60
.address_size 64
// Kernel entry point for matrix multiplication C = A * B
.visible .entry cuda_gemm(
.param .u64 param_0, // A matrix pointer
.param .u64 param_1, // B matrix pointer
.param .u64 param_2, // C matrix pointer
.param .u32 param_3, // M dimension
.param .u32 param_4, // N dimension
.param .u32 param_5 // K dimension
)
{
.reg .pred %p<10>;
.reg .f32 %f<200>;
.reg .b32 %r<50>;
.reg .b64 %rd<20>;
// Shared memory for tiling
.shared .align 4 .b8 shared_A[1024]; // 16x16 tile of floats
.shared .align 4 .b8 shared_B[1024]; // 16x16 tile of floats
// Load parameters
ld.param.u64 %rd1, [param_0]; // A pointer
ld.param.u64 %rd2, [param_1]; // B pointer
ld.param.u64 %rd3, [param_2]; // C pointer
ld.param.u32 %r1, [param_3]; // M
ld.param.u32 %r2, [param_4]; // N
ld.param.u32 %r3, [param_5]; // K
// Thread and block indices
mov.u32 %r4, %ctaid.x; // blockIdx.x
mov.u32 %r5, %ctaid.y; // blockIdx.y
mov.u32 %r6, %tid.x; // threadIdx.x
mov.u32 %r7, %tid.y; // threadIdx.y
// Calculate global thread position
shl.b32 %r8, %r4, 4; // blockIdx.x * 16
add.s32 %r9, %r8, %r6; // col = blockIdx.x * 16 + threadIdx.x
shl.b32 %r10, %r5, 4; // blockIdx.y * 16
add.s32 %r11, %r10, %r7; // row = blockIdx.y * 16 + threadIdx.y
// Check bounds
setp.ge.s32 %p1, %r9, %r2; // col >= N
setp.ge.s32 %p2, %r11, %r1; // row >= M
or.pred %p3, %p1, %p2; // out of bounds
@%p3 bra EXIT;
// Initialize accumulator
mov.f32 %f1, 0.0;
// Calculate number of tiles
add.s32 %r12, %r3, 15; // K + 15
shr.s32 %r13, %r12, 4; // (K + 15) / 16 = num_tiles
// Tile loop
mov.s32 %r14, 0; // tile_idx
TILE_LOOP:
setp.ge.s32 %p4, %r14, %r13; // tile_idx >= num_tiles
@%p4 bra COMPUTE_DONE;
// Calculate tile offset in K dimension
shl.b32 %r15, %r14, 4; // tile_idx * 16
// Load tile A: shared_A[ty][tx] = A[row][tile_k + tx]
add.s32 %r16, %r15, %r6; // tile_k + threadIdx.x
setp.ge.s32 %p5, %r16, %r3; // tile_k + tx >= K
setp.ge.s32 %p6, %r11, %r1; // row >= M
or.pred %p7, %p5, %p6;
@%p7 bra LOAD_A_ZERO;
// Calculate A address: A[row * K + tile_k + tx]
mul.lo.s32 %r17, %r11, %r3; // row * K
add.s32 %r18, %r17, %r16; // row * K + tile_k + tx
shl.b64 %rd4, %r18, 2; // * sizeof(float)
add.s64 %rd5, %rd1, %rd4; // A + offset
ld.global.f32 %f2, [%rd5];
bra STORE_A;
LOAD_A_ZERO:
mov.f32 %f2, 0.0;
STORE_A:
// Store in shared memory: shared_A[ty * 16 + tx]
shl.b32 %r19, %r7, 4; // ty * 16
add.s32 %r20, %r19, %r6; // ty * 16 + tx
shl.b32 %r21, %r20, 2; // * sizeof(float)
add.s32 %r22, shared_A, %r21; // shared_A + offset
st.shared.f32 [%r22], %f2;
// Load tile B: shared_B[ty][tx] = B[(tile_k + ty) * N + col]
add.s32 %r23, %r15, %r7; // tile_k + threadIdx.y
setp.ge.s32 %p8, %r23, %r3; // tile_k + ty >= K
setp.ge.s32 %p9, %r9, %r2; // col >= N
or.pred %p10, %p8, %p9;
@%p10 bra LOAD_B_ZERO;
// Calculate B address: B[(tile_k + ty) * N + col]
mul.lo.s32 %r24, %r23, %r2; // (tile_k + ty) * N
add.s32 %r25, %r24, %r9; // (tile_k + ty) * N + col
shl.b64 %rd6, %r25, 2; // * sizeof(float)
add.s64 %rd7, %rd2, %rd6; // B + offset
ld.global.f32 %f3, [%rd7];
bra STORE_B;
LOAD_B_ZERO:
mov.f32 %f3, 0.0;
STORE_B:
// Store in shared memory: shared_B[ty * 16 + tx]
shl.b32 %r26, %r7, 4; // ty * 16
add.s32 %r27, %r26, %r6; // ty * 16 + tx
shl.b32 %r28, %r27, 2; // * sizeof(float)
add.s32 %r29, shared_B, %r28; // shared_B + offset
st.shared.f32 [%r29], %f3;
// Synchronize threads
bar.sync 0;
// Compute tile multiplication
mov.s32 %r30, 0; // k
INNER_LOOP:
setp.ge.s32 %p11, %r30, 16; // k >= 16
@%p11 bra INNER_DONE;
// Load from shared memory
shl.b32 %r31, %r7, 4; // ty * 16
add.s32 %r32, %r31, %r30; // ty * 16 + k
shl.b32 %r33, %r32, 2; // * sizeof(float)
add.s32 %r34, shared_A, %r33; // shared_A + offset
ld.shared.f32 %f4, [%r34]; // A_tile[ty][k]
shl.b32 %r35, %r30, 4; // k * 16
add.s32 %r36, %r35, %r6; // k * 16 + tx
shl.b32 %r37, %r36, 2; // * sizeof(float)
add.s32 %r38, shared_B, %r37; // shared_B + offset
ld.shared.f32 %f5, [%r38]; // B_tile[k][tx]
// Accumulate: C += A * B
fma.rn.f32 %f1, %f4, %f5, %f1;
add.s32 %r30, %r30, 1; // k++
bra INNER_LOOP;
INNER_DONE:
// Synchronize before next tile
bar.sync 0;
add.s32 %r14, %r14, 1; // tile_idx++
bra TILE_LOOP;
COMPUTE_DONE:
// Store result: C[row * N + col] = accumulator
mul.lo.s32 %r39, %r11, %r2; // row * N
add.s32 %r40, %r39, %r9; // row * N + col
shl.b64 %rd8, %r40, 2; // * sizeof(float)
add.s64 %rd9, %rd3, %rd8; // C + offset
st.global.f32 [%rd9], %f1;
EXIT:
ret;
}