use crate::kernels::matmul::{GemmDispatchParams, GemmKernel};
use crate::{ConstantValues, LibraryName, MetalStream, Value};
use anyhow::ensure;
use metal::{Buffer, MTLSize, NSUInteger};
use std::ffi::c_void;
use std::fmt;
use tract_core::internal::*;
use tract_gpu::tensor::DeviceTensor;
#[derive(Debug)]
#[repr(C)]
struct MlxGemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct MlxGemm;
impl fmt::Display for MlxGemm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MlxGemm")
}
}
impl GemmKernel for MlxGemm {
fn name() -> &'static str {
"mlx"
}
fn dispatch_eval(
&self,
stream: &MetalStream,
params: GemmDispatchParams,
a_buffer: &Buffer,
b_buffer: &Buffer,
c_buffer: &Buffer,
) -> TractResult<()> {
let GemmDispatchParams {
dts,
a_batch,
m,
k,
n,
transpose_a,
a_offset,
transpose_b,
b_offset,
c_offset,
a_strides,
b_strides,
..
} = params;
ensure!(
matches!(dts[0], DatumType::F32 | DatumType::F16),
"Unsupported datum type for MlxGemm {:?}",
dts[0]
);
ensure!(
dts[0] == dts[1] && dts[0] == dts[2],
"MlxGemm only supports homogeneous datum types. I: {:?}, {:?}. O: {:?}",
dts[0],
dts[1],
dts[2]
);
if m == 1 || n == 1 {
dispatch_metal_mlx_gemv(
stream,
dts[0],
(a_batch, m, n, k),
unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) },
a_offset,
a_buffer,
transpose_a,
unsafe { std::mem::transmute::<&[isize], &[usize]>(b_strides.as_slice()) },
b_offset,
b_buffer,
transpose_b,
c_buffer,
c_offset,
)?;
} else {
dispatch_metal_mlx_gemm(
stream,
dts[0],
(a_batch, m, n, k),
unsafe { std::mem::transmute::<&[isize], &[usize]>(a_strides.as_slice()) },
a_offset,
a_buffer,
transpose_a,
unsafe { std::mem::transmute::<&[isize], &[usize]>(b_strides.as_slice()) },
b_offset,
b_buffer,
transpose_b,
c_buffer,
c_offset,
false,
)?;
}
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_metal_mlx_gemv(
stream: &MetalStream,
dt: DatumType,
(b, m, n, k): (usize, usize, usize, usize),
a_strides: &[usize],
a_offset: usize,
a_buffer: &Buffer,
a_trans: bool,
b_strides: &[usize],
b_offset: usize,
b_buffer: &Buffer,
b_trans: bool,
output: &Buffer,
output_offset: usize,
) -> TractResult<()> {
ensure!(m == 1 || n == 1);
ensure!(a_strides.len() >= 2 && b_strides.len() >= 2);
ensure!(a_strides.len() >= 2);
let lda = if a_trans { m } else { k };
let ldb = if b_trans { k } else { n };
let (mut tm, mut tn) = (4, 4);
#[allow(unused_assignments)]
let (mut sm, mut sn) = (1, 32);
let (mut bm, mut bn) = (1, 1);
let is_b_matrix = n != 1;
let mv_m = if is_b_matrix { n } else { m };
let mv_k = k;
let mv_ld = if is_b_matrix { ldb } else { lda };
let mv_trans = if is_b_matrix { !b_trans } else { a_trans };
let mat_batch_stride = if is_b_matrix { b_strides[0] } else { a_strides[0] };
let vec_batch_stride = if is_b_matrix { a_strides[0] } else { b_strides[0] };
let n_out_per_tgp = if mv_trans {
(sm, sn) = if mv_k >= 8192 && mv_m >= 2048 { (4, 8) } else { (8, 4) };
bn = if mv_m >= 2048 {
16
} else if mv_m >= 512 {
4
} else {
2
};
tn = if mv_m < tn { 1 } else { tn };
bn * sn * tn
} else {
bm = if mv_m >= 4096 { 8 } else { 4 };
sn = 32;
tm = if mv_m < tm { 1 } else { tm };
bm * sm * tm
};
let n_tgp = mv_m.div_ceil(n_out_per_tgp);
let group_size = MTLSize { width: 32, height: bn as _, depth: bm as _ };
let grid_size = MTLSize {
width: n_tgp as _,
height: 1,
depth: b as u64,
};
let t_mat = if mv_trans { "t_" } else { "" };
let tname = DeviceTensor::tname(dt)?;
let name = format!("gemv_{t_mat}{tname}_bm{bm}_bn{bn}_sm{sm}_sn{sn}_tm{tm}_tn{tn}_nc0_axpby0");
let pipeline = stream.load_pipeline(LibraryName::MlxGemv, &name)?;
let command_buffer = stream.command_buffer();
command_buffer.encode(|encoder| {
encoder.set_compute_pipeline_state(&pipeline);
if is_b_matrix {
encoder.set_buffer(0, Some(b_buffer), b_offset as _);
encoder.set_buffer(1, Some(a_buffer), a_offset as _);
} else {
encoder.set_buffer(0, Some(a_buffer), a_offset as _);
encoder.set_buffer(1, Some(b_buffer), b_offset as _);
}
encoder.set_buffer(3, Some(output), output_offset as _);
encoder.set_bytes(
4,
std::mem::size_of::<i32>() as u64,
&(mv_k as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<i32>() as u64,
&(mv_m as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
6,
std::mem::size_of::<i32>() as u64,
&(mv_ld as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
9, std::mem::size_of::<i32>() as u64,
&1_i32 as *const i32 as *const c_void,
);
encoder.set_bytes(
10, std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
11, std::mem::size_of::<usize>() as u64,
&vec_batch_stride as *const usize as *const c_void,
);
encoder.set_bytes(
12, std::mem::size_of::<usize>() as u64,
&mat_batch_stride as *const usize as *const c_void,
);
encoder.use_resource(a_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(b_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
});
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_metal_mlx_gemm(
stream: &MetalStream,
dt: DatumType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
lhs_transpose: bool,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
rhs_transpose: bool,
output: &Buffer,
output_offset: usize,
debug: bool,
) -> TractResult<()> {
ensure!(rhs_stride.len() >= 2);
ensure!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let a_trans = lhs_transpose;
let b_trans = rhs_transpose;
if a_trans {
ensure!(
lhs_m1 == 1 && lhs_m2 == m,
"Invalid left matmul argument [{lhs_m2}, {lhs_m1}] != [{m}, 1], strides: {:?} {:?} dims: (m: {m}, n: {n}, k: {k})",
lhs_stride,
rhs_stride
);
} else {
ensure!(
lhs_m1 == 1 && lhs_m2 == k,
"Invalid left matmul argument [{lhs_m2}, {lhs_m1}] != [{k}, 1], strides: {:?} {:?} dims: (m: {m}, n: {n}, k: {k})",
lhs_stride,
rhs_stride
);
}
if b_trans {
ensure!(
rhs_m1 == 1 && rhs_m2 == k,
"Invalid right matmul argument [{rhs_m2}, {rhs_m1}] != [{k}, 1], strides: {:?} {:?} dims: (m: {m}, n: {n}, k: {k})",
lhs_stride,
rhs_stride
);
} else {
ensure!(
rhs_m1 == 1 && rhs_m2 == n,
"Invalid right matmul argument [{rhs_m2}, {rhs_m1}] != [{n}, 1] {:?} {:?} dims: (m: {m}, n: {n}, k: {k})",
lhs_stride,
rhs_stride
);
}
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool( b > 1)),
(100, Value::Bool( false)),
(110, Value::Bool( false)),
(200, Value::Bool( m % bm == 0)),
(201, Value::Bool( n % bn == 0)),
(202, Value::Bool( k % bk == 0)),
(300, Value::Bool( false)),
(400, Value::Bool(debug)),
]));
let swizzle_log = 0;
let tile = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile;
let tm = tm.div_ceil(tile);
let batch_stride_a =
if lhs_stride.len() > 2 { lhs_stride[lhs_stride.len() - 3] } else { m * k };
let batch_stride_b =
if rhs_stride.len() > 2 { rhs_stride[rhs_stride.len() - 3] } else { n * k };
let gemm_params = MlxGemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda: if a_trans { m } else { k } as i32,
ldb: if b_trans { k } else { n } as i32,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a: batch_stride_a as isize,
batch_stride_b: batch_stride_b as isize,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
let name = kernel_name_gemm(dt, a_trans, b_trans)?;
let pipeline = stream.load_pipeline_with_constants(LibraryName::MlxGemm, &name, constants)?;
let command_buffer = stream.command_buffer();
command_buffer.encode(|encoder| {
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(3, Some(output), output_offset as NSUInteger);
encoder.set_bytes(
4,
std::mem::size_of::<MlxGemmParams>() as u64,
&gemm_params as *const MlxGemmParams as *const c_void,
);
encoder.set_bytes(
6, std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);
let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
depth: b as u64,
};
let group_size = MTLSize { width: 32, height: wn, depth: wm };
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
});
if debug {
stream.wait_until_completed()?;
}
Ok(())
}
pub fn kernel_name_gemm(
dt: DatumType,
transpose_a: bool,
transpose_b: bool,
) -> TractResult<String> {
let t_a = if transpose_a { "t" } else { "n" };
let t_b = if transpose_b { "t" } else { "n" };
let tname = DeviceTensor::tname(dt)?;
Ok(format!("gemm_{t_a}{t_b}_{tname}_{tname}_32_32_16_2_2"))
}
#[cfg(test)]
mod tests {
use crate::utils::with_borrowed_metal_stream;
use super::*;
use crate::kernels::matmul::GemmImpl;
use crate::kernels::matmul::tests::run_mmm_test_case;
use tract_gpu::tensor::{DeviceTensor, IntoDevice};
#[test]
fn test_mlx_gemv_compilation() -> TractResult<()> {
crate::utils::with_borrowed_metal_stream(|stream| {
stream.load_library(LibraryName::MlxGemv)
})?;
Ok(())
}
#[test]
fn test_mlx_gemm() -> TractResult<()> {
with_borrowed_metal_stream(|stream| {
let (b, m, n, k) = (10, 32, 32, 16);
let a = Tensor::from_shape(
&[b, m, k],
&(0..b * m * k).map(|_f| 1.0 as f32).collect::<Vec<_>>(),
)?
.into_device()?;
let b = Tensor::from_shape(
&[b, k, n],
&(0..b * n * k).map(|_f| 1.0 as f32).collect::<Vec<_>>(),
)?
.into_device()?;
let c = GemmImpl::<MlxGemm>::default().eval(stream, &a, &b)?;
let expected_c = Tensor::from_shape(&[10, 32, 32], &vec![16.0; 10 * 32 * 32])?;
let c = c.to_host()?;
c.close_enough(&expected_c, Approximation::Approximate)?;
assert!(c.close_enough(&expected_c, Approximation::Approximate).is_ok());
let (b, m, n, k) = (2, 2, 4, 3);
let a = DeviceTensor::from_shape(
&[b, m, k],
&(0..b * m * k).map(|f| f as f32).collect::<Vec<_>>(),
)?;
let b = DeviceTensor::from_shape(
&[b, k, n],
&(0..b * n * k).map(|f| f as f32).collect::<Vec<_>>(),
)?;
let c = GemmImpl::<MlxGemm>::default().eval(stream, &a, &b)?;
let expected_c = Tensor::from_shape(
&[2, 2, 4],
&[
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0,
488.0, 518.0, 548.0, 578.0,
],
)?;
assert!(c.to_host()?.close_enough(&expected_c, Approximation::Approximate).is_ok());
Ok(())
})
}
#[test]
fn test_mat_vec() -> TractResult<()> {
run_mmm_test_case::<MlxGemm>((1, 4, 4, 1), false, false, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>((10, 1, 4, 4), false, false, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>((5, 1, 15, 7), false, true, DatumType::F32, DatumType::F32)?;
Ok(())
}
#[test]
fn test_mat_mul() -> TractResult<()> {
run_mmm_test_case::<MlxGemm>((1, 3, 5, 4), false, false, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>((1, 2, 5, 10), false, true, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>((1, 4, 4, 4), false, true, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>((1, 4, 4, 200), false, true, DatumType::F32, DatumType::F32)?;
run_mmm_test_case::<MlxGemm>(
(1, 25, 1280, 32000),
false,
true,
DatumType::F32,
DatumType::F32,
)?;
Ok(())
}
}