#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMM_FP8_E4M3_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
// FP8 E4M3FN decode (bias=7, no infinity; NaN patterns 0x7F/0xFF → 0).
static inline float pf_fp8_e4m3_to_float(uchar b) {
if (b == 0x7Fu || b == 0xFFu) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
const uint exp = (uint(b) >> 3u) & 15u;
const uint mant = uint(b) & 7u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 8.0f) * (1.0f / 64.0f);
} else {
const uint bits = ((exp - 7u + 127u) << 23u) | (mant << 20u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemm_fp8_e4m3(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float col_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e4m3_to_float(blocks_raw[base_byte + w]) * xbase[w];
}
col_sums[cc] += scale * bsum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float row_sum = simd_sum(col_sums[cc]);
if (lane == 0u) outputs[(col_base + cc) * n_rows + row] += row_sum;
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMM_FP8_E4M3_RESIDUAL_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e4m3_to_float(uchar b) {
if (b == 0x7Fu || b == 0xFFu) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
const uint exp = (uint(b) >> 3u) & 15u;
const uint mant = uint(b) & 7u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 8.0f) * (1.0f / 64.0f);
} else {
const uint bits = ((exp - 7u + 127u) << 23u) | (mant << 20u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemm_fp8_e4m3_residual(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
device const float* residual [[buffer(6)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float col_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e4m3_to_float(blocks_raw[base_byte + w]) * xbase[w];
}
col_sums[cc] += scale * bsum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float row_sum = simd_sum(col_sums[cc]);
if (lane == 0u) {
const uint idx = (col_base + cc) * n_rows + row;
outputs[idx] = residual[idx] + row_sum;
}
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E4M3_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e4m3_to_float(uchar b) {
if (b == 0x7Fu || b == 0xFFu) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
const uint exp = (uint(b) >> 3u) & 15u;
const uint mant = uint(b) & 7u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 8.0f) * (1.0f / 64.0f);
} else {
const uint bits = ((exp - 7u + 127u) << 23u) | (mant << 20u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
// SiLU activation: x * sigmoid(x)
static inline float pf_silu(float x) {
return x / (1.0f + exp(-x));
}
kernel void fused_gate_up_swiglu_gemm_fp8_e4m3(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_ffn_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_ffn_rows) return;
const uint blocks_per_row = k >> 5u;
const uint up_row_offset = n_ffn_rows * blocks_per_row;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float gate_sums[8] = {0,0,0,0,0,0,0,0};
float up_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint gate_block_idx = row * blocks_per_row + b;
const uint up_block_idx = up_row_offset + row * blocks_per_row + b;
const uint gbase = gate_block_idx * 34u;
const uint ubase = up_block_idx * 34u;
const ushort gd_raw =
ushort(blocks_raw[gbase + 32u])
| (ushort(blocks_raw[gbase + 33u]) << 8u);
const float gscale = float(as_type<half>(gd_raw));
const ushort ud_raw =
ushort(blocks_raw[ubase + 32u])
| (ushort(blocks_raw[ubase + 33u]) << 8u);
const float uscale = float(as_type<half>(ud_raw));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float gsum = 0.0f;
float usum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
const float x = xbase[w];
gsum += pf_fp8_e4m3_to_float(blocks_raw[gbase + w]) * x;
usum += pf_fp8_e4m3_to_float(blocks_raw[ubase + w]) * x;
}
gate_sums[cc] += gscale * gsum;
up_sums[cc] += uscale * usum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float gs = simd_sum(gate_sums[cc]);
float us = simd_sum(up_sums[cc]);
if (lane == 0u) {
outputs[(col_base + cc) * n_ffn_rows + row] = pf_silu(gs) * us;
}
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMV_FP8_E4M3_PF_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e4m3_to_float(uchar b) {
if (b == 0x7Fu || b == 0xFFu) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
const uint exp = (uint(b) >> 3u) & 15u;
const uint mant = uint(b) & 7u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 8.0f) * (1.0f / 64.0f);
} else {
const uint bits = ((exp - 7u + 127u) << 23u) | (mant << 20u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemv_fp8_e4m3_pf(
device const uchar* blocks_raw [[buffer(0)]],
device const float* input [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
float local_sum = 0.0f;
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint base_byte = (row * blocks_per_row + b) * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e4m3_to_float(blocks_raw[base_byte + w]) * input[inp_base + w];
}
local_sum += scale * bsum;
}
float row_sum = simd_sum(local_sum);
if (lane == 0u) output[row] = row_sum;
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMM_FP8_E5M2_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
// FP8 E5M2 decode (bias=15, exp=31 → 0).
static inline float pf_fp8_e5m2_to_float(uchar b) {
const uint exp = (uint(b) >> 2u) & 31u;
const uint mant = uint(b) & 3u;
if (exp == 31u) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 4.0f) * (1.0f / 16384.0f);
} else {
const uint bits = ((exp - 15u + 127u) << 23u) | (mant << 21u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemm_fp8_e5m2(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float col_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e5m2_to_float(blocks_raw[base_byte + w]) * xbase[w];
}
col_sums[cc] += scale * bsum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float row_sum = simd_sum(col_sums[cc]);
if (lane == 0u) outputs[(col_base + cc) * n_rows + row] += row_sum;
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMM_FP8_E5M2_RESIDUAL_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e5m2_to_float(uchar b) {
const uint exp = (uint(b) >> 2u) & 31u;
const uint mant = uint(b) & 3u;
if (exp == 31u) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 4.0f) * (1.0f / 16384.0f);
} else {
const uint bits = ((exp - 15u + 127u) << 23u) | (mant << 21u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemm_fp8_e5m2_residual(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
device const float* residual [[buffer(6)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float col_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e5m2_to_float(blocks_raw[base_byte + w]) * xbase[w];
}
col_sums[cc] += scale * bsum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float row_sum = simd_sum(col_sums[cc]);
if (lane == 0u) {
const uint idx = (col_base + cc) * n_rows + row;
outputs[idx] = residual[idx] + row_sum;
}
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E5M2_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e5m2_to_float(uchar b) {
const uint exp = (uint(b) >> 2u) & 31u;
const uint mant = uint(b) & 3u;
if (exp == 31u) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 4.0f) * (1.0f / 16384.0f);
} else {
const uint bits = ((exp - 15u + 127u) << 23u) | (mant << 21u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
static inline float pf_silu(float x) {
return x / (1.0f + exp(-x));
}
kernel void fused_gate_up_swiglu_gemm_fp8_e5m2(
device const uchar* blocks_raw [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& n_ffn_rows [[buffer(3)]],
constant uint& batch_size [[buffer(4)]],
constant uint& k [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_ffn_rows) return;
const uint blocks_per_row = k >> 5u;
const uint up_row_offset = n_ffn_rows * blocks_per_row;
for (uint col_base = 0u; col_base < batch_size; col_base += 8u) {
const uint cols_remaining = batch_size - col_base;
const uint cols = cols_remaining < 8u ? cols_remaining : 8u;
float gate_sums[8] = {0,0,0,0,0,0,0,0};
float up_sums[8] = {0,0,0,0,0,0,0,0};
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint gbase = (row * blocks_per_row + b) * 34u;
const uint ubase = (up_row_offset + row * blocks_per_row + b) * 34u;
const ushort gd_raw =
ushort(blocks_raw[gbase + 32u])
| (ushort(blocks_raw[gbase + 33u]) << 8u);
const float gscale = float(as_type<half>(gd_raw));
const ushort ud_raw =
ushort(blocks_raw[ubase + 32u])
| (ushort(blocks_raw[ubase + 33u]) << 8u);
const float uscale = float(as_type<half>(ud_raw));
const uint inp_base = b * 32u;
for (uint cc = 0u; cc < cols; cc++) {
device const float* xbase = inputs + (col_base + cc) * k + inp_base;
float gsum = 0.0f;
float usum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
const float x = xbase[w];
gsum += pf_fp8_e5m2_to_float(blocks_raw[gbase + w]) * x;
usum += pf_fp8_e5m2_to_float(blocks_raw[ubase + w]) * x;
}
gate_sums[cc] += gscale * gsum;
up_sums[cc] += uscale * usum;
}
}
for (uint cc = 0u; cc < cols; cc++) {
float gs = simd_sum(gate_sums[cc]);
float us = simd_sum(up_sums[cc]);
if (lane == 0u) {
outputs[(col_base + cc) * n_ffn_rows + row] = pf_silu(gs) * us;
}
}
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMV_FP8_E5M2_PF_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
static inline float pf_fp8_e5m2_to_float(uchar b) {
const uint exp = (uint(b) >> 2u) & 31u;
const uint mant = uint(b) & 3u;
if (exp == 31u) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
float val;
if (exp == 0u) {
val = float(mant) * (1.0f / 4.0f) * (1.0f / 16384.0f);
} else {
const uint bits = ((exp - 15u + 127u) << 23u) | (mant << 21u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemv_fp8_e5m2_pf(
device const uchar* blocks_raw [[buffer(0)]],
device const float* input [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u;
float local_sum = 0.0f;
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint base_byte = (row * blocks_per_row + b) * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
float bsum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
bsum += pf_fp8_e5m2_to_float(blocks_raw[base_byte + w]) * input[inp_base + w];
}
local_sum += scale * bsum;
}
float row_sum = simd_sum(local_sum);
if (lane == 0u) output[row] = row_sum;
}
"#;
#[cfg(test)]
mod tests {
#[test]
#[cfg(all(feature = "metal", target_os = "macos"))]
fn fp8_prefill_kernels_contain_entry_points() {
use super::*;
assert!(MSL_GEMM_FP8_E4M3_V1.contains("kernel void gemm_fp8_e4m3"));
assert!(MSL_GEMM_FP8_E4M3_RESIDUAL_V1.contains("kernel void gemm_fp8_e4m3_residual"));
assert!(MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E4M3_V1
.contains("kernel void fused_gate_up_swiglu_gemm_fp8_e4m3"));
assert!(MSL_GEMV_FP8_E4M3_PF_V1.contains("kernel void gemv_fp8_e4m3_pf"));
assert!(MSL_GEMM_FP8_E5M2_V1.contains("kernel void gemm_fp8_e5m2"));
assert!(MSL_GEMM_FP8_E5M2_RESIDUAL_V1.contains("kernel void gemm_fp8_e5m2_residual"));
assert!(MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E5M2_V1
.contains("kernel void fused_gate_up_swiglu_gemm_fp8_e5m2"));
assert!(MSL_GEMV_FP8_E5M2_PF_V1.contains("kernel void gemv_fp8_e5m2_pf"));
for src in [
MSL_GEMM_FP8_E4M3_V1,
MSL_GEMM_FP8_E4M3_RESIDUAL_V1,
MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E4M3_V1,
MSL_GEMV_FP8_E4M3_PF_V1,
MSL_GEMM_FP8_E5M2_V1,
MSL_GEMM_FP8_E5M2_RESIDUAL_V1,
MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E5M2_V1,
MSL_GEMV_FP8_E5M2_PF_V1,
] {
assert!(src.contains("* 34u"));
assert!(src.contains("[[buffer(0)]]"));
}
for src in [
MSL_GEMM_FP8_E4M3_V1,
MSL_GEMM_FP8_E4M3_RESIDUAL_V1,
MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E4M3_V1,
MSL_GEMM_FP8_E5M2_V1,
MSL_GEMM_FP8_E5M2_RESIDUAL_V1,
MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E5M2_V1,
] {
assert!(src.contains("col_base += 8u"));
}
assert!(MSL_GEMM_FP8_E4M3_RESIDUAL_V1.contains("[[buffer(6)]]"));
assert!(MSL_GEMM_FP8_E5M2_RESIDUAL_V1.contains("[[buffer(6)]]"));
assert!(MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E4M3_V1.contains("pf_silu"));
assert!(MSL_FUSED_GATE_UP_SWIGLU_GEMM_FP8_E5M2_V1.contains("pf_silu"));
}
}