1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! SwiGLU activation — `silu(gate) * up`.
//!
//! Fused element-wise activation used in every modern transformer MLP
//! (Llama 4, Qwen3 dense + MoE, Gemma, Mistral families): given two
//! equally-sized inputs `gate` and `up` (the two halves of an MLP's
//! `w_gate · x` and `w_up · x` outputs), produce
//!
//! ```text
//! out[i] = silu(gate[i]) * up[i]
//! = (gate[i] * sigmoid(gate[i])) * up[i]
//! ```
//!
//! Existing baseline: two separate kernel launches — one applies
//! `silu(gate)` elementwise (`mt_silu` in `unary.rs`), the second
//! multiplies by `up` (`mt_binary` mul). Each load+store cycles the
//! intermediate `silu(gate)` value through device memory.
//!
//! Fusion saves one full-tensor RMW: the intermediate value stays in
//! registers, halving global memory traffic on the activation path.
//! At Qwen3-MoE expert intermediate=768 × prefill 512 tokens =
//! ~400KB per layer per expert; across 48 layers × 8 active experts
//! the saved bandwidth adds up.
//!
//! MLX reference: `mx.fast.swiglu` lives in
//! `mlx/mlx/backend/metal/kernels/fast.metal` as a single launch with
//! `silu(g) * u` in the body. We mirror that pattern.
//!
//! ## Cross-kernel calling
//!
//! `mt_swiglu` calls `mt_silu` via the DSL cross-kernel call syntax
//! (just the kernel name). `KernelInlinePass` splices the silu body
//! inline before MSL emission — no extra memory round-trip, same code
//! quality as a manual inline, with a clear compositional structure
//! that future fusion passes can reason about.
//!
//! Type-efficiency: `g` and `u` are loaded and cast to f32 before the
//! call. `KernelInlinePass` replaces `mt_silu`'s input-param load with
//! the actual f32 arg, so all arithmetic stays in f32 regardless of T.
//! No T→f32→T precision loss in the silu path.
use ;