#include <metal_math>
#include <metal_stdlib>
using namespace metal;
#define NUM_SIMDGROUP 32
METAL_FUNC uint indices_to_idx_2(uint2 indices,
constant const size_t strides[2]) {
return indices.x * strides[1] + indices.y * strides[0];
}
METAL_FUNC uint indices_to_idx_3(uint3 indices,
constant const size_t strides[3]) {
return indices.x * strides[2] + indices.y * strides[1] +
indices.z * strides[0];
}
METAL_FUNC uint indices_to_idx_4(uint3 indices, constant const size_t shape[4],
constant const size_t strides[4]) {
auto idx = indices.x * strides[3] + indices.y * strides[2];
idx += (indices.z % shape[1]) * strides[1];
indices.z /= shape[1];
idx += indices.z * strides[0];
return idx;
}
template <typename U> struct MeanOfSquares {
float simd_reduce(float val, size_t reduce_dim) {
return simd_sum(val) / static_cast<float>(reduce_dim);
}
static constexpr constant float init = 0.0;
// Operator
float operator()(float acc, U a) {
float a_f = static_cast<float>(a);
return acc + a_f * a_f;
}
};
template <typename U> struct Sum {
U simd_reduce(U val, size_t reduce_dim) { return simd_sum(val); }
static constexpr constant U init = U(0);
// Operator
U operator()(U acc, U a) { return acc + a; }
};
template <typename U> struct Min {
template <typename T> T simd_reduce(T val, size_t reduce_dim) {
return simd_min(val);
}
static constexpr constant U init = metal::numeric_limits<U>::infinity();
// Operator
U operator()(U a, U b) { return a < b ? a : b; }
};
template <typename U> struct Max {
template <typename T> T simd_reduce(T val, size_t reduce_dim) {
return simd_max(val);
}
static constexpr constant U init = -metal::numeric_limits<U>::infinity();
// Operator
U operator()(U a, U b) { return a > b ? a : b; }
};
template <typename U> struct Prod {
U simd_reduce(U val, size_t reduce_dim) { return simd_product(val); }
static constexpr constant U init = U(1);
// Operator
U operator()(U acc, U a) { return acc * a; }
};
template <typename U> struct All {
U simd_reduce(U val, size_t reduce_dim) { return simd_all(val); }
static constexpr constant U init = U(1);
// Operator
U operator()(U acc, U a) { return acc && a; }
};
template <typename U> struct Any {
U simd_reduce(U val, size_t reduce_dim) { return simd_any(val); }
static constexpr constant U init = U(0);
// Operator
U operator()(U acc, U a) { return acc || a; }
};
template <typename F, typename Op>
[[kernel]] void reduce_nd3(device const void *input_b, device void *output_b,
constant const size_t input_shape[3],
constant const size_t input_strides[3],
constant const size_t output_strides[3],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint tpsg [[threads_per_simdgroup]]) {
device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;
Op op = Op();
size_t reduce_dim = input_shape[1];
size_t out_idx = tgpig.x * output_strides[2] + tgpig.y * output_strides[1] +
tgpig.z * output_strides[0];
size_t base_in_idx =
tgpig.x * input_strides[2] + tgpig.z * input_strides[0];
auto partial_acc = Op::init;
for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
F el = input[base_in_idx + i * input_strides[1]];
partial_acc = op(partial_acc, el);
}
auto acc = op.simd_reduce(partial_acc, reduce_dim);
if (tiisg == 0) {
output[out_idx] = acc;
}
}
typedef decltype(reduce_nd3<float, Prod<float>>) reduce_nd3_t;
#define INSTANTIATE_REDUCE(name, op, tname, type) \
template [[host_name( \
"nn_ops::reduce_" #name \
"_nd3_" #tname)]] [[kernel]] reduce_nd3_t reduce_nd3<type, op<type>>;
INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f32, float)
INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f16, half)
INSTANTIATE_REDUCE(sum, Sum, f32, float)
INSTANTIATE_REDUCE(sum, Sum, f16, half)
INSTANTIATE_REDUCE(min, Min, f32, float)
INSTANTIATE_REDUCE(min, Min, f16, half)
INSTANTIATE_REDUCE(max, Max, f32, float)
INSTANTIATE_REDUCE(max, Max, f16, half)
INSTANTIATE_REDUCE(prod, Prod, f32, float)
INSTANTIATE_REDUCE(prod, Prod, f16, half)
INSTANTIATE_REDUCE(all, All, bool, char)
INSTANTIATE_REDUCE(any, Any, bool, char)
template <typename F>
[[kernel]] void rms_norm_nd3(device const void *input_b, constant void *eps_b,
device void *output_b,
constant const size_t shape[3],
constant const size_t strides[3],
threadgroup float *shmem_f32 [[threadgroup(0)]],
uint tgpig [[threadgroup_position_in_grid]],
ushort tpitg [[thread_position_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]],
ushort tiisg [[thread_index_in_simdgroup]],
ushort ntg [[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const F *input = (device const F *)input_b;
float eps = ((constant float *)eps_b)[0];
device F *output = (device F *)output_b;
size_t dim = shape[1];
size_t base_idx =
(tgpig % shape[2]) * strides[2] + (tgpig / shape[2]) * strides[0];
float partial_acc = 0.0;
for (size_t i = tpitg; i < dim; i += ntg) {
float el = static_cast<float>(input[base_idx + i * strides[1]]);
partial_acc += el * el;
}
partial_acc = simd_sum(partial_acc);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = partial_acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
partial_acc = shmem_f32[tiisg];
partial_acc = simd_sum(partial_acc);
float mean_of_squares = partial_acc / dim;
float norm = metal::rsqrt(mean_of_squares + eps);
for (size_t i = tpitg; i < dim; i += ntg) {
auto idx = base_idx + i * strides[1];
output[idx] = input[idx] * norm;
}
}
template <typename F, typename F4>
[[kernel]] void rms_norm_nd2_l4(device const char *input_b,
constant char *eps_b, device char *output_b,
constant const size_t &n,
constant const size_t &n_div_4,
constant const size_t &outer_stride,
threadgroup float *shmem_f32 [[threadgroup(0)]],
uint tgpig [[threadgroup_position_in_grid]],
ushort tpitg [[thread_position_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]],
ushort tiisg [[thread_index_in_simdgroup]],
ushort ntg [[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const F4 *x = (device const F4 *)(input_b + tgpig * outer_stride);
float eps = ((constant float *)eps_b)[0];
float sumf = 0.0f;
// parallel sum
for (size_t i = tpitg; i < n_div_4; i += ntg) {
float4 el = static_cast<float4>(x[i]);
sumf += dot(el, el);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf / n;
const float scale = 1.0f / sqrt(mean + eps);
device F4 *y = (device F4 *)output_b + tgpig * n_div_4;
for (size_t i = tpitg; i < n_div_4; i += ntg) {
y[i] = x[i] * scale;
}
}
typedef decltype(rms_norm_nd3<float>) rms_norm_nd3_t;
typedef decltype(rms_norm_nd2_l4<float, float4>) rms_norm_nd2_l4_t;
template [[host_name(
"nn_ops::rms_norm_nd3_f32")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<float>;
template [[host_name(
"nn_ops::rms_norm_nd3_f16")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<half>;
template
[[host_name("nn_ops::rms_norm_nd2_l4_f32")]] [[kernel]] rms_norm_nd2_l4_t
rms_norm_nd2_l4<float, float4>;
template
[[host_name("nn_ops::rms_norm_nd2_l4_f16")]] [[kernel]] rms_norm_nd2_l4_t
rms_norm_nd2_l4<half, half4>;
struct Sigmoid {
template <typename T> T operator()(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
};
template <typename T>
[[kernel]] void silu(device const void *input_b [[buffer(0)]],
device void *output_b [[buffer(1)]],
uint tpig [[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device T *output = (device T *)output_b;
output[tpig] = Sigmoid()(static_cast<float>(input[tpig])) * input[tpig];
}
typedef decltype(silu<float>) silu_t;
template <typename T4>
[[kernel]] void silu_4(device const void *input_b, device void *output_b,
uint tpig [[thread_position_in_grid]]) {
device const T4 *input = (device const T4 *)input_b;
device T4 *output = (device T4 *)output_b;
auto x = input[tpig];
output[tpig] = x / (1.0f + exp(-x));
}
typedef decltype(silu_4<float4>) silu_4_t;
template [[host_name("nn_ops::silu_f32")]] [[kernel]] silu_t silu<float>;
template [[host_name("nn_ops::silu_f16")]] [[kernel]] silu_t silu<half>;
template [[host_name("nn_ops::silu_4_f32")]] [[kernel]] silu_4_t silu_4<float4>;
template [[host_name("nn_ops::silu_4_f16")]] [[kernel]] silu_4_t silu_4<half4>;
template <typename T>
[[kernel]] void leaky_relu(device const void *input_b [[buffer(0)]],
device void *output_b [[buffer(1)]],
constant float &alpha [[buffer(2)]],
uint tpig [[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device T *output = (device T *)output_b;
T x = input[tpig];
output[tpig] = x >= T(0) ? x : T(alpha) * x;
}
typedef decltype(leaky_relu<float>) leaky_relu_t;
template [[host_name("nn_ops::leaky_relu_f32")]] [[kernel]] leaky_relu_t leaky_relu<float>;
template [[host_name("nn_ops::leaky_relu_f16")]] [[kernel]] leaky_relu_t leaky_relu<half>;
template <typename F>
[[kernel]] void softmax_nd3(device const void *input_b, device void *output_b,
constant const size_t shape[3],
constant const size_t strides[3],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint tpsg [[threads_per_simdgroup]]) {
device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;
size_t dim = shape[1];
size_t base_idx = tgpig.x * strides[2] + tgpig.z * strides[0];
// Get max value on softmax dim
float partial_max = -INFINITY;
for (size_t i = tiisg; i < dim; i += tpsg) {
auto idx = base_idx + i * strides[1];
float el = static_cast<float>(input[idx]);
partial_max = max(partial_max, el);
}
float axis_max = simd_max(partial_max);
// Compute Sum(exp(x - max))
float partial_norm = 0;
for (size_t i = tiisg; i < dim; i += tpsg) {
auto idx = base_idx + i * strides[1];
float el = static_cast<float>(input[idx]);
float exp_el = fast::exp(el - axis_max);
partial_norm += exp_el;
output[idx] = static_cast<F>(exp_el);
}
float axis_norm = simd_sum(partial_norm);
float inv_axis_norm = 1.0 / axis_norm;
for (size_t i = tiisg; i < dim; i += tpsg) {
auto idx = base_idx + i * strides[1];
float exp_el = static_cast<float>(output[idx]);
output[idx] = static_cast<F>(exp_el * inv_axis_norm);
}
}
typedef decltype(softmax_nd3<float>) softmax_nd3_t;
template [[host_name(
"nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3<float>;
template [[host_name(
"nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3<half>;
template <typename F>
[[kernel]] void scaled_masked_softmax_nd5(
device const void *input_b, device const void *mask_b,
constant float *scale_b, device void *output_b,
constant const size_t shape[5], constant const size_t strides[5],
constant const size_t mask_strides[5], constant const size_t out_strides[5],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint tpsg [[threads_per_simdgroup]],
uint3 tptg [[thread_position_in_threadgroup]],
uint3 tptgN [[threads_per_threadgroup]],
threadgroup float *tgmem [[threadgroup(0)]]) {
const uint tid = tptg.x;
const uint tg_sz = tptgN.x;
const uint sg_id = tid / tpsg;
const uint lane = tiisg;
// Grid is (rows, g, b * kh) == (shape[3], shape[2], shape[0] * shape[1])
const size_t row = (size_t)tgpig.x;
const size_t h = (size_t)tgpig.y;
const size_t z = (size_t)tgpig.z;
const size_t z0 = z / shape[1];
const size_t z1 = z % shape[1];
device const F *x = (device const F *)input_b;
device const F *mask = (device const F *)mask_b;
device F *out = (device F *)output_b;
const float scale = *scale_b;
x += row * strides[3] + h * strides[2] + z1 * strides[1] + z0 * strides[0];
out += row * out_strides[3] + h * out_strides[2] + z1 * out_strides[1] +
z0 * out_strides[0];
const bool has_mask = (mask_b != nullptr);
if (has_mask) {
mask += row * mask_strides[3] + h * mask_strides[2] +
z1 * mask_strides[1] + z0 * mask_strides[0];
}
// Threadgroup scratch layout:
// tgmem[0..31] -> buf_iw (one float per simdgroup, up to 32
// simdgroups) tgmem[32..32+cols-1] -> vals (float[cols]) If you
// allocate nextPow2(cols) for vals, that's fine too.
threadgroup float *buf_iw = tgmem;
threadgroup float *vals = tgmem + 32;
const uint simd_size = tpsg; // usually 32 on Apple GPUs
const uint num_sg = (tg_sz + simd_size - 1u) / simd_size;
const size_t cols = shape[4];
// 1) Load (x*scale + mask) and compute max in float
float max_val = -INFINITY;
for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) {
const float xv = (float)x[col * strides[4]] * scale;
const float mv = has_mask ? (float)mask[col * mask_strides[4]] : 0.0f;
const float v = xv + mv;
vals[col] = v;
max_val = metal::max(max_val, v);
}
// reduce max across simdgroup
float sg_max = simd_max(max_val);
// write per-simdgroup max
if (lane == 0) {
buf_iw[sg_id] = sg_max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce across simdgroups using first simdgroup
if (sg_id == 0) {
float x0 = (lane < num_sg) ? buf_iw[lane] : -INFINITY;
float block_max = simd_max(x0);
if (lane == 0)
buf_iw[0] = block_max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf_iw[0];
// 2) exp(vals - max) and sum
float sum = 0.0f;
for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) {
float e = exp(vals[col] - max_val);
vals[col] = e;
sum += e;
}
float sg_sum = simd_sum(sum);
if (lane == 0) {
buf_iw[sg_id] = sg_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sg_id == 0) {
float x0 = (lane < num_sg) ? buf_iw[lane] : 0.0f;
float block_sum = simd_sum(x0);
if (lane == 0)
buf_iw[0] = block_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf_iw[0];
const float inv_sum = 1.0f / sum;
// 3) write output
for (size_t col = (size_t)tid; col < cols; col += (size_t)tg_sz) {
float y = vals[col] * inv_sum;
out[col * out_strides[4]] = (F)y;
}
}
typedef decltype(scaled_masked_softmax_nd5<float>) scaled_masked_softmax_nd5_t;
template [[host_name("nn_ops::scaled_masked_softmax_nd5_"
"f32")]] [[kernel]] scaled_masked_softmax_nd5_t
scaled_masked_softmax_nd5<float>;
template [[host_name("nn_ops::scaled_masked_softmax_nd5_"
"f16")]] [[kernel]] scaled_masked_softmax_nd5_t
scaled_masked_softmax_nd5<half>;
constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
template <typename F>
[[kernel]] void gelu_approx(device const void *input_b, device void *output_b,
uint tpig [[thread_position_in_grid]]) {
device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;
float x = static_cast<float>(input[tpig]);
float output_f32 =
0.5 * x *
(1.0 +
precise::tanh(SQRT_2_OVER_PI * (x + GELU_COEF_A * metal::powr(x, 3))));
output[tpig] = static_cast<F>(output_f32);
}
typedef decltype(gelu_approx<float>) gelu_approx_t;
template [[host_name(
"nn_ops::gelu_approx_f32")]] [[kernel]] gelu_approx_t gelu_approx<float>;
template [[host_name(
"nn_ops::gelu_approx_f16")]] [[kernel]] gelu_approx_t gelu_approx<half>;
template <typename F>
[[kernel]] void gelu_approx_fast(device const void *input_b,
device void *output_b,
uint tpig [[thread_position_in_grid]]) {
device const F *input = (device const F *)input_b;
device F *output = (device F *)output_b;
float x = static_cast<float>(input[tpig]);
float output_f32 =
0.5 * x *
(1.0 +
precise::tanh(SQRT_2_OVER_PI * (x + GELU_COEF_A * metal::powr(x, 2))));
output[tpig] = static_cast<F>(output_f32);
}
typedef decltype(gelu_approx_fast<float>) gelu_approx_fast_t;
template
[[host_name("nn_ops::gelu_approx_fast_f32")]] [[kernel]] gelu_approx_fast_t
gelu_approx_fast<float>;
template
[[host_name("nn_ops::gelu_approx_fast_f16")]] [[kernel]] gelu_approx_fast_t
gelu_approx_fast<half>;
template <typename T>
[[kernel]] void apply_rope_nd2(device const void *input_b [[buffer(0)]],
device const void *cos_b [[buffer(1)]],
device const void *sin_b [[buffer(2)]],
device void *output_b [[buffer(3)]],
constant const size_t *shape [[buffer(4)]],
constant const size_t *strides [[buffer(5)]],
constant const size_t *cos_sin_strides
[[buffer(6)]],
constant const size_t *out_strides [[buffer(7)]],
uint2 tpig [[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device const T *cos = (device const T *)cos_b;
device const T *sin = (device const T *)sin_b;
device T *output = (device T *)output_b;
uint2 rotated_tpig = tpig;
rotated_tpig.x += shape[1] / 2;
auto idx = indices_to_idx_2(tpig, strides);
auto rot_idx = indices_to_idx_2(rotated_tpig, strides);
auto out_idx = indices_to_idx_2(tpig, out_strides);
auto out_rot_idx = indices_to_idx_2(rotated_tpig, out_strides);
auto cos_sin_idx = indices_to_idx_2(tpig, cos_sin_strides);
auto rot_cos_sin_idx = indices_to_idx_2(rotated_tpig, cos_sin_strides);
output[out_idx] =
input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx] +
input[idx] * sin[rot_cos_sin_idx];
}
template <typename T>
[[kernel]] void apply_rope_nd3(device const void *input_b [[buffer(0)]],
device const void *cos_b [[buffer(1)]],
device const void *sin_b [[buffer(2)]],
device void *output_b [[buffer(3)]],
constant const size_t *shape [[buffer(4)]],
constant const size_t *strides [[buffer(5)]],
constant const size_t *cos_sin_strides
[[buffer(6)]],
constant const size_t *out_strides [[buffer(7)]],
uint3 tpig [[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device const T *cos = (device const T *)cos_b;
device const T *sin = (device const T *)sin_b;
device T *output = (device T *)output_b;
uint3 rotated_tpig = tpig;
rotated_tpig.x += shape[2] / 2;
auto idx = indices_to_idx_3(tpig, strides);
auto rot_idx = indices_to_idx_3(rotated_tpig, strides);
auto out_idx = indices_to_idx_3(tpig, out_strides);
auto out_rot_idx = indices_to_idx_3(rotated_tpig, out_strides);
auto cos_sin_idx = indices_to_idx_3(tpig, cos_sin_strides);
auto rot_cos_sin_idx = indices_to_idx_3(rotated_tpig, cos_sin_strides);
output[out_idx] =
input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx] +
input[idx] * sin[rot_cos_sin_idx];
}
template <typename T>
[[kernel]] void apply_rope_nd4(device const void *input_b [[buffer(0)]],
device const void *cos_b [[buffer(1)]],
device const void *sin_b [[buffer(2)]],
device void *output_b [[buffer(3)]],
constant const size_t *shape [[buffer(4)]],
constant const size_t *strides [[buffer(5)]],
constant const size_t *cos_sin_strides
[[buffer(6)]],
constant const size_t *out_strides [[buffer(7)]],
uint3 tpig [[thread_position_in_grid]]) {
device const T *input = (device const T *)input_b;
device const T *cos = (device const T *)cos_b;
device const T *sin = (device const T *)sin_b;
device T *output = (device T *)output_b;
uint3 rotated_tpig = tpig;
rotated_tpig.x += shape[3] / 2;
auto idx = indices_to_idx_4(tpig, shape, strides);
auto rot_idx = indices_to_idx_4(rotated_tpig, shape, strides);
auto out_idx = indices_to_idx_4(tpig, shape, out_strides);
auto out_rot_idx = indices_to_idx_4(rotated_tpig, shape, out_strides);
auto cos_sin_idx = indices_to_idx_4(tpig, shape, cos_sin_strides);
auto rot_cos_sin_idx =
indices_to_idx_4(rotated_tpig, shape, cos_sin_strides);
output[out_idx] =
input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx] +
input[idx] * sin[rot_cos_sin_idx];
}
typedef decltype(apply_rope_nd2<float>) apply_rope_nd2_t;
typedef decltype(apply_rope_nd3<float>) apply_rope_nd3_t;
typedef decltype(apply_rope_nd4<float>) apply_rope_nd4_t;
template [[host_name("nn_ops::apply_rope_nd2_f32")]] [[kernel]] apply_rope_nd2_t
apply_rope_nd2<float>;
template [[host_name("nn_ops::apply_rope_nd3_f32")]] [[kernel]] apply_rope_nd3_t
apply_rope_nd3<float>;
template [[host_name("nn_ops::apply_rope_nd4_f32")]] [[kernel]] apply_rope_nd4_t
apply_rope_nd4<float>;
template [[host_name("nn_ops::apply_rope_nd2_f16")]] [[kernel]] apply_rope_nd2_t
apply_rope_nd2<half>;
template [[host_name("nn_ops::apply_rope_nd3_f16")]] [[kernel]] apply_rope_nd3_t
apply_rope_nd3<half>;
template [[host_name("nn_ops::apply_rope_nd4_f16")]] [[kernel]] apply_rope_nd4_t
apply_rope_nd4<half>;