tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
//
// Optimized GEMM (General Matrix Multiplication) kernel in PTX
// Implements tiled matrix multiplication for maximum memory coalescing
// Target: CUDA compute capability 6.0+
//

.version 7.0
.target sm_60
.address_size 64

// Kernel entry point for matrix multiplication C = A * B
.visible .entry cuda_gemm(
    .param .u64 param_0, // A matrix pointer
    .param .u64 param_1, // B matrix pointer  
    .param .u64 param_2, // C matrix pointer
    .param .u32 param_3, // M dimension
    .param .u32 param_4, // N dimension
    .param .u32 param_5  // K dimension
)
{
    .reg .pred      %p<10>;
    .reg .f32       %f<200>;
    .reg .b32       %r<50>;
    .reg .b64       %rd<20>;
    
    // Shared memory for tiling
    .shared .align 4 .b8 shared_A[1024];  // 16x16 tile of floats
    .shared .align 4 .b8 shared_B[1024];  // 16x16 tile of floats
    
    // Load parameters
    ld.param.u64    %rd1, [param_0];      // A pointer
    ld.param.u64    %rd2, [param_1];      // B pointer
    ld.param.u64    %rd3, [param_2];      // C pointer
    ld.param.u32    %r1, [param_3];       // M
    ld.param.u32    %r2, [param_4];       // N
    ld.param.u32    %r3, [param_5];       // K
    
    // Thread and block indices
    mov.u32         %r4, %ctaid.x;        // blockIdx.x
    mov.u32         %r5, %ctaid.y;        // blockIdx.y
    mov.u32         %r6, %tid.x;          // threadIdx.x
    mov.u32         %r7, %tid.y;          // threadIdx.y
    
    // Calculate global thread position
    shl.b32         %r8, %r4, 4;          // blockIdx.x * 16
    add.s32         %r9, %r8, %r6;        // col = blockIdx.x * 16 + threadIdx.x
    shl.b32         %r10, %r5, 4;         // blockIdx.y * 16
    add.s32         %r11, %r10, %r7;      // row = blockIdx.y * 16 + threadIdx.y
    
    // Check bounds
    setp.ge.s32     %p1, %r9, %r2;        // col >= N
    setp.ge.s32     %p2, %r11, %r1;       // row >= M
    or.pred         %p3, %p1, %p2;        // out of bounds
    @%p3 bra        EXIT;
    
    // Initialize accumulator
    mov.f32         %f1, 0.0;
    
    // Calculate number of tiles
    add.s32         %r12, %r3, 15;        // K + 15
    shr.s32         %r13, %r12, 4;        // (K + 15) / 16 = num_tiles
    
    // Tile loop
    mov.s32         %r14, 0;              // tile_idx
TILE_LOOP:
    setp.ge.s32     %p4, %r14, %r13;      // tile_idx >= num_tiles
    @%p4 bra        COMPUTE_DONE;
    
    // Calculate tile offset in K dimension
    shl.b32         %r15, %r14, 4;        // tile_idx * 16
    
    // Load tile A: shared_A[ty][tx] = A[row][tile_k + tx]
    add.s32         %r16, %r15, %r6;      // tile_k + threadIdx.x
    setp.ge.s32     %p5, %r16, %r3;       // tile_k + tx >= K
    setp.ge.s32     %p6, %r11, %r1;       // row >= M
    or.pred         %p7, %p5, %p6;
    
    @%p7 bra        LOAD_A_ZERO;
    
    // Calculate A address: A[row * K + tile_k + tx]
    mul.lo.s32      %r17, %r11, %r3;      // row * K
    add.s32         %r18, %r17, %r16;     // row * K + tile_k + tx
    shl.b64         %rd4, %r18, 2;        // * sizeof(float)
    add.s64         %rd5, %rd1, %rd4;     // A + offset
    ld.global.f32   %f2, [%rd5];
    bra             STORE_A;
    
LOAD_A_ZERO:
    mov.f32         %f2, 0.0;
    
STORE_A:
    // Store in shared memory: shared_A[ty * 16 + tx]
    shl.b32         %r19, %r7, 4;         // ty * 16
    add.s32         %r20, %r19, %r6;      // ty * 16 + tx
    shl.b32         %r21, %r20, 2;        // * sizeof(float)
    add.s32         %r22, shared_A, %r21; // shared_A + offset
    st.shared.f32   [%r22], %f2;
    
    // Load tile B: shared_B[ty][tx] = B[(tile_k + ty) * N + col]
    add.s32         %r23, %r15, %r7;      // tile_k + threadIdx.y
    setp.ge.s32     %p8, %r23, %r3;       // tile_k + ty >= K
    setp.ge.s32     %p9, %r9, %r2;        // col >= N
    or.pred         %p10, %p8, %p9;
    
    @%p10 bra       LOAD_B_ZERO;
    
    // Calculate B address: B[(tile_k + ty) * N + col]
    mul.lo.s32      %r24, %r23, %r2;      // (tile_k + ty) * N
    add.s32         %r25, %r24, %r9;      // (tile_k + ty) * N + col
    shl.b64         %rd6, %r25, 2;        // * sizeof(float)
    add.s64         %rd7, %rd2, %rd6;     // B + offset
    ld.global.f32   %f3, [%rd7];
    bra             STORE_B;
    
LOAD_B_ZERO:
    mov.f32         %f3, 0.0;
    
STORE_B:
    // Store in shared memory: shared_B[ty * 16 + tx]
    shl.b32         %r26, %r7, 4;         // ty * 16
    add.s32         %r27, %r26, %r6;      // ty * 16 + tx
    shl.b32         %r28, %r27, 2;        // * sizeof(float)
    add.s32         %r29, shared_B, %r28; // shared_B + offset
    st.shared.f32   [%r29], %f3;
    
    // Synchronize threads
    bar.sync        0;
    
    // Compute tile multiplication
    mov.s32         %r30, 0;              // k
INNER_LOOP:
    setp.ge.s32     %p11, %r30, 16;       // k >= 16
    @%p11 bra       INNER_DONE;
    
    // Load from shared memory
    shl.b32         %r31, %r7, 4;         // ty * 16
    add.s32         %r32, %r31, %r30;     // ty * 16 + k
    shl.b32         %r33, %r32, 2;        // * sizeof(float)
    add.s32         %r34, shared_A, %r33; // shared_A + offset
    ld.shared.f32   %f4, [%r34];          // A_tile[ty][k]
    
    shl.b32         %r35, %r30, 4;        // k * 16
    add.s32         %r36, %r35, %r6;      // k * 16 + tx
    shl.b32         %r37, %r36, 2;        // * sizeof(float)
    add.s32         %r38, shared_B, %r37; // shared_B + offset
    ld.shared.f32   %f5, [%r38];          // B_tile[k][tx]
    
    // Accumulate: C += A * B
    fma.rn.f32      %f1, %f4, %f5, %f1;
    
    add.s32         %r30, %r30, 1;        // k++
    bra             INNER_LOOP;
    
INNER_DONE:
    // Synchronize before next tile
    bar.sync        0;
    
    add.s32         %r14, %r14, 1;        // tile_idx++
    bra             TILE_LOOP;
    
COMPUTE_DONE:
    // Store result: C[row * N + col] = accumulator
    mul.lo.s32      %r39, %r11, %r2;      // row * N
    add.s32         %r40, %r39, %r9;      // row * N + col
    shl.b64         %rd8, %r40, 2;        // * sizeof(float)
    add.s64         %rd9, %rd3, %rd8;     // C + offset
    st.global.f32   [%rd9], %f1;
    
EXIT:
    ret;
}