tenflowers-core 0.1.1

Core tensor operations and execution engine for TenfloweRS
Documentation
//
// Optimized element-wise addition kernel in PTX
// Implements vectorized addition with memory coalescing
// Target: CUDA compute capability 6.0+
//

.version 7.0
.target sm_60
.address_size 64

// Kernel entry point for element-wise addition C = A + B
.visible .entry cuda_elementwise_add(
    .param .u64 param_0, // A array pointer
    .param .u64 param_1, // B array pointer
    .param .u64 param_2, // C array pointer
    .param .u32 param_3  // Number of elements
)
{
    .reg .pred      %p<5>;
    .reg .f32       %f<20>;
    .reg .b32       %r<20>;
    .reg .b64       %rd<10>;
    
    // Load parameters
    ld.param.u64    %rd1, [param_0];      // A pointer
    ld.param.u64    %rd2, [param_1];      // B pointer
    ld.param.u64    %rd3, [param_2];      // C pointer
    ld.param.u32    %r1, [param_3];       // count
    
    // Calculate global thread index
    mov.u32         %r2, %ctaid.x;        // blockIdx.x
    mov.u32         %r3, %blockDim.x;     // blockDim.x
    mov.u32         %r4, %tid.x;          // threadIdx.x
    mad.lo.s32      %r5, %r2, %r3, %r4;   // blockIdx.x * blockDim.x + threadIdx.x
    
    // Process multiple elements per thread for better memory utilization
    mov.s32         %r6, 4;               // elements_per_thread
    mul.lo.s32      %r7, %r5, %r6;        // base_idx = tid * elements_per_thread
    
    // Check bounds for vectorized access
    add.s32         %r8, %r7, %r6;        // base_idx + elements_per_thread
    setp.le.s32     %p1, %r8, %r1;        // base_idx + 4 <= count
    setp.eq.and.s32 %p2, %r7, 0, 3;       // base_idx % 4 == 0 (alignment check)
    and.pred        %p3, %p1, %p2;        // can use vectorized access
    
    @%p3 bra        VECTORIZED_ADD;
    
    // Scalar fallback
    mov.s32         %r9, 0;               // i = 0
SCALAR_LOOP:
    setp.ge.s32     %p4, %r9, %r6;        // i >= elements_per_thread
    @%p4 bra        EXIT;
    
    add.s32         %r10, %r7, %r9;       // idx = base_idx + i
    setp.ge.s32     %p5, %r10, %r1;       // idx >= count
    @%p5 bra        EXIT;
    
    // Calculate addresses
    shl.b64         %rd4, %r10, 2;        // idx * sizeof(float)
    add.s64         %rd5, %rd1, %rd4;     // A + offset
    add.s64         %rd6, %rd2, %rd4;     // B + offset
    add.s64         %rd7, %rd3, %rd4;     // C + offset
    
    // Load, compute, store
    ld.global.f32   %f1, [%rd5];          // A[idx]
    ld.global.f32   %f2, [%rd6];          // B[idx]
    add.f32         %f3, %f1, %f2;        // A[idx] + B[idx]
    st.global.f32   [%rd7], %f3;          // C[idx] = result
    
    add.s32         %r9, %r9, 1;          // i++
    bra             SCALAR_LOOP;
    
VECTORIZED_ADD:
    // Use float4 vectorized operations for better memory bandwidth
    shl.b64         %rd8, %r7, 2;         // base_idx * sizeof(float)
    add.s64         %rd9, %rd1, %rd8;     // A + offset
    add.s64         %rd10, %rd2, %rd8;    // B + offset
    add.s64         %rd11, %rd3, %rd8;    // C + offset
    
    // Load float4 vectors
    ld.global.v4.f32 {%f4, %f5, %f6, %f7}, [%rd9];    // A[base_idx:base_idx+4]
    ld.global.v4.f32 {%f8, %f9, %f10, %f11}, [%rd10]; // B[base_idx:base_idx+4]
    
    // Vectorized addition
    add.f32         %f12, %f4, %f8;       // C[0] = A[0] + B[0]
    add.f32         %f13, %f5, %f9;       // C[1] = A[1] + B[1]
    add.f32         %f14, %f6, %f10;      // C[2] = A[2] + B[2]
    add.f32         %f15, %f7, %f11;      // C[3] = A[3] + B[3]
    
    // Store float4 vector
    st.global.v4.f32 [%rd11], {%f12, %f13, %f14, %f15}; // C[base_idx:base_idx+4]
    
EXIT:
    ret;
}