use crate::metal::{Buffer, ComputeCommandEncoder, Device, MetalDeviceType};
use crate::utils::EncoderProvider;
use crate::{set_params, ConstantValues, EncoderParam, Kernels, MetalKernelError, Source, Value};
use objc2_metal::{MTLResourceUsage, MTLSize};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
BF16,
F16,
F32,
}
#[derive(Copy, Clone, Debug)]
struct TileConfig {
bm: usize, bn: usize, bk: usize, wm: usize, wn: usize, }
impl TileConfig {
const fn new(bm: usize, bn: usize, bk: usize, wm: usize, wn: usize) -> Self {
Self { bm, bn, bk, wm, wn }
}
}
#[allow(dead_code)]
const TILE_32_32_16_2_2: TileConfig = TileConfig::new(32, 32, 16, 2, 2);
const TILE_64_64_16_2_2: TileConfig = TileConfig::new(64, 64, 16, 2, 2);
const TILE_64_64_16_1_2: TileConfig = TileConfig::new(64, 64, 16, 1, 2);
const TILE_64_32_32_2_2: TileConfig = TileConfig::new(64, 32, 32, 2, 2);
const TILE_32_64_16_1_2: TileConfig = TileConfig::new(32, 64, 16, 1, 2);
fn select_tile_config(
dtype: GemmDType,
m: usize,
n: usize,
k: usize,
batch_size: usize,
a_trans: bool,
b_trans: bool,
device_type: MetalDeviceType,
) -> TileConfig {
if m < 16 {
return TILE_32_32_16_2_2;
}
let total_output = batch_size * m * n;
let is_large_matmul = total_output >= (1 << 20);
match device_type {
MetalDeviceType::Phone | MetalDeviceType::BasePro => {
if !a_trans && b_trans {
TILE_64_32_32_2_2
} else if dtype != GemmDType::F32 {
TILE_64_64_16_1_2
} else {
TILE_64_64_16_2_2
}
}
MetalDeviceType::Ultra => {
if is_large_matmul {
if dtype != GemmDType::F32 {
if 2 * m.max(n) > k {
TILE_64_64_16_1_2
} else if !a_trans && b_trans {
TILE_64_32_32_2_2
} else {
TILE_32_64_16_1_2
}
} else {
TILE_64_64_16_2_2
}
} else {
if dtype != GemmDType::F32 {
if !a_trans && b_trans {
TILE_64_32_32_2_2
} else {
TILE_64_64_16_1_2
}
} else {
if !a_trans && b_trans {
TILE_32_64_16_1_2
} else {
TILE_64_32_32_2_2
}
}
}
}
MetalDeviceType::Max | MetalDeviceType::Medium => {
match dtype {
GemmDType::F32 => {
if !is_large_matmul {
if !a_trans && b_trans {
TILE_32_64_16_1_2
} else {
TILE_64_32_32_2_2
}
} else {
TILE_64_64_16_2_2
}
}
GemmDType::F16 | GemmDType::BF16 => {
if is_large_matmul {
if 2 * m.max(n) > k {
TILE_64_64_16_1_2
} else if !a_trans && b_trans {
TILE_64_32_32_2_2
} else {
TILE_32_64_16_1_2
}
} else if !a_trans && b_trans {
TILE_64_32_32_2_2
} else {
TILE_64_64_16_1_2
}
}
}
}
}
}
fn check_batch_collapse(
b: usize,
m: usize,
k: usize,
a_trans: bool,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> (usize, usize, bool) {
if b <= 1 {
return (b, m, false);
}
if a_trans {
return (b, m, false);
}
let a_batch_stride = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3]
} else {
m * k
};
let b_batch_stride = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3]
} else {
0 };
let a_contiguous = a_batch_stride == m * k;
let b_broadcasted = b_batch_stride == 0;
if a_contiguous && b_broadcasted {
(1, b * m, true)
} else {
(b, m, false)
}
}
#[allow(dead_code)]
fn should_use_split_k(b: usize, m: usize, n: usize, k: usize) -> bool {
if b != 1 {
return false;
}
let tm = m / 16;
let tn = n / 16;
let tk = k / 16;
(tm * tn) <= 32 && tk >= 8
}
#[allow(clippy::too_many_arguments)]
pub fn call_mlx_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct GemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, false)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
}
.bt())?;
};
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, false)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
}
.bt())?;
};
let (effective_batch, effective_m, batch_collapsed) =
check_batch_collapse(b, m, k, a_trans, lhs_stride, rhs_stride);
let m = effective_m;
let b = effective_batch;
let device_type = device.device_type();
let tile = select_tile_config(dtype, m, n, k, b, a_trans, b_trans, device_type);
let (bm, bn, bk, wm, wn) = (tile.bm, tile.bn, tile.bk, tile.wm, tile.wn);
let has_batch = b > 1;
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool(has_batch)),
(100, Value::Bool( false)),
(110, Value::Bool( false)),
(200, Value::Bool( m % bm == 0)),
(201, Value::Bool( n % bn == 0)),
(202, Value::Bool( k % bk == 0)),
(300, Value::Bool( false)),
]));
let swizzle_log = 0;
let tile_swizzle = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile_swizzle;
let tm = tm.div_ceil(tile_swizzle);
let (batch_stride_a, batch_stride_b) = if batch_collapsed {
(0isize, 0isize)
} else {
let a_stride = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3] as isize
} else {
(m * k) as isize
};
let b_stride = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3] as isize
} else {
(n * k) as isize
};
(a_stride, b_stride)
};
let gemm_params = GemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda: if batch_collapsed { k as i32 } else { lda }, ldb,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a,
batch_stride_b,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let dtype_str = match dtype {
GemmDType::F32 => "f32",
GemmDType::F16 => "f16",
GemmDType::BF16 => "bf16",
};
let trans_str = match (a_trans, b_trans) {
(false, false) => "nn",
(true, false) => "tn",
(false, true) => "nt",
(true, true) => "tt",
};
let name = format!(
"gemm_{}_{}_{}_{}_{}_{}_{}_{}",
trans_str, dtype_str, dtype_str, bm, bn, bk, wm, wn
);
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
impl EncoderParam for GemmParams {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes(position, &data);
}
}
let batch_strides = [batch_stride_a, batch_stride_b];
set_params!(
encoder,
(
(lhs_buffer, lhs_offset),
(rhs_buffer, rhs_offset),
(),
output,
gemm_params,
(),
b as i32,
&batch_strides[..]
)
);
let grid_size = MTLSize {
width: tn,
height: tm,
depth: b,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
};
encoder.use_resource(lhs_buffer, MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}