#include <metal_stdlib>
#include <metal_integer>
#include <metal_math>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
namespace utils {
METAL_FUNC uint indices_to_idx_1(uint index, constant const size_t strides[1]) {
return index * strides[0];
}
METAL_FUNC uint indices_to_idx_2(uint2 indices, constant const size_t strides[2]) {
return indices.x * strides[1] + indices.y * strides[0];
}
METAL_FUNC uint indices_to_idx_3(uint3 indices, constant const size_t strides[3]) {
return indices.x * strides[2] + indices.y * strides[1] + indices.z * strides[0];
}
METAL_FUNC uint indices_to_idx_4(uint3 indices,
constant const size_t shape[4],
constant const size_t strides[4]) {
auto idx = indices.x * strides[3] + indices.y * strides[2];
idx += (indices.z % shape[1]) * strides[1];
indices.z /= shape[1];
idx += indices.z * strides[0];
return idx;
}
METAL_FUNC uint indices_to_idx_5(uint3 indices,
constant const size_t shape[5],
constant const size_t strides[5]) {
auto idx = indices.x * strides[4] + indices.y * strides[3];
idx += (indices.z % shape[2]) * strides[2];
indices.z /= shape[2];
idx += (indices.z % shape[1]) * strides[1];
indices.z /= shape[1];
idx += indices.z * strides[0];
return idx;
}
}
/*
* Based on code from:
* https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/binary_ops.h
*/
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct Div {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Sub {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct Mul {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct Equals {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NotEquals {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct And {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct Or {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct Min {
template <typename T>
T operator()(T x, T y) {
return x < y ? x : y;
}
};
struct Max {
template <typename T>
T operator()(T x, T y) {
return x > y ? x : y;
}
};
struct BitAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
}
};
struct BitOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
}
};
struct BitXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
}
};
struct Pow {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T>
operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T>
operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
};
#define INSTANTIATE_1ROW_BIN_OP() \
template [[host_name("bin_ops::add_1row_f32")]] [[kernel]] \
bin_op_1row_t bin_op_1row<float4, Add>; \
template [[host_name("bin_ops::sub_1row_f32")]] [[kernel]] \
bin_op_1row_t bin_op_1row<float4, Sub>; \
template [[host_name("bin_ops::div_1row_f32")]] [[kernel]] \
bin_op_1row_t bin_op_1row<float4, Div>; \
template [[host_name("bin_ops::mul_1row_f32")]] [[kernel]] \
bin_op_1row_t bin_op_1row<float4, Mul>; \
template [[host_name("bin_ops::add_1row_f16")]] [[kernel]] \
bin_op_1row_t bin_op_1row<half4, Add>; \
template [[host_name("bin_ops::sub_1row_f16")]] [[kernel]] \
bin_op_1row_t bin_op_1row<half4, Sub>; \
template [[host_name("bin_ops::dib_1row_f16")]] [[kernel]] \
bin_op_1row_t bin_op_1row<half4, Div>; \
template [[host_name("bin_ops::mul_1row_f16")]] [[kernel]] \
bin_op_1row_t bin_op_1row<half4, Mul>; \
#define INSTANTIATE_BIN_OP(name, op, itname, itype, otype) \
template [[host_name("bin_ops::" #name "_" #itname)]] [[kernel]] \
bin_op_t bin_op<itype, otype, op>; \
#define INSTANTIATE_FLOAT(name, op) \
INSTANTIATE_BIN_OP(name, op, f32, float, float) \
INSTANTIATE_BIN_OP(name, op, f16, half, half)
#define INSTANTIATE_FLOAT_BOOL(name, op) \
INSTANTIATE_BIN_OP(name, op, f32, float, bool) \
INSTANTIATE_BIN_OP(name, op, f16, half, bool)
#define INSTANTIATE_INTEGER(name, op) \
INSTANTIATE_BIN_OP(name, op, u8, uint8_t, uint8_t) \
INSTANTIATE_BIN_OP(name, op, u16, uint16_t, uint16_t) \
INSTANTIATE_BIN_OP(name, op, u32, uint32_t, uint32_t) \
INSTANTIATE_BIN_OP(name, op, u64, uint64_t, uint64_t) \
INSTANTIATE_BIN_OP(name, op, i8, int8_t, int8_t) \
INSTANTIATE_BIN_OP(name, op, i16, int16_t, int16_t) \
INSTANTIATE_BIN_OP(name, op, i32, int32_t, int32_t) \
INSTANTIATE_BIN_OP(name, op, i64, int64_t, int64_t)
#define INSTANTIATE_INTEGER_BOOL(name, op) \
INSTANTIATE_BIN_OP(name, op, u8, uint8_t, bool) \
INSTANTIATE_BIN_OP(name, op, u16, uint16_t, bool) \
INSTANTIATE_BIN_OP(name, op, u32, uint32_t, bool) \
INSTANTIATE_BIN_OP(name, op, u64, uint64_t, bool) \
INSTANTIATE_BIN_OP(name, op, i8, int8_t, bool) \
INSTANTIATE_BIN_OP(name, op, i16, int16_t, bool) \
INSTANTIATE_BIN_OP(name, op, i32, int32_t, bool) \
INSTANTIATE_BIN_OP(name, op, i64, int64_t, bool)
#define INSTANTIATE_ALL_TYPES(name, op) \
INSTANTIATE_FLOAT(name, op) \
INSTANTIATE_INTEGER(name, op)
#define INSTANTIATE_ALL_TYPES_BOOL(name, op) \
INSTANTIATE_FLOAT_BOOL(name, op) \
INSTANTIATE_INTEGER_BOOL(name, op)
template<typename In, typename Out, typename Op>
[[kernel]] void bin_op(device const void *lhs_b [[buffer(0)]],
constant const size_t * lhs_shape [[buffer(1)]],
constant const size_t * lhs_strides [[buffer(2)]],
device const void *rhs_b [[buffer(3)]],
constant const size_t * rhs_shape [[buffer(4)]],
constant const size_t * rhs_strides [[buffer(5)]],
device void *output_b [[buffer(6)]],
constant const size_t * out_shape [[buffer(7)]],
constant const size_t * out_strides [[buffer(8)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
device const In * lhs = (device const In *)lhs_b;
device const In * rhs = (device const In *)rhs_b;
device Out * output = (device Out *)output_b;
auto lhs_idx = tgpig.z * lhs_strides[0] + tgpig.y * lhs_strides[1] + tgpig.x * lhs_strides[2];
auto rhs_idx = tgpig.z * rhs_strides[0] + tgpig.y * rhs_strides[1] + tgpig.x * rhs_strides[2];
auto out_idx = tgpig.z * out_strides[0] + tgpig.y * out_strides[1] + tgpig.x * out_strides[2];
for (size_t i = tpitg.x; i < out_shape[3]; i += ntg.x) {
output[out_idx + i] = Op()(lhs[lhs_idx + i * lhs_strides[3]], rhs[rhs_idx + i * rhs_strides[3]]);
}
}
typedef decltype(bin_op<float, float, Mul>) bin_op_t;
template<typename T4, typename Op>
[[kernel]] void bin_op_1row(device const void *lhs_b [[buffer(0)]],
device const void *rhs_b [[buffer(1)]],
device void *output_b [[buffer(2)]],
device const size_t & n [[buffer(3)]],
uint tpig[[thread_position_in_grid]]) {
device const T4 * lhs = (device const T4 *)lhs_b;
device const T4 * rhs = (device const T4 *)rhs_b;
device T4 * output = (device T4 *)output_b;
const uint nb = n/4;
output[tpig] = Op()(lhs[tpig], rhs[tpig % nb]);
}
typedef decltype(bin_op_1row<float4, Mul>) bin_op_1row_t;
INSTANTIATE_ALL_TYPES(mul, Mul)
INSTANTIATE_ALL_TYPES(div, Div)
INSTANTIATE_ALL_TYPES(add, Add)
INSTANTIATE_ALL_TYPES(sub, Sub)
INSTANTIATE_ALL_TYPES(pow, Pow)
INSTANTIATE_ALL_TYPES_BOOL(lt, Less)
INSTANTIATE_ALL_TYPES_BOOL(gt, Greater)
INSTANTIATE_ALL_TYPES_BOOL(lte, LessEqual)
INSTANTIATE_ALL_TYPES_BOOL(gte, GreaterEqual)
INSTANTIATE_ALL_TYPES_BOOL(eq, Equals)
INSTANTIATE_ALL_TYPES_BOOL(ne, NotEquals)
INSTANTIATE_ALL_TYPES(min, Min)
INSTANTIATE_ALL_TYPES(max, Max)
INSTANTIATE_INTEGER(bitand, BitAnd)
INSTANTIATE_INTEGER(bitor, BitOr)
INSTANTIATE_INTEGER(bitxor, BitXor)
INSTANTIATE_BIN_OP(and, And, bool, bool, bool)
INSTANTIATE_BIN_OP(or, Or, bool, bool, bool)
INSTANTIATE_1ROW_BIN_OP()
// --- Iff (select) kernel ---
template <typename T>
[[kernel]] void iff_generic(
device const bool *cond [[buffer(0)]],
device const T *then_values [[buffer(1)]],
device const T *else_values [[buffer(2)]],
device T *out [[buffer(3)]],
constant const size_t *out_shape [[buffer(4)]],
constant const size_t *cond_strides [[buffer(5)]],
constant const size_t *then_strides [[buffer(6)]],
constant const size_t *else_strides [[buffer(7)]],
constant const size_t *out_strides [[buffer(8)]],
uint tpig [[thread_position_in_grid]])
{
size_t total = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] * out_shape[4];
if (tpig >= total) return;
size_t tmp = tpig;
size_t i4 = tmp % out_shape[4]; tmp /= out_shape[4];
size_t i3 = tmp % out_shape[3]; tmp /= out_shape[3];
size_t i2 = tmp % out_shape[2]; tmp /= out_shape[2];
size_t i1 = tmp % out_shape[1]; tmp /= out_shape[1];
size_t i0 = tmp;
size_t icond = i0 * cond_strides[0] + i1 * cond_strides[1] + i2 * cond_strides[2]
+ i3 * cond_strides[3] + i4 * cond_strides[4];
bool pick = cond[icond];
size_t offset = i0 * (pick ? then_strides[0] : else_strides[0])
+ i1 * (pick ? then_strides[1] : else_strides[1])
+ i2 * (pick ? then_strides[2] : else_strides[2])
+ i3 * (pick ? then_strides[3] : else_strides[3])
+ i4 * (pick ? then_strides[4] : else_strides[4]);
size_t io = i0 * out_strides[0] + i1 * out_strides[1] + i2 * out_strides[2]
+ i3 * out_strides[3] + i4 * out_strides[4];
out[io] = (pick ? then_values : else_values)[offset];
}
#define INSTANTIATE_IFF(tname, type) \
template [[host_name("bin_ops::iff_generic_" #tname)]] [[kernel]] \
void iff_generic<type>( \
device const bool*, device const type*, device const type*, device type*, \
constant const size_t*, constant const size_t*, constant const size_t*, \
constant const size_t*, constant const size_t*, uint);
INSTANTIATE_IFF(u8, uint8_t)
INSTANTIATE_IFF(u16, uint16_t)
INSTANTIATE_IFF(u32, uint32_t)
INSTANTIATE_IFF(u64, uint64_t)