#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use std::sync::OnceLock;
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn enable_dispatch_profile() {
static GUARD: OnceLock<()> = OnceLock::new();
GUARD.get_or_init(|| {
unsafe {
std::env::set_var("MLX_PROFILE_DISPATCH", "1");
std::env::set_var("MLX_PROFILE_CB", "1");
}
});
}
fn supports_dispatch_boundary_sampling(device: &MlxDevice) -> bool {
use metal::MTLCounterSamplingPoint;
device
.metal_device()
.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary)
}
fn make_inc_kernel() -> (KernelRegistry, &'static str) {
let mut registry = KernelRegistry::new();
registry.register_source(
"inc_one",
r#"
#include <metal_stdlib>
using namespace metal;
kernel void inc_one(
device const float *in [[buffer(0)]],
device float *out [[buffer(1)]],
uint id [[thread_position_in_grid]]
) {
out[id] = in[id] + 1.0f;
}
"#,
);
(registry, "inc_one")
}
#[test]
fn per_dispatch_sampling_records_nonzero_ns() {
enable_dispatch_profile();
mlx_native::kernel_profile::reset();
let device = MlxDevice::new().expect("MlxDevice::new");
let supported = supports_dispatch_boundary_sampling(&device);
let (mut registry, kernel) = make_inc_kernel();
let pipeline = registry
.get_pipeline(kernel, device.metal_device())
.expect("get_pipeline")
.clone();
let n: usize = 256;
let byte_len = n * std::mem::size_of::<f32>();
let mut in_buf = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("alloc in");
let out_buf = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("alloc out");
{
let s: &mut [f32] = in_buf.as_mut_slice().expect("in slice");
for (i, v) in s.iter_mut().enumerate() {
*v = i as f32;
}
}
let mut enc = device.command_encoder().expect("encoder");
use mlx_native::KernelArg;
use metal::MTLSize;
let bindings: Vec<(u64, KernelArg<'_>)> = vec![
(0, KernelArg::Buffer(&in_buf)),
(1, KernelArg::Buffer(&out_buf)),
];
let tg = MTLSize { width: (n as u64 + 31) / 32, height: 1, depth: 1 };
let tg_size = MTLSize { width: 32, height: 1, depth: 1 };
enc.encode_threadgroups_with_args(&pipeline, &bindings, tg, tg_size);
enc.memory_barrier();
enc.encode_threadgroups_with_args(&pipeline, &bindings, tg, tg_size);
enc.commit_and_wait_labeled("test.profile.cb")
.expect("commit_and_wait_labeled");
let dumps = mlx_native::kernel_profile::dump_dispatches();
if !supported {
eprintln!(
"[test] device does NOT support AtDispatchBoundary; \
verifying graceful degrade"
);
assert!(
dumps.is_empty(),
"device unsupported → dump_dispatches must be empty; got {:?}",
dumps
);
let cb_dump = mlx_native::kernel_profile::dump();
assert!(
cb_dump.iter().any(|(l, _)| l == "test.profile.cb"),
"MLX_PROFILE_CB=1 should still produce a per-CB entry on \
Apple Silicon (independent code path); cb_dump={:?}",
cb_dump
);
mlx_native::kernel_profile::reset();
return;
}
assert!(
!dumps.is_empty(),
"device supports AtDispatchBoundary AND \
MLX_PROFILE_DISPATCH=1 set + commit_and_wait_labeled called \
→ expected at least one DispatchEntry; dumps={:?}",
dumps
);
let (label, entries) = dumps
.iter()
.find(|(l, _)| l == "test.profile.cb")
.expect("test.profile.cb dispatches present");
assert_eq!(label, "test.profile.cb");
assert_eq!(
entries.len(),
2,
"exactly 2 dispatches encoded; got {} entries",
entries.len()
);
for (i, e) in entries.iter().enumerate() {
assert_eq!(
e.dispatch_index as usize, i,
"dispatch_index must equal insertion order; entry={:?}",
e
);
assert!(
e.end_gpu_ns >= e.start_gpu_ns,
"end timestamp must follow start; entry={:?}",
e
);
assert_eq!(e.op_kind, "Other", "default op kind = Other; got {:?}", e);
}
let cb_dump = mlx_native::kernel_profile::dump();
let cb_entry = cb_dump
.iter()
.find(|(l, _)| l == "test.profile.cb")
.expect("MLX_PROFILE_CB=1 should produce a per-CB entry");
let cb_total = cb_entry.1.total_ns;
let dispatch_sum: u64 = entries.iter().map(|e| e.gpu_ns).sum();
assert!(
cb_total + 1_000_000 >= dispatch_sum,
"per-CB total ({} ns) must be >= sum of per-dispatch ({} ns) \
within ±1ms slop; cb_entry={:?}, entries={:?}",
cb_total,
dispatch_sum,
cb_entry,
entries
);
mlx_native::kernel_profile::reset();
}
#[test]
fn per_dispatch_dump_is_grouped_by_cb_label() {
enable_dispatch_profile();
mlx_native::kernel_profile::reset();
let device = MlxDevice::new().expect("MlxDevice::new");
let supported = supports_dispatch_boundary_sampling(&device);
let (mut registry, kernel) = make_inc_kernel();
let pipeline = registry
.get_pipeline(kernel, device.metal_device())
.expect("get_pipeline")
.clone();
let n: usize = 64;
let byte_len = n * std::mem::size_of::<f32>();
let in_buf = device.alloc_buffer(byte_len, DType::F32, vec![n]).unwrap();
let out_buf = device.alloc_buffer(byte_len, DType::F32, vec![n]).unwrap();
use mlx_native::KernelArg;
use metal::MTLSize;
let bindings: Vec<(u64, KernelArg<'_>)> = vec![
(0, KernelArg::Buffer(&in_buf)),
(1, KernelArg::Buffer(&out_buf)),
];
let tg = MTLSize { width: (n as u64 + 31) / 32, height: 1, depth: 1 };
let tg_size = MTLSize { width: 32, height: 1, depth: 1 };
{
let mut enc = device.command_encoder().unwrap();
enc.encode_threadgroups_with_args(&pipeline, &bindings, tg, tg_size);
enc.commit_and_wait_labeled("test.cb.alpha").unwrap();
}
{
let mut enc = device.command_encoder().unwrap();
enc.encode_threadgroups_with_args(&pipeline, &bindings, tg, tg_size);
enc.commit_and_wait_labeled("test.cb.beta").unwrap();
}
let dumps = mlx_native::kernel_profile::dump_dispatches();
if !supported {
assert!(dumps.is_empty(), "unsupported → empty dispatches; got {:?}", dumps);
mlx_native::kernel_profile::reset();
return;
}
let alpha = dumps
.iter()
.find(|(l, _)| l == "test.cb.alpha")
.expect("alpha label");
let beta = dumps
.iter()
.find(|(l, _)| l == "test.cb.beta")
.expect("beta label");
assert_eq!(alpha.1.len(), 1, "alpha has 1 dispatch");
assert_eq!(beta.1.len(), 1, "beta has 1 dispatch");
assert_eq!(alpha.1[0].dispatch_index, 0);
assert_eq!(beta.1[0].dispatch_index, 0);
mlx_native::kernel_profile::reset();
}
#[test]
fn op_kind_label_propagates_into_dispatch_entry() {
enable_dispatch_profile();
mlx_native::kernel_profile::reset();
use mlx_native::CapturedOpKind;
let device = MlxDevice::new().expect("device");
let supported = supports_dispatch_boundary_sampling(&device);
let (mut registry, kernel) = make_inc_kernel();
let pipeline = registry
.get_pipeline(kernel, device.metal_device())
.expect("pipeline")
.clone();
let n: usize = 64;
let byte_len = n * std::mem::size_of::<f32>();
let in_buf = device.alloc_buffer(byte_len, DType::F32, vec![n]).unwrap();
let out_buf = device.alloc_buffer(byte_len, DType::F32, vec![n]).unwrap();
let mut enc = device.command_encoder().unwrap();
enc.set_op_kind(CapturedOpKind::Sdpa);
use mlx_native::KernelArg;
use metal::MTLSize;
let bindings: Vec<(u64, KernelArg<'_>)> = vec![
(0, KernelArg::Buffer(&in_buf)),
(1, KernelArg::Buffer(&out_buf)),
];
let tg = MTLSize { width: (n as u64 + 31) / 32, height: 1, depth: 1 };
let tg_size = MTLSize { width: 32, height: 1, depth: 1 };
enc.encode_threadgroups_with_args(&pipeline, &bindings, tg, tg_size);
enc.commit_and_wait_labeled("test.opkind.cb").unwrap();
let dumps = mlx_native::kernel_profile::dump_dispatches();
if !supported {
assert!(dumps.is_empty(), "unsupported → empty; got {:?}", dumps);
mlx_native::kernel_profile::reset();
return;
}
let (_, entries) = dumps
.iter()
.find(|(l, _)| l == "test.opkind.cb")
.expect("opkind cb present");
assert_eq!(entries.len(), 1);
assert_eq!(
entries[0].op_kind, "Sdpa",
"set_op_kind(Sdpa) must propagate into DispatchEntry; got {:?}",
entries[0]
);
mlx_native::kernel_profile::reset();
}
#[test]
fn supports_counter_sampling_capability_diagnostic() {
let device = MlxDevice::new().expect("device");
let md = device.metal_device();
use metal::MTLCounterSamplingPoint;
eprintln!(
"[mlx-native diag] device={:?} \
supports_counter_sampling: \
AtStageBoundary={} AtDispatchBoundary={} \
AtBlitBoundary={} AtTileDispatchBoundary={} AtDrawBoundary={}",
device.name(),
md.supports_counter_sampling(MTLCounterSamplingPoint::AtStageBoundary),
md.supports_counter_sampling(MTLCounterSamplingPoint::AtDispatchBoundary),
md.supports_counter_sampling(MTLCounterSamplingPoint::AtBlitBoundary),
md.supports_counter_sampling(MTLCounterSamplingPoint::AtTileDispatchBoundary),
md.supports_counter_sampling(MTLCounterSamplingPoint::AtDrawBoundary),
);
}