use std::ffi::CString;
use derive_more::{IsVariant, TryUnwrap, Unwrap};
use smol_str::format_smolstr;
use crate::{
array::Array,
dtype::Dtype,
error::{
EmptyInputPayload, Error, FfiNullHandlePayload, InteriorNulPayload, LengthMismatchPayload,
OutOfRangePayload, Result, check, check_vector_array_handle,
},
ffi::VectorArrayGuard,
stream::default_stream,
};
#[derive(Debug, Clone, Copy, PartialEq, IsVariant, Unwrap, TryUnwrap)]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum KernelTemplateArg {
Bool(bool),
Int(i32),
Dtype(Dtype),
}
#[derive(Debug, Clone)]
pub struct MetalKernelApplyConfig {
grid: [u32; 3],
thread_group: [u32; 3],
output_shapes: Vec<Vec<i32>>,
output_dtypes: Vec<Dtype>,
template: Vec<(String, KernelTemplateArg)>,
init_value: Option<f32>,
verbose: bool,
}
impl MetalKernelApplyConfig {
pub fn new(
grid: [u32; 3],
thread_group: [u32; 3],
output_shapes: Vec<Vec<i32>>,
output_dtypes: Vec<Dtype>,
) -> Result<Self> {
if grid == [0, 0, 0] {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"MetalKernelApplyConfig::new: grid",
"must have at least one non-zero dimension (a zero grid dispatches no threads)",
format_smolstr!("grid={grid:?}"),
)));
}
if thread_group == [0, 0, 0] {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"MetalKernelApplyConfig::new: thread_group",
"must have at least one non-zero dimension (Metal requires thread_group_size > 0)",
format_smolstr!("thread_group={thread_group:?}"),
)));
}
let tg_product: u64 = (thread_group[0] as u64)
.saturating_mul(thread_group[1] as u64)
.saturating_mul(thread_group[2] as u64);
if tg_product > 1024 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"MetalKernelApplyConfig::new: thread_group product",
"must be <= 1024 (Metal hardware limit for threads per threadgroup)",
format_smolstr!("thread_group={thread_group:?}, product={tg_product}"),
)));
}
Ok(Self {
grid,
thread_group,
output_shapes,
output_dtypes,
template: Vec::new(),
init_value: None,
verbose: false,
})
}
#[must_use]
pub fn with_template(mut self, template: Vec<(String, KernelTemplateArg)>) -> Self {
self.template = template;
self
}
#[must_use]
pub fn with_init_value(mut self, value: f32) -> Self {
self.init_value = Some(value);
self
}
#[must_use]
pub fn with_verbose(mut self, v: bool) -> Self {
self.verbose = v;
self
}
#[inline(always)]
pub fn grid(&self) -> [u32; 3] {
self.grid
}
#[inline(always)]
pub fn thread_group(&self) -> [u32; 3] {
self.thread_group
}
#[inline(always)]
pub fn output_shapes_slice(&self) -> &[Vec<i32>] {
&self.output_shapes
}
#[inline(always)]
pub fn output_dtypes_slice(&self) -> &[Dtype] {
&self.output_dtypes
}
#[inline(always)]
pub fn template_slice(&self) -> &[(String, KernelTemplateArg)] {
&self.template
}
#[inline(always)]
pub fn init_value(&self) -> Option<f32> {
self.init_value
}
#[inline(always)]
pub fn verbose(&self) -> bool {
self.verbose
}
}
struct VectorStringGuard(mlxrs_sys::mlx_vector_string);
impl Drop for VectorStringGuard {
fn drop(&mut self) {
unsafe {
let _ = mlxrs_sys::mlx_vector_string_free(self.0);
}
}
}
struct MetalKernelConfigGuard(mlxrs_sys::mlx_fast_metal_kernel_config);
impl Drop for MetalKernelConfigGuard {
fn drop(&mut self) {
unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_free(self.0);
}
}
}
fn build_vector_string(items: &[&str], context: &'static str) -> Result<VectorStringGuard> {
let vstr = unsafe { mlxrs_sys::mlx_vector_string_new() };
let guard = VectorStringGuard(vstr);
for s in items {
let cs = CString::new(*s).map_err(|_| {
let _ = s;
Error::InteriorNul(InteriorNulPayload::new(
"ops::fast::metal_kernel::vector_string entry append",
context,
))
})?;
check(unsafe { mlxrs_sys::mlx_vector_string_append_value(vstr, cs.as_ptr()) })?;
}
Ok(guard)
}
fn cstring_or_err(s: &str, context: &'static str) -> Result<CString> {
CString::new(s).map_err(|_| {
let _ = s;
Error::InteriorNul(InteriorNulPayload::new(
"ops::fast::metal_kernel::cstring_or_err",
context,
))
})
}
fn to_dispatch_dim(dim: [u32; 3], context: &'static str) -> Result<[i32; 3]> {
let mut out = [0_i32; 3];
for (axis, &v) in dim.iter().enumerate() {
out[axis] = i32::try_from(v).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
context,
"must fit in i32 (mlx-c set_grid / set_thread_group requires i32; reduce the dispatch dimension)",
format_smolstr!("{context}[{axis}]={v}"),
))
})?;
}
Ok(out)
}
pub struct MetalKernel {
inner: mlxrs_sys::mlx_fast_metal_kernel,
output_names: Vec<String>,
}
impl std::fmt::Debug for MetalKernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetalKernel")
.field("output_names", &self.output_names)
.finish_non_exhaustive()
}
}
impl Drop for MetalKernel {
fn drop(&mut self) {
unsafe {
mlxrs_sys::mlx_fast_metal_kernel_free(self.inner);
}
}
}
impl MetalKernel {
pub fn new(
name: &str,
input_names: &[&str],
output_names: &[&str],
source: &str,
header: &str,
ensure_row_contiguous: bool,
atomic_outputs: bool,
) -> Result<Self> {
crate::error::ensure_handler_installed();
let name_c = cstring_or_err(name, "`name`")?;
let source_c = cstring_or_err(source, "`source`")?;
let header_c = cstring_or_err(header, "`header`")?;
let input_names_guard = build_vector_string(input_names, "input_names")?;
let output_names_guard = build_vector_string(output_names, "output_names")?;
let raw = unsafe {
mlxrs_sys::mlx_fast_metal_kernel_new(
name_c.as_ptr(),
input_names_guard.0,
output_names_guard.0,
source_c.as_ptr(),
header_c.as_ptr(),
ensure_row_contiguous,
atomic_outputs,
)
};
drop(input_names_guard);
drop(output_names_guard);
drop(name_c);
drop(source_c);
drop(header_c);
if raw.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(FfiNullHandlePayload::new(
"mlx_fast_metal_kernel_new",
))),
);
}
Ok(Self {
inner: raw,
output_names: output_names.iter().map(|s| (*s).to_string()).collect(),
})
}
#[inline(always)]
pub fn output_arity(&self) -> usize {
self.output_names.len()
}
#[inline(always)]
pub fn output_names_slice(&self) -> &[String] {
&self.output_names
}
pub fn apply(&self, inputs: &[&Array], config: &MetalKernelApplyConfig) -> Result<Vec<Array>> {
let expected = self.output_names.len();
if config.output_shapes_slice().len() != expected {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"metal_kernel::apply: output_shapes vs kernel output_names",
expected,
config.output_shapes_slice().len(),
)));
}
if config.output_dtypes_slice().len() != expected {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"metal_kernel::apply: output_dtypes vs kernel output_names",
expected,
config.output_dtypes_slice().len(),
)));
}
for shape in config.output_shapes_slice().iter() {
if shape.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"metal_kernel::apply: output_shapes[idx] (custom Metal kernel outputs must have rank >= 1)",
)));
}
crate::shape::validate_dims(shape)?;
}
let grid_i32 = to_dispatch_dim(config.grid(), "grid")?;
let thread_group_i32 = to_dispatch_dim(config.thread_group(), "thread_group")?;
crate::error::ensure_handler_installed();
let stream = default_stream();
let config_raw = unsafe { mlxrs_sys::mlx_fast_metal_kernel_config_new() };
let _config_guard = MetalKernelConfigGuard(config_raw);
if config_raw.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(FfiNullHandlePayload::new(
"mlx_fast_metal_kernel_config_new",
))),
);
}
for (shape, dtype) in config
.output_shapes_slice()
.iter()
.zip(config.output_dtypes_slice().iter())
{
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_add_output_arg(
config_raw,
crate::shape::dim_ptr(shape),
shape.len(),
(*dtype).into(),
)
})?;
}
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_set_grid(
config_raw,
grid_i32[0],
grid_i32[1],
grid_i32[2],
)
})?;
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_set_thread_group(
config_raw,
thread_group_i32[0],
thread_group_i32[1],
thread_group_i32[2],
)
})?;
if let Some(v) = config.init_value() {
check(unsafe { mlxrs_sys::mlx_fast_metal_kernel_config_set_init_value(config_raw, v) })?;
}
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_set_verbose(config_raw, config.verbose())
})?;
for (arg_name, arg_value) in config.template_slice() {
let name_c = cstring_or_err(arg_name.as_str(), "template-arg name")?;
match arg_value {
KernelTemplateArg::Bool(v) => {
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_add_template_arg_bool(
config_raw,
name_c.as_ptr(),
*v,
)
})?;
}
KernelTemplateArg::Int(v) => {
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_add_template_arg_int(
config_raw,
name_c.as_ptr(),
*v,
)
})?;
}
KernelTemplateArg::Dtype(v) => {
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_config_add_template_arg_dtype(
config_raw,
name_c.as_ptr(),
(*v).into(),
)
})?;
}
}
}
let raw_inputs: Vec<mlxrs_sys::mlx_array> = inputs.iter().map(|a| a.0).collect();
let inputs_vec =
unsafe { mlxrs_sys::mlx_vector_array_new_data(raw_inputs.as_ptr(), raw_inputs.len()) };
let _inputs_guard = VectorArrayGuard(inputs_vec);
if inputs_vec.ctx.is_null() {
return Err(
crate::error::LAST
.with(|c| c.borrow_mut().take())
.unwrap_or(Error::FfiNullHandle(FfiNullHandlePayload::new(
"mlx_vector_array_new_data",
))),
);
}
let mut out_vec = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(out_vec)?;
let _out_guard = VectorArrayGuard(out_vec);
check(unsafe {
mlxrs_sys::mlx_fast_metal_kernel_apply(
&mut out_vec,
self.inner,
inputs_vec,
config_raw,
stream,
)
})?;
let n = unsafe { mlxrs_sys::mlx_vector_array_size(out_vec) };
let mut parts = Vec::with_capacity(n);
for i in 0..n {
let mut part = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_vector_array_get(&mut part.0, out_vec, i) })?;
parts.push(part);
}
Ok(parts)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn template_arg_bool_variant_roundtrip() {
let a = KernelTemplateArg::Bool(true);
let b = KernelTemplateArg::Bool(false);
assert_eq!(a, KernelTemplateArg::Bool(true));
assert_ne!(a, b);
}
#[test]
fn template_arg_int_variant_roundtrip() {
let a = KernelTemplateArg::Int(7);
assert_eq!(a, KernelTemplateArg::Int(7));
assert_ne!(a, KernelTemplateArg::Int(8));
assert_ne!(a, KernelTemplateArg::Bool(true));
}
#[test]
fn template_arg_dtype_variant_roundtrip() {
let a = KernelTemplateArg::Dtype(Dtype::F32);
assert_eq!(a, KernelTemplateArg::Dtype(Dtype::F32));
assert_ne!(a, KernelTemplateArg::Dtype(Dtype::F16));
assert_ne!(a, KernelTemplateArg::Int(0));
}
#[test]
fn template_arg_is_copy_and_clone() {
fn assert_copy<T: Copy>() {}
fn assert_clone<T: Clone>() {}
assert_copy::<KernelTemplateArg>();
assert_clone::<KernelTemplateArg>();
let a = KernelTemplateArg::Int(3);
let _b = a; let _c = a; }
#[test]
fn config_new_defaults_optional_fields() {
let cfg =
MetalKernelApplyConfig::new([8, 1, 1], [4, 1, 1], vec![vec![8]], vec![Dtype::F32]).unwrap();
assert_eq!(cfg.grid(), [8, 1, 1]);
assert_eq!(cfg.thread_group(), [4, 1, 1]);
assert_eq!(cfg.output_shapes_slice(), &[vec![8]]);
assert_eq!(cfg.output_dtypes_slice(), &[Dtype::F32]);
assert!(cfg.template_slice().is_empty());
assert!(cfg.init_value().is_none());
assert!(!cfg.verbose());
}
#[test]
fn config_struct_update_overrides_optional_fields() {
let cfg = MetalKernelApplyConfig::new([16, 1, 1], [8, 1, 1], vec![vec![16]], vec![Dtype::F16])
.unwrap()
.with_template(vec![("ALPHA".to_string(), KernelTemplateArg::Int(2))])
.with_init_value(0.5)
.with_verbose(true);
assert_eq!(cfg.grid(), [16, 1, 1]);
assert_eq!(cfg.thread_group(), [8, 1, 1]);
assert_eq!(cfg.template_slice().len(), 1);
assert_eq!(cfg.template_slice()[0].0, "ALPHA");
assert_eq!(cfg.template_slice()[0].1, KernelTemplateArg::Int(2));
assert_eq!(cfg.init_value(), Some(0.5));
assert!(cfg.verbose());
}
#[test]
fn config_is_clone_for_repeated_dispatch() {
fn assert_clone<T: Clone>() {}
assert_clone::<MetalKernelApplyConfig>();
let cfg =
MetalKernelApplyConfig::new([1, 1, 1], [1, 1, 1], vec![vec![1]], vec![Dtype::F32]).unwrap();
let cloned = cfg.clone();
assert_eq!(cloned.grid(), cfg.grid());
assert_eq!(cloned.output_shapes_slice(), cfg.output_shapes_slice());
}
#[test]
fn config_multi_output_shapes_and_dtypes_align() {
let cfg = MetalKernelApplyConfig::new(
[2, 2, 1],
[1, 1, 1],
vec![vec![4], vec![4, 4]],
vec![Dtype::F32, Dtype::I32],
)
.unwrap();
assert_eq!(
cfg.output_shapes_slice().len(),
cfg.output_dtypes_slice().len()
);
assert_eq!(cfg.output_shapes_slice()[1], vec![4, 4]);
assert_eq!(cfg.output_dtypes_slice()[1], Dtype::I32);
}
fn assert_interior_nul(err: &Error, needle: &str) {
match err {
Error::InteriorNul(p) => {
assert!(
p.bytes_kind() == needle || p.bytes_kind().contains(needle.trim_matches('`')),
"expected bytes_kind to match {needle:?}, got: {p:?}"
);
}
other => panic!("expected Error::InteriorNul, got: {other:?}"),
}
}
#[test]
fn metal_kernel_new_rejects_interior_nul_in_name() {
let err = MetalKernel::new("bad\0name", &["a"], &["out"], "// noop", "", true, false)
.expect_err("interior NUL in name should be rejected");
assert_interior_nul(&err, "`name`");
}
#[test]
fn metal_kernel_new_rejects_interior_nul_in_source() {
let err = MetalKernel::new("k", &["a"], &["out"], "// bad\0", "", true, false)
.expect_err("interior NUL in source should be rejected");
assert_interior_nul(&err, "`source`");
}
#[test]
fn metal_kernel_new_rejects_interior_nul_in_header() {
let err = MetalKernel::new("k", &["a"], &["out"], "// noop", "hdr\0bad", true, false)
.expect_err("interior NUL in header should be rejected");
assert_interior_nul(&err, "`header`");
}
#[test]
fn metal_kernel_new_rejects_interior_nul_in_input_names() {
let err = MetalKernel::new("k", &["a\0b"], &["out"], "// noop", "", true, false)
.expect_err("interior NUL in input_names should be rejected");
assert_interior_nul(&err, "input_names");
}
#[test]
fn metal_kernel_new_rejects_interior_nul_in_output_names() {
let err = MetalKernel::new("k", &["a"], &["out\0bad"], "// noop", "", true, false)
.expect_err("interior NUL in output_names should be rejected");
assert_interior_nul(&err, "output_names");
}
fn make_validation_kernel(output_names: &[&str]) -> MetalKernel {
MetalKernel::new(
"validation_only",
&["x"],
output_names,
"uint elem = thread_position_in_grid.x; out[elem] = x[elem];",
"",
true,
false,
)
.expect("construction should not need a Metal device")
}
#[test]
fn apply_rejects_negative_output_dimension() {
let kernel = make_validation_kernel(&["out"]);
let input = Array::ones::<f32>(&(8usize,)).expect("ones alloc");
let cfg =
MetalKernelApplyConfig::new([8, 1, 1], [8, 1, 1], vec![vec![-1, 8]], vec![Dtype::F32])
.unwrap();
let err = kernel
.apply(&[&input], &cfg)
.expect_err("negative output dim should be rejected before FFI");
match err {
Error::OutOfRange(payload) => {
assert_eq!(payload.context(), "shape::validate_dims: dim");
assert_eq!(payload.requirement(), "must be non-negative");
assert_eq!(payload.value(), "dim[0]=-1");
}
other => panic!("expected OutOfRange, got: {other:?}"),
}
}
#[test]
fn apply_rejects_scalar_output_shape() {
let kernel = make_validation_kernel(&["out"]);
let input = Array::ones::<f32>(&(8usize,)).expect("ones alloc");
let cfg =
MetalKernelApplyConfig::new([1, 1, 1], [1, 1, 1], vec![vec![]], vec![Dtype::F32]).unwrap();
let err = kernel
.apply(&[&input], &cfg)
.expect_err("empty output shape should be rejected before FFI");
match err {
Error::EmptyInput(payload) => {
assert_eq!(
payload.context(),
"metal_kernel::apply: output_shapes[idx] (custom Metal kernel outputs must have rank >= 1)"
);
}
other => panic!("expected EmptyInput, got: {other:?}"),
}
}
}