#include <metal_stdlib>
#include <metal_math>
using namespace metal;
// Utils
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template<uint Y>
constexpr uint div_ceil(uint x) {
return x / Y + (x % Y > 0);
}
template<uint X, uint Y>
constexpr uint div_ceil() {
return X / Y + (X % Y > 0);
}
template<typename T>
constexpr uint work_per_thread() {
return div_ceil<8, sizeof(T)>();
}
// Kernels
template <typename T, typename U, typename unary, int W = work_per_thread<T>()>
[[kernel]] void unary_kernel(
constant size_t &dim,
device const T* input,
device U* output,
uint tid [[thread_position_in_grid]]
) {
unary op;
const uint step = div_ceil<W>(dim);
#pragma clang loop unroll(full)
for (uint i = tid; i < dim; i += step) {
output[i] = static_cast<U>(op(input[i]));
}
}
template <typename T, typename U, typename unary>
[[kernel]] void unary_kernel_strided(
constant size_t &dim,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant const T *input,
device U *output,
uint tid [[ thread_position_in_grid ]]
) {
unary op;
if (tid >= dim) return;
uint idx = get_strided_index(tid, num_dims, dims, strides);
output[tid] = static_cast<U>(op(input[idx]));
}
template <typename T, int W = work_per_thread<T>()>
[[kernel]] void const_set(
constant size_t &dim,
device const T &input,
device T *output,
uint tid [[thread_position_in_grid]]
) {
const uint step = div_ceil<W>(dim);
#pragma clang loop unroll(full)
for (uint i = tid; i < dim; i += step) {
output[i] = input;
}
}
template <typename T>
[[kernel]] void const_set_strided(
constant size_t &dim,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
device const T &input,
device T *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dim) return;
uint idx = get_strided_index(tid, num_dims, dims, strides);
output[idx] = input;
}
template <typename T>
[[kernel]] void copy2d(
constant int64_t &d1,
constant int64_t &d2,
constant int64_t &src_s,
constant int64_t &dst_s,
device const T *input,
device T *output,
uint2 idx [[thread_position_in_grid]]
) {
if (idx.x >= d1 || idx.y >= d2) return;
int64_t src_idx = idx.x * src_s + idx.y;
int64_t dst_idx = idx.x * dst_s + idx.y;
output[dst_idx] = input[src_idx];
}
// Unary functions
template <typename T> METAL_FUNC T erf(T in){
// constants
constexpr const float a1 = 0.254829592;
constexpr const float a2 = -0.284496736;
constexpr const float a3 = 1.421413741;
constexpr const float a4 = -1.453152027;
constexpr const float a5 = 1.061405429;
constexpr const float p = 0.3275911;
float x = static_cast<float>(in);
// Save the sign of x
int sign = 1;
if (x < 0)
sign = -1;
x = fabs(x);
// A&S formula 7.1.26
float t = 1.0/(1.0 + p*x);
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
return T(sign*y);
}
template <typename T> METAL_FUNC T id(T in) { return in; }
template <typename T> METAL_FUNC T gelu_erf(T x) {
return static_cast<T>(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
}
template <typename T> METAL_FUNC T gelu(T x) {
if (x > 5) {
return x;
}
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T x) {
if (x > 5) {
return x;
}
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
}
template <typename T> METAL_FUNC T recip(T x) {
return static_cast<T>(1.0 / x);
}
template <typename T> METAL_FUNC T sigmoid(T x) {
return static_cast<T>(recip(1 + exp(-x)));
}
// Define unary ops
#define define_unary_op(name, op) \
struct name { \
template <typename T> \
METAL_FUNC T operator()(T x) { \
return static_cast<T>(op); \
} \
};
define_unary_op(usqr, x * x);
define_unary_op(urecip, recip(x));
define_unary_op(uneg, -x);
define_unary_op(uid, x);
define_unary_op(ugelu, gelu(x));
define_unary_op(urelu, x < 0 ? 0 : x);
define_unary_op(usilu, x / (1 + exp(-x)));
define_unary_op(ugelu_erf, gelu_erf(x));
define_unary_op(usqrt, sqrt(x));
define_unary_op(ucos, cos(x));
define_unary_op(usin, sin(x));
define_unary_op(uexp, exp(x));
define_unary_op(ulog, log(x));
define_unary_op(uabs, abs(static_cast<float>(x)));
define_unary_op(uceil, ceil(x));
define_unary_op(ufloor, floor(x));
define_unary_op(uround, round(x));
define_unary_op(uerf, erf(x));
define_unary_op(usign, sign(x));
define_unary_op(usigmoid, sigmoid(x));
// tanh may create NaN on large values, e.g. 45 rather than outputting 1.
// This has been an issue for the encodec example.
define_unary_op(utanh, precise::tanh(x));
// Macros to help initialize kernels
#define init_kernel(name, func, ...) \
template [[host_name(name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
#define init_unary(op_name, unary_op, tname, t) \
init_kernel(#op_name "_" #tname, unary_kernel, t, t, unary_op) \
init_kernel(#op_name "_" #tname "_strided", unary_kernel_strided, t, t, unary_op)
#if defined(__HAVE_BFLOAT__)
#define init_unary_float(op_name, unary_op) \
init_unary(op_name, unary_op, f32, float) \
init_unary(op_name, unary_op, f16, half) \
init_unary(op_name, unary_op, bf16, bfloat)
#else
#define init_unary_float(op_name, unary_op) \
init_unary(op_name, unary_op, f32, float) \
init_unary(op_name, unary_op, f16, half)
#endif
#define init_copy2d(tname, t) \
init_kernel("copy2d_" #tname, copy2d, t)
#define init_const_set(tname, t) \
init_kernel("const_set_" #tname, const_set, t) \
init_kernel("const_set_" #tname "_strided", const_set_strided, t)
// Initialize all unary kernels for floating point types
init_unary_float(gelu_erf, ugelu_erf);
init_unary_float(sqrt, usqrt);
init_unary_float(sqr, usqr);
init_unary_float(neg, uneg);
init_unary_float(recip, urecip);
init_unary_float(copy, uid);
init_unary_float(silu, usilu);
init_unary_float(gelu, ugelu);
init_unary_float(relu, urelu);
init_unary_float(cos, ucos);
init_unary_float(sin, usin);
init_unary_float(exp, uexp);
init_unary_float(log, ulog);
init_unary_float(abs, uabs);
init_unary_float(ceil, uceil);
init_unary_float(floor, ufloor);
init_unary_float(round, uround);
init_unary_float(erf, uerf);
init_unary_float(sign, usign);
init_unary_float(sigmoid, usigmoid);
init_unary_float(tanh, utanh);
// Initialize copy2d kernels
init_copy2d(f32, float);
init_copy2d(f16, half);
// Initialize const_set kernels
init_const_set(f32, float);
init_const_set(f16, half);
#if defined(__HAVE_BFLOAT__)
init_copy2d(bf16, bfloat);
init_const_set(bf16, bfloat);
#endif
// Initialize unary kernels for integer dtypes
init_unary(copy, uid, u8, uint8_t);
init_unary(copy, uid, u32, uint32_t);
init_copy2d(u8, uint8_t);
init_copy2d(u32, uint32_t);
init_const_set(u8, uint8_t);
init_const_set(u32, uint32_t);
#if __METAL_VERSION__ >= 220
init_unary(copy, uid, i64, int64_t);
init_copy2d(i64, int64_t);
init_const_set(i64, int64_t);
#endif