#pragma once
__device__ inline bool batch_unary_setup(
int* b, int* elem, int nstates
) {
*elem = blockIdx.x * blockDim.x + threadIdx.x;
*b = blockIdx.y;
return (*elem < nstates);
}
__device__ inline bool batch_binary_setup(
int* elem, int nstates,
int* li, int lhs_stride,
int* ri, int rhs_stride, int rhs_nbatch
) {
*elem = blockIdx.x * blockDim.x + threadIdx.x;
if (*elem >= nstates) return false;
int b = blockIdx.y;
*li = b * lhs_stride + *elem;
*ri = (b % rhs_nbatch) * rhs_stride + *elem;
return true;
}
__device__ inline bool batch_ternary_setup(
int* elem, int nstates,
int* li, int lhs_stride,
int* ri, int rhs_stride, int rhs_nbatch,
int* oi, int out_stride, int out_nbatch
) {
*elem = blockIdx.x * blockDim.x + threadIdx.x;
if (*elem >= nstates) return false;
int b = blockIdx.y;
*li = b * lhs_stride + *elem;
*ri = (b % rhs_nbatch) * rhs_stride + *elem;
*oi = (b % out_nbatch) * out_stride + *elem;
return true;
}
__device__ inline bool batch_gather_scatter_setup(
int* j, int nindices,
int* si, int self_stride, int self_nbatch,
int* oi, int other_stride, int other_nbatch,
const int* __restrict__ indices
) {
*j = blockIdx.x * blockDim.x + threadIdx.x;
if (*j >= nindices) return false;
int b = blockIdx.y;
int src = indices[*j];
*si = (b % self_nbatch) * self_stride + *j;
*oi = (b % other_nbatch) * other_stride + src;
return true;
}
__device__ inline bool batch_assign_at_setup(
int* j, int nindices,
int* si, int self_stride, int self_nbatch,
const int* __restrict__ indices
) {
*j = blockIdx.x * blockDim.x + threadIdx.x;
if (*j >= nindices) return false;
int b = blockIdx.y;
int idx = indices[*j];
*si = (b % self_nbatch) * self_stride + idx;
return true;
}
__device__ inline bool batch_copy_indices_setup(
int* j, int nindices,
int* si, int self_stride, int self_nbatch,
int* oi, int other_stride, int other_nbatch,
const int* __restrict__ indices
) {
*j = blockIdx.x * blockDim.x + threadIdx.x;
if (*j >= nindices) return false;
int b = blockIdx.y;
int idx = indices[*j];
*si = (b % self_nbatch) * self_stride + idx;
*oi = (b % other_nbatch) * other_stride + idx;
return true;
}
__device__ inline bool batch_set_data_setup(
int* j, int n,
int* si, int self_stride, int self_nbatch,
int* oi, int other_stride, int other_nbatch,
const int* __restrict__ dst_indices,
const int* __restrict__ src_indices
) {
*j = blockIdx.x * blockDim.x + threadIdx.x;
if (*j >= n) return false;
int b = blockIdx.y;
int di = dst_indices[*j];
int si_idx = src_indices[*j];
*si = (b % self_nbatch) * self_stride + di;
*oi = (b % other_nbatch) * other_stride + si_idx;
return true;
}
__device__ inline bool batch_diagonal_setup(
int* i, int nrows,
int* mi, int mat_stride, int mat_nbatch,
int* di, int diag_stride, int diag_nbatch
) {
*i = blockIdx.x * blockDim.x + threadIdx.x;
if (*i >= nrows) return false;
int b = blockIdx.y;
*mi = (b % mat_nbatch) * mat_stride + *i * nrows + *i;
*di = (b % diag_nbatch) * diag_stride + *i;
return true;
}
__device__ inline bool batch_set_column_setup(
int* i, int nrows,
int* mi, int mat_stride, int mat_nbatch,
int* ci, int col_stride, int col_nbatch,
int column_index
) {
*i = blockIdx.x * blockDim.x + threadIdx.x;
if (*i >= nrows) return false;
int b = blockIdx.y;
*mi = (b % mat_nbatch) * mat_stride + column_index * nrows + *i;
*ci = (b % col_nbatch) * col_stride + *i;
return true;
}