1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
#include "../utils/type_cast.cuh"
#include "../utils/fast_divmod.cuh"
#include "../utils/index_calculator.cuh"
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_contiguous(Output *out, const LHS *lhs, const RHS *rhs, int32_t n, Op op)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(cast<LHS, Output>(lhs[idx]), cast<RHS, Output>(rhs[idx]));
}
}
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_contiguous_lhs_scalar(Output *out, const LHS lhs, const RHS *rhs, int32_t n, Op op)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Output lhs_scalar = cast<LHS, Output>(lhs);
for (int i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(lhs_scalar, cast<RHS, Output>(rhs[idx]));
}
}
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_contiguous_rhs_scalar(Output *out, const LHS *lhs, const RHS rhs, int32_t n, Op op)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Output rhs_scalar = cast<RHS, Output>(rhs);
for (int i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(cast<LHS, Output>(lhs[idx]), rhs_scalar);
}
}
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_uncontiguous(
Output *out,
LHS *lhs,
RHS *rhs,
int32_t n,
Op op,
FastDivmod *lhs_shape,
int32_t *lhs_strides,
FastDivmod *rhs_shape,
int32_t *rhs_strides,
int32_t lhs_ndim,
int32_t rhs_ndim)
{
UncontiguousIndexCalculator<LHS> lhs_idx_calculator = UncontiguousIndexCalculator<LHS>(lhs, lhs_shape, lhs_strides, lhs_ndim);
UncontiguousIndexCalculator<RHS> rhs_idx_calculator = UncontiguousIndexCalculator<RHS>(rhs, rhs_shape, rhs_strides, rhs_ndim);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (size_t i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(cast<LHS, Output>(lhs_idx_calculator.get(idx)), cast<RHS, Output>(rhs_idx_calculator.get(idx)));
}
}
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_uncontiguous_lhs_scalar(
Output *out,
const LHS lhs,
RHS *rhs,
int32_t n,
Op op,
FastDivmod *rhs_shape,
int32_t *rhs_strides,
int32_t rhs_ndim)
{
Output lhs_scalar = cast<LHS, Output>(lhs);
UncontiguousIndexCalculator<RHS> rhs_idx_calculator = UncontiguousIndexCalculator<RHS>(rhs, rhs_shape, rhs_strides, rhs_ndim);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (size_t i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(lhs_scalar, cast<RHS, Output>(rhs_idx_calculator.get(idx)));
}
}
template <typename LHS, typename RHS, typename Output, typename Op>
__device__ __forceinline__ void binary_uncontiguous_rhs_scalar(
Output *out,
LHS *lhs,
const RHS rhs,
int32_t n,
Op op,
FastDivmod *lhs_shape,
int32_t *lhs_strides,
int32_t lhs_ndim)
{
UncontiguousIndexCalculator<LHS> lhs_idx_calculator = UncontiguousIndexCalculator<LHS>(lhs, lhs_shape, lhs_strides, lhs_ndim);
Output rhs_scalar = cast<RHS, Output>(rhs);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (size_t i = idx; i < n; i += blockDim.x * gridDim.x)
{
out[idx] = op(cast<LHS, Output>(lhs_idx_calculator.get(idx)), rhs_scalar);
}
}
#define DEFINE_BINARY_KERNEL(func_name, lhs_type, rhs_type, op, promote) \
extern "C" __global__ void func_name##_contiguous(void *out, void *lhs, void *rhs, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_contiguous<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), static_cast<const lhs_type *>(lhs), static_cast<const rhs_type *>(rhs), n, op<Output, Output>{}); \
} \
extern "C" __global__ void func_name##_contiguous_lhs_scalar(void *out, lhs_type lhs, void *rhs, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_contiguous_lhs_scalar<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), lhs, static_cast<const rhs_type *>(rhs), n, op<Output, Output>{}); \
} \
extern "C" __global__ void func_name##_contiguous_rhs_scalar(void *out, void *lhs, rhs_type rhs, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_contiguous_rhs_scalar<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), static_cast<const lhs_type *>(lhs), rhs, n, op<Output, Output>{}); \
} \
extern "C" __global__ void func_name##_uncontiguous(void *out, void *lhs, FastDivmod *lhs_shape, int32_t *lhs_strides, void *rhs, FastDivmod *rhs_shape, int32_t *rhs_strides, int32_t lhs_ndim, int32_t rhs_ndim, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_uncontiguous<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), static_cast<lhs_type *>(lhs), static_cast<rhs_type *>(rhs), n, op<Output, Output>{}, lhs_shape, lhs_strides, rhs_shape, rhs_strides, lhs_ndim, rhs_ndim); \
} \
extern "C" __global__ void func_name##_uncontiguous_lhs_scalar(void *out, lhs_type lhs, void *rhs, FastDivmod *rhs_shape, int32_t *rhs_strides, int32_t rhs_ndim, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_uncontiguous_lhs_scalar<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), lhs, static_cast<rhs_type *>(rhs), n, op<Output, Output>{}, rhs_shape, rhs_strides, rhs_ndim); \
} \
extern "C" __global__ void func_name##_uncontiguous_rhs_scalar(void *out, void *lhs, rhs_type rhs, FastDivmod *lhs_shape, int32_t *lhs_strides, int32_t lhs_ndim, int32_t n) \
{ \
using Output = typename promote<lhs_type, rhs_type>::Output; \
binary_uncontiguous_rhs_scalar<lhs_type, rhs_type, Output, op<Output, Output>>(static_cast<Output *>(out), static_cast<lhs_type *>(lhs), rhs, n, op<Output, Output>{}, lhs_shape, lhs_strides, lhs_ndim); \
}