#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMV_FP8_E4M3_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
// FP8 E4M3FN decode helper (bias=7, no infinity).
static inline float fp8_e4m3_to_float(uchar b) {
// NaN patterns: 0x7F and 0xFF → return 0 for inference.
if (b == 0x7Fu || b == 0xFFu) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
const uint exp = (uint(b) >> 3u) & 15u; // 4-bit exponent
const uint mant = uint(b) & 7u; // 3-bit mantissa
float val;
if (exp == 0u) {
// Denormal: 2^(-6) * (mant / 8)
val = float(mant) * (1.0f / 8.0f) * (1.0f / 64.0f);
} else {
// Normal: 2^(exp - 7) * (1 + mant/8)
// Assemble as IEEE-754 f32: ((exp - 7 + 127) << 23) | (mant << 20)
const uint bits = ((exp - 7u + 127u) << 23u) | (mant << 20u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemv_fp8_e4m3(
device const uchar* blocks_raw [[buffer(0)]],
device const float* input [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u; // k / 32
float local_sum = 0.0f;
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
// Scale at bytes 32-33 (after the 32 FP8 weight bytes).
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
// Dot product: 32 FP8 weights at bytes 0..31 with 32 float inputs.
const uint inp_base = b * 32u;
float block_sum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
block_sum += fp8_e4m3_to_float(blocks_raw[base_byte + w]) * input[inp_base + w];
}
local_sum += scale * block_sum;
}
// Sum across all 32 lanes within the simdgroup.
float row_sum = simd_sum(local_sum);
if (lane == 0u) {
output[row] = row_sum;
}
}
"#;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub const MSL_GEMV_FP8_E5M2_V1: &str = r#"
#include <metal_stdlib>
using namespace metal;
// FP8 E5M2 decode helper (bias=15, with infinity).
static inline float fp8_e5m2_to_float(uchar b) {
const uint exp = (uint(b) >> 2u) & 31u; // 5-bit exponent
const uint mant = uint(b) & 3u; // 2-bit mantissa
// Inf / NaN: exp = 31 → 0 for inference.
if (exp == 31u) return 0.0f;
const uint sign = (uint(b) >> 7u) & 1u;
float val;
if (exp == 0u) {
// Denormal: 2^(-14) * (mant / 4)
val = float(mant) * (1.0f / 4.0f) * (1.0f / 16384.0f);
} else {
// Normal: 2^(exp - 15) * (1 + mant/4)
// Assemble as IEEE-754 f32: ((exp - 15 + 127) << 23) | (mant << 21)
const uint bits = ((exp - 15u + 127u) << 23u) | (mant << 21u);
val = as_type<float>(bits);
}
return sign ? -val : val;
}
kernel void gemv_fp8_e5m2(
device const uchar* blocks_raw [[buffer(0)]],
device const float* input [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n_rows [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tgid [[threadgroup_position_in_grid]],
uint sgid [[simdgroup_index_in_threadgroup]],
uint lane [[thread_index_in_simdgroup]])
{
const uint row = tgid * 8u + sgid;
if (row >= n_rows) return;
const uint blocks_per_row = k >> 5u; // k / 32
float local_sum = 0.0f;
for (uint b = lane; b < blocks_per_row; b += 32u) {
const uint block_idx = row * blocks_per_row + b;
const uint base_byte = block_idx * 34u;
const ushort scale_bits =
ushort(blocks_raw[base_byte + 32u])
| (ushort(blocks_raw[base_byte + 33u]) << 8u);
const float scale = float(as_type<half>(scale_bits));
const uint inp_base = b * 32u;
float block_sum = 0.0f;
for (uint w = 0u; w < 32u; ++w) {
block_sum += fp8_e5m2_to_float(blocks_raw[base_byte + w]) * input[inp_base + w];
}
local_sum += scale * block_sum;
}
float row_sum = simd_sum(local_sum);
if (lane == 0u) {
output[row] = row_sum;
}
}
"#;
#[cfg(test)]
mod tests {
#[test]
#[cfg(all(feature = "metal", target_os = "macos"))]
fn fp8_kernels_contain_entry_points() {
use super::*;
assert!(MSL_GEMV_FP8_E4M3_V1.contains("kernel void gemv_fp8_e4m3"));
assert!(MSL_GEMV_FP8_E5M2_V1.contains("kernel void gemv_fp8_e5m2"));
assert!(MSL_GEMV_FP8_E4M3_V1.contains("fp8_e4m3_to_float"));
assert!(MSL_GEMV_FP8_E5M2_V1.contains("fp8_e5m2_to_float"));
assert!(MSL_GEMV_FP8_E4M3_V1.contains("[[buffer(0)]]"));
assert!(MSL_GEMV_FP8_E4M3_V1.contains("[[buffer(4)]]"));
assert!(MSL_GEMV_FP8_E4M3_V1.contains("* 34u"));
assert!(MSL_GEMV_FP8_E5M2_V1.contains("* 34u"));
}
}