#[cfg(target_os = "macos")]
mod buffers;
#[cfg(target_os = "macos")]
mod context;
#[cfg(target_os = "macos")]
mod operations;
#[cfg(target_os = "macos")]
mod pipelines;
#[cfg(target_os = "macos")]
pub use buffers::{MetalBuffer, MetalBufferPool};
#[cfg(target_os = "macos")]
pub use context::{MetalConfig, MetalContext};
#[cfg(target_os = "macos")]
pub use operations::{
batched_gemm_metal,
dequantize_int8,
fp16_to_fp32,
fp32_to_fp16,
gemv_batched_metal,
gemv_metal,
gemv_metal_f16,
gemv_metal_with_params,
quantize_int8,
verify_speculative_tokens,
GemvParams,
};
#[cfg(target_os = "macos")]
pub use pipelines::{MetalPipelines, PipelineCache};
use crate::error::{Result, RuvLLMError};
use crate::kernels::AttentionConfig;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct AttentionParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_len: u32,
pub scale: f32,
pub causal: u32,
pub _padding: u32,
}
impl AttentionParams {
pub fn from_config(config: &AttentionConfig, seq_len: usize, kv_len: usize) -> Self {
Self {
num_heads: config.num_heads as u32,
num_kv_heads: config.num_kv_heads as u32,
head_dim: config.head_dim as u32,
seq_len: seq_len as u32,
kv_len: kv_len as u32,
scale: config.effective_scale(),
causal: config.causal as u32,
_padding: 0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct GemmParams {
pub m: u32,
pub n: u32,
pub k: u32,
pub lda: u32,
pub ldb: u32,
pub ldc: u32,
pub alpha: f32,
pub beta: f32,
}
impl GemmParams {
pub fn new(m: usize, n: usize, k: usize) -> Self {
Self {
m: m as u32,
n: n as u32,
k: k as u32,
lda: k as u32, ldb: n as u32,
ldc: n as u32,
alpha: 1.0,
beta: 0.0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct NormParams {
pub hidden_size: u32,
pub eps: f32,
pub elements_per_thread: u32,
pub _padding: u32,
}
impl NormParams {
pub fn new(hidden_size: usize, eps: f32) -> Self {
let elements_per_thread = (hidden_size + 255) / 256; Self {
hidden_size: hidden_size as u32,
eps,
elements_per_thread: elements_per_thread as u32,
_padding: 0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct RopeParams {
pub head_dim: u32,
pub num_heads: u32,
pub position: u32,
pub theta_base: f32,
}
impl RopeParams {
pub fn new(head_dim: usize, num_heads: usize, position: usize, theta_base: f32) -> Self {
Self {
head_dim: head_dim as u32,
num_heads: num_heads as u32,
position: position as u32,
theta_base,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct FusedAttentionParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_len: u32,
pub scale: f32,
pub causal: u32,
pub block_size: u32,
}
impl FusedAttentionParams {
pub fn new(
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
seq_len: usize,
kv_len: usize,
causal: bool,
) -> Self {
Self {
num_heads: num_heads as u32,
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
seq_len: seq_len as u32,
kv_len: kv_len as u32,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: causal as u32,
block_size: 64, }
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct FusedNormParams {
pub hidden_size: u32,
pub eps: f32,
pub residual_scale: f32,
pub _padding: u32,
}
impl FusedNormParams {
pub fn new(hidden_size: usize, eps: f32) -> Self {
Self {
hidden_size: hidden_size as u32,
eps,
residual_scale: 1.0,
_padding: 0,
}
}
pub fn with_residual_scale(hidden_size: usize, eps: f32, residual_scale: f32) -> Self {
Self {
hidden_size: hidden_size as u32,
eps,
residual_scale,
_padding: 0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Int4GemvParams {
pub m: u32,
pub n: u32,
pub group_size: u32,
pub num_groups: u32,
}
impl Int4GemvParams {
pub fn new(m: usize, n: usize, group_size: usize) -> Self {
let num_groups = (n + group_size - 1) / group_size;
Self {
m: m as u32,
n: n as u32,
group_size: group_size as u32,
num_groups: num_groups as u32,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct RopeAttentionParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_len: u32,
pub position_offset: u32,
pub rope_theta: f32,
pub scale: f32,
pub causal: u32,
pub _padding: [u32; 3],
}
impl RopeAttentionParams {
pub fn new(
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
seq_len: usize,
kv_len: usize,
position_offset: usize,
rope_theta: f32,
causal: bool,
) -> Self {
Self {
num_heads: num_heads as u32,
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
seq_len: seq_len as u32,
kv_len: kv_len as u32,
position_offset: position_offset as u32,
rope_theta,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: causal as u32,
_padding: [0; 3],
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct YarnAttentionParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub kv_len: u32,
pub position_offset: u32,
pub rope_theta: f32,
pub attn_scale: f32,
pub yarn_scale: f32,
pub original_max_position: u32,
pub causal: u32,
pub _padding: u32,
}
impl YarnAttentionParams {
pub fn new(
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
seq_len: usize,
kv_len: usize,
position_offset: usize,
rope_theta: f32,
original_max_position: usize,
target_max_position: usize,
causal: bool,
) -> Self {
let yarn_scale = (target_max_position as f32) / (original_max_position as f32);
Self {
num_heads: num_heads as u32,
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
seq_len: seq_len as u32,
kv_len: kv_len as u32,
position_offset: position_offset as u32,
rope_theta,
attn_scale: 1.0 / (head_dim as f32).sqrt(),
yarn_scale,
original_max_position: original_max_position as u32,
causal: causal as u32,
_padding: 0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct PagedAttentionParams {
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub seq_len: u32,
pub block_size: u32,
pub num_pages: u32,
pub scale: f32,
pub causal: u32,
}
impl PagedAttentionParams {
pub fn new(
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
seq_len: usize,
block_size: usize,
num_pages: usize,
causal: bool,
) -> Self {
Self {
num_heads: num_heads as u32,
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
seq_len: seq_len as u32,
block_size: block_size as u32,
num_pages: num_pages as u32,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: causal as u32,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct QuantParams {
pub group_size: u32,
pub num_groups: u32,
pub zero_point_mode: u32,
pub _padding: u32,
}
impl QuantParams {
pub fn new(group_size: usize, num_elements: usize, asymmetric: bool) -> Self {
let num_groups = (num_elements + group_size - 1) / group_size;
Self {
group_size: group_size as u32,
num_groups: num_groups as u32,
zero_point_mode: asymmetric as u32,
_padding: 0,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct SwiGLUParams {
pub hidden_size: u32,
pub intermediate_size: u32,
pub _padding: [u32; 2],
}
impl SwiGLUParams {
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self {
hidden_size: hidden_size as u32,
intermediate_size: intermediate_size as u32,
_padding: [0; 2],
}
}
}
pub mod tile_sizes {
pub const ATTENTION_TILE: usize = 64;
pub const GEMM_TILE_M: usize = 64;
pub const GEMM_TILE_N: usize = 64;
pub const GEMM_TILE_K: usize = 32;
pub const SIMD_SIZE: usize = 32;
pub const MAX_THREADS_PER_THREADGROUP: usize = 1024;
pub const M4_GEMM_TILE_M: usize = 128;
pub const M4_GEMM_TILE_N: usize = 128;
pub const M4_GEMM_TILE_K: usize = 32;
pub const FLASH_ATTENTION_BLOCK: usize = 64;
pub const FUSED_ATTENTION_Q_BLOCK: usize = 64;
pub const FUSED_ATTENTION_KV_BLOCK: usize = 64;
pub const INT4_GROUP_SIZE: usize = 32;
pub const INT8_GROUP_SIZE: usize = 128;
pub const M4_WARPS_PER_BLOCK: usize = 16;
pub const THREADS_PER_WARP: usize = 32;
pub const M4_L1_CACHE_SIZE: usize = 16 * 1024;
pub const M4_L2_CACHE_SIZE: usize = 192 * 1024;
pub const M4_THREADGROUP_MEMORY: usize = 16 * 1024;
}
#[cfg(target_os = "macos")]
pub fn is_metal_available() -> bool {
metal::Device::system_default().is_some()
}
#[cfg(not(target_os = "macos"))]
pub fn is_metal_available() -> bool {
false
}
#[cfg(target_os = "macos")]
pub fn get_device_info() -> Option<MetalDeviceInfo> {
metal::Device::system_default().map(|device| MetalDeviceInfo {
name: device.name().to_string(),
registry_id: device.registry_id(),
max_threads_per_threadgroup: device.max_threads_per_threadgroup().width as usize,
max_buffer_length: device.max_buffer_length() as usize,
has_unified_memory: device.has_unified_memory(),
recommended_max_working_set_size: device.recommended_max_working_set_size() as usize,
})
}
#[cfg(not(target_os = "macos"))]
pub fn get_device_info() -> Option<MetalDeviceInfo> {
None
}
#[derive(Debug, Clone)]
pub struct MetalDeviceInfo {
pub name: String,
pub registry_id: u64,
pub max_threads_per_threadgroup: usize,
pub max_buffer_length: usize,
pub has_unified_memory: bool,
pub recommended_max_working_set_size: usize,
}
pub mod shader_source {
pub const ATTENTION: &str = include_str!("shaders/attention.metal");
pub const GEMM: &str = include_str!("shaders/gemm.metal");
pub const NORM: &str = include_str!("shaders/norm.metal");
pub const ROPE: &str = include_str!("shaders/rope.metal");
pub const ATTENTION_FUSED: &str = include_str!("shaders/attention_fused.metal");
pub const FUSED_OPS: &str = include_str!("shaders/fused_ops.metal");
pub const QUANTIZED: &str = include_str!("shaders/quantized.metal");
pub const ROPE_ATTENTION: &str = include_str!("shaders/rope_attention.metal");
pub fn all_optimized_shaders() -> String {
format!(
"{}\n{}\n{}\n{}",
ATTENTION_FUSED, FUSED_OPS, QUANTIZED, ROPE_ATTENTION
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_params() {
let config = AttentionConfig {
num_heads: 32,
num_kv_heads: 8,
head_dim: 128,
max_seq_len: 4096,
causal: true,
scale: 0.0,
};
let params = AttentionParams::from_config(&config, 1, 100);
assert_eq!(params.num_heads, 32);
assert_eq!(params.num_kv_heads, 8);
assert_eq!(params.head_dim, 128);
assert!(params.scale > 0.0);
}
#[test]
fn test_gemm_params() {
let params = GemmParams::new(64, 128, 256);
assert_eq!(params.m, 64);
assert_eq!(params.n, 128);
assert_eq!(params.k, 256);
assert_eq!(params.alpha, 1.0);
assert_eq!(params.beta, 0.0);
}
#[test]
fn test_norm_params() {
let params = NormParams::new(4096, 1e-6);
assert_eq!(params.hidden_size, 4096);
assert!((params.eps - 1e-6).abs() < 1e-10);
}
#[test]
fn test_rope_params() {
let params = RopeParams::new(128, 32, 0, 10000.0);
assert_eq!(params.head_dim, 128);
assert_eq!(params.num_heads, 32);
assert_eq!(params.theta_base, 10000.0);
}
#[cfg(target_os = "macos")]
#[test]
fn test_metal_available() {
let available = is_metal_available();
println!("Metal available: {}", available);
if available {
let info = get_device_info().unwrap();
println!("Device: {}", info.name);
println!("Unified memory: {}", info.has_unified_memory);
}
}
#[test]
fn test_fused_attention_params() {
let params = FusedAttentionParams::new(32, 8, 128, 16, 2048, true);
assert_eq!(params.num_heads, 32);
assert_eq!(params.num_kv_heads, 8);
assert_eq!(params.head_dim, 128);
assert_eq!(params.seq_len, 16);
assert_eq!(params.kv_len, 2048);
assert_eq!(params.causal, 1);
assert_eq!(params.block_size, 64); assert!((params.scale - 0.0884).abs() < 0.001);
}
#[test]
fn test_fused_norm_params() {
let params = FusedNormParams::new(4096, 1e-5);
assert_eq!(params.hidden_size, 4096);
assert!((params.eps - 1e-5).abs() < 1e-10);
assert!((params.residual_scale - 1.0).abs() < 1e-10);
let params_scaled = FusedNormParams::with_residual_scale(4096, 1e-5, 0.5);
assert!((params_scaled.residual_scale - 0.5).abs() < 1e-10);
}
#[test]
fn test_int4_gemv_params() {
let params = Int4GemvParams::new(4096, 4096, 32);
assert_eq!(params.m, 4096);
assert_eq!(params.n, 4096);
assert_eq!(params.group_size, 32);
assert_eq!(params.num_groups, 128); }
#[test]
fn test_rope_attention_params() {
let params = RopeAttentionParams::new(32, 8, 128, 16, 2048, 100, 10000.0, true);
assert_eq!(params.num_heads, 32);
assert_eq!(params.num_kv_heads, 8);
assert_eq!(params.head_dim, 128);
assert_eq!(params.position_offset, 100);
assert_eq!(params.rope_theta, 10000.0);
assert_eq!(params.causal, 1);
}
#[test]
fn test_yarn_attention_params() {
let params = YarnAttentionParams::new(32, 8, 128, 16, 2048, 0, 10000.0, 4096, 16384, true);
assert_eq!(params.num_heads, 32);
assert_eq!(params.original_max_position, 4096);
assert!((params.yarn_scale - 4.0).abs() < 1e-5);
}
#[test]
fn test_paged_attention_params() {
let params = PagedAttentionParams::new(32, 8, 128, 16, 64, 32, true);
assert_eq!(params.num_heads, 32);
assert_eq!(params.num_kv_heads, 8);
assert_eq!(params.block_size, 64);
assert_eq!(params.num_pages, 32);
assert_eq!(params.causal, 1);
}
#[test]
fn test_quant_params() {
let params = QuantParams::new(32, 4096, false);
assert_eq!(params.group_size, 32);
assert_eq!(params.num_groups, 128); assert_eq!(params.zero_point_mode, 0);
let params_asym = QuantParams::new(128, 4096, true);
assert_eq!(params_asym.group_size, 128);
assert_eq!(params_asym.num_groups, 32); assert_eq!(params_asym.zero_point_mode, 1); }
#[test]
fn test_swiglu_params() {
let params = SwiGLUParams::new(4096, 11008);
assert_eq!(params.hidden_size, 4096);
assert_eq!(params.intermediate_size, 11008);
}
#[test]
fn test_m4_pro_tile_sizes() {
assert_eq!(tile_sizes::M4_GEMM_TILE_M, 128);
assert_eq!(tile_sizes::M4_GEMM_TILE_N, 128);
assert_eq!(tile_sizes::M4_GEMM_TILE_K, 32);
assert_eq!(tile_sizes::FLASH_ATTENTION_BLOCK, 64);
assert_eq!(tile_sizes::INT4_GROUP_SIZE, 32);
assert_eq!(tile_sizes::M4_THREADGROUP_MEMORY, 16 * 1024);
assert_eq!(tile_sizes::MAX_THREADS_PER_THREADGROUP, 1024);
}
}