use std::sync::Mutex;
use mlx_native::{
macos_15_or_newer_for_test, reset_residency_env_cache_for_test,
reset_residency_test_counters, residency_allocation_count_for_test,
residency_commit_call_count_for_test, DType, MlxBufferPool, MlxDevice, MlxError,
};
static TEST_LOCK: Mutex<()> = Mutex::new(());
fn active_residency_device() -> Option<MlxDevice> {
reset_residency_env_cache_for_test();
reset_residency_test_counters();
std::env::remove_var("HF2Q_NO_RESIDENCY");
if !macos_15_or_newer_for_test() {
return None;
}
match MlxDevice::new() {
Ok(device) => {
assert!(
device.residency_sets_enabled(),
"macOS 15+ device boot should enable residency sets and log `residency sets = true`",
);
Some(device)
}
Err(MlxError::DeviceNotFound) => None,
Err(err) => panic!("MlxDevice::new failed: {err}"),
}
}
#[test]
fn residency_set_initialized_on_macos_15_plus() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
assert!(device.residency_sets_enabled());
}
#[test]
fn residency_set_buffers_added_on_allocation() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
let mut pool = MlxBufferPool::new();
let _buffers = pool
.alloc_batch(
&device,
(0..4).map(|_| (1024, DType::F32, vec![256])),
)
.expect("alloc batch");
assert_eq!(residency_allocation_count_for_test(), 4);
}
#[test]
fn residency_set_buffers_removed_on_pool_eviction() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
let mut pool = MlxBufferPool::new();
{
let _buffers = pool
.alloc_batch(
&device,
(0..4).map(|_| (1024, DType::F32, vec![256])),
)
.expect("alloc batch");
assert_eq!(residency_allocation_count_for_test(), 4);
}
pool.reset();
pool.clear();
assert_eq!(residency_allocation_count_for_test(), 0);
}
#[test]
fn hf2q_no_residency_disables_init() {
let _guard = TEST_LOCK.lock().expect("test lock");
reset_residency_env_cache_for_test();
reset_residency_test_counters();
std::env::set_var("HF2Q_NO_RESIDENCY", "1");
let result = MlxDevice::new();
std::env::remove_var("HF2Q_NO_RESIDENCY");
reset_residency_env_cache_for_test();
match result {
Ok(device) => assert!(!device.residency_sets_enabled()),
Err(MlxError::DeviceNotFound) => {}
Err(err) => panic!("MlxDevice::new failed: {err}"),
}
}
#[test]
fn device_alloc_buffers_removed_on_drop() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
assert_eq!(residency_allocation_count_for_test(), 0);
let buffers = (0..10)
.map(|i| {
device
.alloc_buffer(1024, DType::F32, vec![256])
.unwrap_or_else(|e| panic!("alloc_buffer iter {i}: {e}"))
})
.collect::<Vec<_>>();
assert_eq!(
residency_allocation_count_for_test(),
10,
"after 10 alloc_buffer calls, residency count should be 10",
);
drop(buffers);
assert_eq!(
residency_allocation_count_for_test(),
0,
"after dropping all 10 buffers, residency count should be 0",
);
}
#[test]
fn mlx_buffer_clone_shares_registration() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
assert_eq!(residency_allocation_count_for_test(), 0);
let buf = device
.alloc_buffer(2048, DType::U8, vec![2048])
.expect("alloc_buffer");
assert_eq!(residency_allocation_count_for_test(), 1);
let buf_clone1 = buf.clone();
let buf_clone2 = buf.clone();
let buf_clone3 = buf.clone();
let buf_clone4 = buf.clone();
let buf_clone5 = buf.clone();
assert_eq!(
residency_allocation_count_for_test(),
1,
"5x MlxBuffer clones must NOT double-register",
);
let view = buf.slice_view(0, 256);
assert_eq!(
residency_allocation_count_for_test(),
1,
"slice_view must NOT double-register",
);
drop(buf_clone1);
drop(buf_clone2);
drop(buf_clone3);
drop(buf_clone4);
drop(buf_clone5);
drop(view);
assert_eq!(
residency_allocation_count_for_test(),
1,
"dropping clones while original is alive must NOT deregister",
);
drop(buf);
assert_eq!(
residency_allocation_count_for_test(),
0,
"dropping the last MlxBuffer clone must deregister",
);
}
#[test]
fn no_residency_env_skips_registration() {
let _guard = TEST_LOCK.lock().expect("test lock");
reset_residency_env_cache_for_test();
reset_residency_test_counters();
std::env::set_var("HF2Q_NO_RESIDENCY", "1");
let device = match MlxDevice::new() {
Ok(d) => d,
Err(MlxError::DeviceNotFound) => {
std::env::remove_var("HF2Q_NO_RESIDENCY");
reset_residency_env_cache_for_test();
return;
}
Err(e) => {
std::env::remove_var("HF2Q_NO_RESIDENCY");
reset_residency_env_cache_for_test();
panic!("MlxDevice::new failed: {e}");
}
};
assert!(!device.residency_sets_enabled());
let _bufs: Vec<_> = (0..5)
.map(|_| {
device
.alloc_buffer(512, DType::F32, vec![128])
.expect("alloc_buffer")
})
.collect();
assert_eq!(
residency_allocation_count_for_test(),
0,
"HF2Q_NO_RESIDENCY=1 must skip auto-registration entirely",
);
std::env::remove_var("HF2Q_NO_RESIDENCY");
reset_residency_env_cache_for_test();
}
#[test]
fn defer_and_flush_commit_count() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
let commits_baseline = residency_commit_call_count_for_test();
assert_eq!(commits_baseline, 0, "active_residency_device resets counters");
let buffers = (0..100)
.map(|i| {
device
.alloc_buffer(512, DType::F32, vec![128])
.unwrap_or_else(|e| panic!("alloc_buffer iter {i}: {e}"))
})
.collect::<Vec<_>>();
let commits_after_allocs = residency_commit_call_count_for_test();
assert_eq!(
commits_after_allocs, 0,
"100 alloc_buffer calls must NOT issue any [set commit] (deferred)",
);
assert_eq!(residency_allocation_count_for_test(), 100);
drop(buffers);
let commits_after_drops = residency_commit_call_count_for_test();
assert_eq!(
commits_after_drops, 0,
"100 buffer drops must NOT issue any [set commit] (deferred)",
);
assert_eq!(residency_allocation_count_for_test(), 0);
let mut encoder = device.command_encoder().expect("command_encoder");
encoder.commit_and_wait().expect("commit_and_wait");
let commits_after_flush = residency_commit_call_count_for_test();
assert!(
commits_after_flush <= 5,
"Phase 3b AC4 (queen gate ≤ 5 commits/token): got {} commits for 100-alloc + 100-drop batch (variant baseline = ~200)",
commits_after_flush,
);
assert_eq!(
commits_after_flush, 1,
"defer-and-flush should issue exactly 1 [set commit] for the 100-alloc + 100-drop batch + 1 CB submission",
);
}
#[test]
fn commit_called_after_alloc_batch() {
let _guard = TEST_LOCK.lock().expect("test lock");
let Some(device) = active_residency_device() else {
return;
};
let mut pool = MlxBufferPool::new();
let empty = pool
.alloc_batch(&device, std::iter::empty())
.expect("empty alloc batch");
assert!(empty.is_empty());
assert_eq!(residency_commit_call_count_for_test(), 0);
let _buffers = pool
.alloc_batch(
&device,
(0..4).map(|_| (1024, DType::F32, vec![256])),
)
.expect("alloc batch");
assert_eq!(residency_allocation_count_for_test(), 4);
assert_eq!(residency_commit_call_count_for_test(), 1);
}