metaltile-std 0.1.0

MetalTile kernel standard library — benchmark metadata and type definitions
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! SwiGLU activation — `silu(gate) * up`.
//!
//! Fused element-wise activation used in every modern transformer MLP
//! (Llama 4, Qwen3 dense + MoE, Gemma, Mistral families): given two
//! equally-sized inputs `gate` and `up` (the two halves of an MLP's
//! `w_gate · x` and `w_up · x` outputs), produce
//!
//! ```text
//!   out[i] = silu(gate[i]) * up[i]
//!         = (gate[i] * sigmoid(gate[i])) * up[i]
//! ```
//!
//! Existing baseline: two separate kernel launches — one applies
//! `silu(gate)` elementwise (`mt_silu` in `unary.rs`), the second
//! multiplies by `up` (`mt_binary` mul). Each load+store cycles the
//! intermediate `silu(gate)` value through device memory.
//!
//! Fusion saves one full-tensor RMW: the intermediate value stays in
//! registers, halving global memory traffic on the activation path.
//! At Qwen3-MoE expert intermediate=768 × prefill 512 tokens =
//! ~400KB per layer per expert; across 48 layers × 8 active experts
//! the saved bandwidth adds up.
//!
//! MLX reference: `mx.fast.swiglu` lives in
//! `mlx/mlx/backend/metal/kernels/fast.metal` as a single launch with
//! `silu(g) * u` in the body. We mirror that pattern.
//!
//! ## Cross-kernel calling
//!
//! `mt_swiglu` calls `mt_silu` via the DSL cross-kernel call syntax
//! (just the kernel name). `KernelInlinePass` splices the silu body
//! inline before MSL emission — no extra memory round-trip, same code
//! quality as a manual inline, with a clear compositional structure
//! that future fusion passes can reason about.
//!
//! Type-efficiency: `g` and `u` are loaded and cast to f32 before the
//! call. `KernelInlinePass` replaces `mt_silu`'s input-param load with
//! the actual f32 arg, so all arithmetic stays in f32 regardless of T.
//! No T→f32→T precision loss in the silu path.

use metaltile::{bench_kernel, kernel};

#[bench_kernel(
    op="swiglu",
    subop="swiglu",
    class=Binary,
    input_a=Signed,
    input_b=Signed,
    tol=1e-3,
)]
#[kernel]
pub fn mt_swiglu<T>(gate: Tensor<T>, up: Tensor<T>, out: Tensor<T>) {
    let idx = tid;
    let g = load(gate[idx]).cast::<f32>();
    let u = load(up[idx]).cast::<f32>();
    // Cross-kernel call: KernelInlinePass splices mt_silu's scalar body
    // here. mt_silu's input-param load is replaced by g (already f32),
    // so silu runs in f32. Future fusion passes can identify the
    // (silu, mul) → swiglu composition pattern from this call site.
    let s = mt_silu(g);
    store(out[idx], (s * u).cast::<T>());
}