//
// Optimized element-wise multiplication kernel in PTX
// Implements vectorized multiplication with memory coalescing
// Target: CUDA compute capability 6.0+
//
.version 7.0
.target sm_60
.address_size 64
// Kernel entry point for element-wise multiplication C = A * B
.visible .entry cuda_elementwise_mul(
.param .u64 param_0, // A array pointer
.param .u64 param_1, // B array pointer
.param .u64 param_2, // C array pointer
.param .u32 param_3 // Number of elements
)
{
.reg .pred %p<5>;
.reg .f32 %f<20>;
.reg .b32 %r<20>;
.reg .b64 %rd<10>;
// Load parameters
ld.param.u64 %rd1, [param_0]; // A pointer
ld.param.u64 %rd2, [param_1]; // B pointer
ld.param.u64 %rd3, [param_2]; // C pointer
ld.param.u32 %r1, [param_3]; // count
// Calculate global thread index
mov.u32 %r2, %ctaid.x; // blockIdx.x
mov.u32 %r3, %blockDim.x; // blockDim.x
mov.u32 %r4, %tid.x; // threadIdx.x
mad.lo.s32 %r5, %r2, %r3, %r4; // blockIdx.x * blockDim.x + threadIdx.x
// Process multiple elements per thread
mov.s32 %r6, 4; // elements_per_thread
mul.lo.s32 %r7, %r5, %r6; // base_idx = tid * elements_per_thread
// Check bounds for vectorized access
add.s32 %r8, %r7, %r6; // base_idx + elements_per_thread
setp.le.s32 %p1, %r8, %r1; // base_idx + 4 <= count
setp.eq.and.s32 %p2, %r7, 0, 3; // base_idx % 4 == 0 (alignment check)
and.pred %p3, %p1, %p2; // can use vectorized access
@%p3 bra VECTORIZED_MUL;
// Scalar fallback
mov.s32 %r9, 0; // i = 0
SCALAR_LOOP:
setp.ge.s32 %p4, %r9, %r6; // i >= elements_per_thread
@%p4 bra EXIT;
add.s32 %r10, %r7, %r9; // idx = base_idx + i
setp.ge.s32 %p5, %r10, %r1; // idx >= count
@%p5 bra EXIT;
// Calculate addresses
shl.b64 %rd4, %r10, 2; // idx * sizeof(float)
add.s64 %rd5, %rd1, %rd4; // A + offset
add.s64 %rd6, %rd2, %rd4; // B + offset
add.s64 %rd7, %rd3, %rd4; // C + offset
// Load, compute, store
ld.global.f32 %f1, [%rd5]; // A[idx]
ld.global.f32 %f2, [%rd6]; // B[idx]
mul.f32 %f3, %f1, %f2; // A[idx] * B[idx]
st.global.f32 [%rd7], %f3; // C[idx] = result
add.s32 %r9, %r9, 1; // i++
bra SCALAR_LOOP;
VECTORIZED_MUL:
// Use float4 vectorized operations
shl.b64 %rd8, %r7, 2; // base_idx * sizeof(float)
add.s64 %rd9, %rd1, %rd8; // A + offset
add.s64 %rd10, %rd2, %rd8; // B + offset
add.s64 %rd11, %rd3, %rd8; // C + offset
// Load float4 vectors
ld.global.v4.f32 {%f4, %f5, %f6, %f7}, [%rd9]; // A[base_idx:base_idx+4]
ld.global.v4.f32 {%f8, %f9, %f10, %f11}, [%rd10]; // B[base_idx:base_idx+4]
// Vectorized multiplication
mul.f32 %f12, %f4, %f8; // C[0] = A[0] * B[0]
mul.f32 %f13, %f5, %f9; // C[1] = A[1] * B[1]
mul.f32 %f14, %f6, %f10; // C[2] = A[2] * B[2]
mul.f32 %f15, %f7, %f11; // C[3] = A[3] * B[3]
// Store float4 vector
st.global.v4.f32 [%rd11], {%f12, %f13, %f14, %f15}; // C[base_idx:base_idx+4]
EXIT:
ret;
}