use metaltile::kernel;
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
const ALL_FLOAT_DTYPES: &[DType] = &[DType::F32, DType::F16, DType::BF16];
#[kernel]
pub fn mt_fused_gate_gelu<T>(gate: Tensor<T>, up: Tensor<T>, out: Tensor<T>) {
let idx = program_id::<0>();
let g = load(gate[idx]).cast::<f32>();
let u = load(up[idx]).cast::<f32>();
let x3 = g * g * g;
let inner = 0.7978845608f32 * (g + 0.044715f32 * x3);
let act = 0.5f32 * g * (1.0f32 + tanh(inner));
store(out[idx], (act * u).cast::<T>());
}
#[kernel]
pub fn mt_fused_gate_clipped_swiglu<T>(gate: Tensor<T>, up: Tensor<T>, out: Tensor<T>) {
let idx = program_id::<0>();
let g_raw = load(gate[idx]).cast::<f32>();
let u_raw = load(up[idx]).cast::<f32>();
let g_hi = select(g_raw > 7.0f32, 7.0f32, g_raw);
let g = select(g_hi < (0.0f32 - 7.0f32), 0.0f32 - 7.0f32, g_hi);
let u_hi = select(u_raw > 7.0f32, 7.0f32, u_raw);
let u = select(u_hi < (0.0f32 - 7.0f32), 0.0f32 - 7.0f32, u_hi);
let sig = 1.0f32 / (1.0f32 + exp(0.0f32 - 1.702f32 * g));
let act = g * sig * (u + 1.0f32);
store(out[idx], act.cast::<T>());
}
inventory::submit! {
BenchSpec {
op: "fused_gate_activation",
subop: "gelu_approx",
kernel_name: "mt_fused_gate_gelu",
kernel_ir: mt_fused_gate_gelu::kernel_ir_for,
dtypes: ALL_FLOAT_DTYPES,
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
inventory::submit! {
BenchSpec {
op: "fused_gate_activation",
subop: "clipped_swiglu",
kernel_name: "mt_fused_gate_clipped_swiglu",
kernel_ir: mt_fused_gate_clipped_swiglu::kernel_ir_for,
dtypes: ALL_FLOAT_DTYPES,
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}