#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// diag_mask_inf kernels
//------------------------------------------------------------------------------
kernel void kernel_diag_mask_inf(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int n_past
) {
src0 = (global float*)((global char*)src0 + offset0) dst = (global float*)((global char*)dst + offsetd)
int i02 = get_global_id(2) int i01 = get_global_id(1) int i00 = get_global_id(0)
if (i00 > n_past + i01) {
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY } else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00] }
}
kernel void kernel_diag_mask_inf_8(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd,
int ne00,
int ne01,
int n_past
) {
src0 = (global float4*)((global char*)src0 + offset0) dst = (global float4*)((global char*)dst + offsetd)
int i = 2*get_global_id(0)
dst[i+0] = src0[i+0] dst[i+1] = src0[i+1] int i4 = 4*i int i02 = i4/(ne00*ne01) int i01 = i4/(ne00) int i00 = i4 for (int k = 3 if (i00 + 4 + k <= n_past + i01) {
break }
(&dst[i+1])[k] = -INFINITY if (i00 + k > n_past + i01) {
(&dst[i])[k] = -INFINITY }
}
}