#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::sync::Mutex;
use mlx_native::{
reset_residency_test_counters, residency_allocation_count_for_test, DType, EncoderSession,
KernelRegistry, MlxDevice,
};
static RESIDENCY_TEST_LOCK: Mutex<()> = Mutex::new(());
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
RESIDENCY_TEST_LOCK
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn test_session_fence_stage_then_reset_then_begin_stage() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_fence_stage_then_reset_then_begin_stage \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env. \
Re-run with HF2Q_ENCODER_SESSION=1 to exercise."
);
return;
}
let _guard = acquire_test_lock();
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a1 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a1");
let mut b1 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b1");
let out1 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("out1");
a1.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
b1.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
let mut a2 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a2");
let mut b2 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b2");
let out2 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("out2");
a2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[100.0, 200.0, 300.0, 400.0]);
b2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[7.0, 8.0, 9.0, 10.0]);
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
assert!(!sess.has_event(), "no event before first fence_stage");
assert_eq!(sess.fence_value(), 0, "fence_value starts at 0");
assert!(!sess.is_fence_pending());
sess.begin_stage("phase.iter89e2b_stage1");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a1,
&b1,
&out1,
n,
DType::F32,
)
.expect("stage1 dispatch");
sess.fence_stage(Some("phase.iter89e2b_stage1.fence"))
.expect("fence_stage Ok");
assert!(sess.is_drained(), "drained after fence_stage");
assert!(sess.is_fence_pending(), "fence_pending after fence_stage");
assert!(sess.has_event(), "event lazy-allocated by first fence");
assert_eq!(sess.fence_value(), 1, "fence_value bumped to 1");
let cb_label_stage1: String = sess.metal_command_buffer().label().to_string();
assert_eq!(
cb_label_stage1, "phase.iter89e2b_stage1.fence",
"stage 1 fenced CB carries the fence_stage label"
);
sess.reset_for_next_stage().expect("reset_for_next_stage Ok");
assert!(!sess.is_drained(), "no longer drained after reset");
assert!(!sess.is_fence_pending(), "no fence pending after reset");
assert_eq!(
sess.fence_value(),
1,
"fence_value persists across reset (it is the high-water mark)"
);
let cb_label_post_reset: String = sess.metal_command_buffer().label().to_string();
assert_ne!(
cb_label_post_reset, "phase.iter89e2b_stage1.fence",
"reset_for_next_stage must rotate to a fresh MTLCommandBuffer (no carryover label)"
);
sess.begin_stage("phase.iter89e2b_stage2");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a2,
&b2,
&out2,
n,
DType::F32,
)
.expect("stage2 dispatch");
sess.commit_and_wait().expect("stage2 commit_and_wait Ok");
assert!(sess.is_drained(), "drained after commit_and_wait");
assert!(!sess.is_fence_pending(), "no fence pending after commit_and_wait");
let r1 = out1.as_slice::<f32>().expect("read out1");
assert_eq!(
r1,
&[11.0, 22.0, 33.0, 44.0],
"stage 1 elementwise_add result must propagate via fenced CB"
);
let r2 = out2.as_slice::<f32>().expect("read out2");
assert_eq!(
r2,
&[107.0, 208.0, 309.0, 410.0],
"stage 2 elementwise_add result must propagate via fresh CB after wait-event"
);
let cb_label = sess.metal_command_buffer().label();
assert_eq!(
cb_label, "phase.iter89e2b_stage2",
"stage 2 label must propagate to the fresh CB's MTLCommandBuffer.label"
);
}
#[test]
fn test_session_fence_event_signal_wait_round_trip() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_fence_event_signal_wait_round_trip \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let _guard = acquire_test_lock();
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 2usize;
let byte_len = n * std::mem::size_of::<f32>();
let inputs: Vec<(_, _, _)> = (0..3)
.map(|i| {
let mut a = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a");
let mut b = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
let base = (i as f32) * 100.0;
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[base + 1.0, base + 2.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0]);
(a, b, out)
})
.collect();
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
assert_eq!(sess.fence_value(), 0);
assert!(!sess.has_event());
for (i, (a, b, out)) in inputs.iter().enumerate() {
sess.begin_stage(&format!("phase.iter89e2b_chain_stage{i}"));
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
a,
b,
out,
n,
DType::F32,
)
.expect("chain dispatch");
sess.fence_stage(None).expect("fence_stage Ok");
let expected_value = (i as u64) + 1;
assert_eq!(
sess.fence_value(),
expected_value,
"fence_value must increment monotonically per fence (i={i})"
);
assert!(sess.has_event(), "event must be allocated after first fence");
assert!(sess.is_fence_pending());
assert!(sess.is_drained());
sess.reset_for_next_stage()
.expect("reset_for_next_stage Ok");
assert!(!sess.is_drained());
assert!(!sess.is_fence_pending());
assert_eq!(
sess.fence_value(),
expected_value,
"fence_value persists across reset"
);
}
sess.begin_stage("phase.iter89e2b_chain_drain");
let (a_final, b_final, out_final) = &inputs[0];
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
a_final,
b_final,
out_final,
n,
DType::F32,
)
.expect("drain dispatch");
sess.commit_and_wait().expect("drain commit_and_wait Ok");
for (i, (_, _, out)) in inputs.iter().enumerate() {
let r = out.as_slice::<f32>().expect("read chain out");
let base = (i as f32) * 100.0;
assert_eq!(
r,
&[base + 11.0, base + 22.0],
"chain stage {i} output must be readable after drained commit"
);
}
assert_eq!(
sess.fence_value(),
3,
"fence_value is the high-water mark (3 fences fired); commit_and_wait does not bump it"
);
}
#[test]
fn test_session_residency_delegation_round_trip() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_residency_delegation_round_trip \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let _guard = acquire_test_lock();
let device = MlxDevice::new().expect("MlxDevice::new");
if !device.residency_sets_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_residency_delegation_round_trip \
SKIPPED — residency sets disabled (macOS<15 or HF2Q_NO_RESIDENCY=1)."
);
return;
}
reset_residency_test_counters();
let baseline = residency_allocation_count_for_test();
assert_eq!(baseline, 0, "reset_residency_test_counters zeros the count");
let buf = device
.alloc_buffer(1024, DType::F32, vec![256])
.expect("alloc_buffer");
assert_eq!(
residency_allocation_count_for_test(),
baseline + 1,
"device.alloc_buffer must auto-register (delta=+1)"
);
let sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
let removed = sess.remove_from_residency_set(&buf);
assert!(
removed,
"remove_from_residency_set must return true when residency is enabled"
);
assert_eq!(
residency_allocation_count_for_test(),
baseline,
"session.remove_from_residency_set must decrement the same counter"
);
let added = sess.add_to_residency_set(&buf);
assert!(
added,
"add_to_residency_set must return true when residency is enabled"
);
assert_eq!(
residency_allocation_count_for_test(),
baseline + 1,
"session.add_to_residency_set must increment the same counter"
);
drop(sess);
drop(buf);
assert_eq!(
residency_allocation_count_for_test(),
baseline,
"buffer Drop must remove its registration via storage Drop"
);
}
#[test]
fn test_session_drop_with_open_fence_drains_synchronously() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_drop_with_open_fence_drains_synchronously \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let _guard = acquire_test_lock();
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a");
let mut b = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[5.0, 6.0, 7.0, 8.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
{
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
sess.begin_stage("phase.iter89e2b_drop_fenced");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("dispatch");
sess.fence_stage(None).expect("fence_stage Ok");
assert!(sess.is_fence_pending());
assert!(sess.is_drained());
sess.metal_command_buffer().wait_until_completed();
}
let r = out.as_slice::<f32>().expect("read out post-drop");
assert_eq!(
r,
&[6.0, 8.0, 10.0, 12.0],
"fenced CB output must be visible after session drop"
);
let mut enc = device
.command_encoder()
.expect("command_encoder post-drop");
let mut a2 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a2");
let mut b2 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b2");
let out2 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out2");
a2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
b2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[2.0, 2.0, 2.0, 2.0]);
mlx_native::ops::elementwise::elementwise_add(
&mut enc,
&mut registry,
device.metal_device(),
&a2,
&b2,
&out2,
n,
DType::F32,
)
.expect("post-drop dispatch");
enc.commit_and_wait().expect("post-drop commit_and_wait");
assert_eq!(
out2.as_slice::<f32>().expect("read out2"),
&[3.0, 3.0, 3.0, 3.0],
"device usable after fenced EncoderSession Drop"
);
}
#[test]
fn test_session_arena_lifetime_under_fence_no_rescission() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_arena_lifetime_under_fence_no_rescission \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let _guard = acquire_test_lock();
let device = MlxDevice::new().expect("MlxDevice::new");
if !device.residency_sets_enabled() {
eprintln!(
"[encoder_session_multistage] test_session_arena_lifetime_under_fence_no_rescission \
SKIPPED — residency sets disabled (macOS<15 or HF2Q_NO_RESIDENCY=1). \
F2 adversarial test only meaningful when residency sets are live."
);
return;
}
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a2 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a2");
let mut b2 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b2");
let out2 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out2");
a2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 1.0, 1.0, 1.0]);
b2.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[2.0, 2.0, 2.0, 2.0]);
let mut a1 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("a1");
let mut b1 = device.alloc_buffer(byte_len, DType::F32, vec![n]).expect("b1");
a1.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[5.0, 6.0, 7.0, 8.0]);
b1.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
reset_residency_test_counters();
let baseline = residency_allocation_count_for_test();
assert_eq!(
baseline, 0,
"reset_residency_test_counters zeros the count baseline"
);
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
let scratch_out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("scratch_out");
assert_eq!(
residency_allocation_count_for_test(),
baseline + 1,
"scratch_out alloc bumped residency count"
);
sess.begin_stage("phase.iter89e2b_f2_stage1");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a1,
&b1,
&scratch_out,
n,
DType::F32,
)
.expect("stage1 dispatch into scratch_out");
sess.fence_stage(None).expect("fence_stage Ok");
assert!(sess.is_fence_pending());
drop(scratch_out);
assert_eq!(
residency_allocation_count_for_test(),
baseline,
"MlxBufferStorage::Drop must decrement residency count even mid-fence"
);
sess.reset_for_next_stage()
.expect("reset_for_next_stage Ok");
sess.begin_stage("phase.iter89e2b_f2_stage2");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a2,
&b2,
&out2,
n,
DType::F32,
)
.expect("stage2 dispatch");
sess.commit_and_wait()
.expect("stage2 commit_and_wait must succeed under retained-refs F2");
let r2 = out2.as_slice::<f32>().expect("read out2");
assert_eq!(
r2,
&[3.0, 3.0, 3.0, 3.0],
"stage 2 output must be correct after the F2 mid-fence drop"
);
assert_eq!(
residency_allocation_count_for_test(),
baseline,
"F2 fence preservation: residency count back to baseline after \
scratch drop + fence + reset + commit"
);
drop(sess);
let mut enc = device.command_encoder().expect("command_encoder post-F2");
let mut a3 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a3");
let mut b3 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b3");
let out3 = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out3");
a3.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
b3.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
mlx_native::ops::elementwise::elementwise_add(
&mut enc,
&mut registry,
device.metal_device(),
&a3,
&b3,
&out3,
n,
DType::F32,
)
.expect("post-F2 dispatch");
enc.commit_and_wait().expect("post-F2 commit_and_wait");
assert_eq!(
out3.as_slice::<f32>().expect("read out3"),
&[11.0, 22.0, 33.0, 44.0],
"device usable after F2 adversarial session drop"
);
}
#[test]
fn test_session_borrowed_across_n_stages() {
let _guard = acquire_test_lock();
if !EncoderSession::env_enabled() {
eprintln!(
"[test_session_borrowed_across_n_stages] SKIP — HF2Q_ENCODER_SESSION not set to \"1\""
);
return;
}
const N: usize = 5;
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n_elems = 4usize;
let byte_len = n_elems * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n_elems])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n_elems])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n_elems])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under HF2Q_ENCODER_SESSION=1");
for stage_idx in 0..N {
{
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n_elems,
DType::F32,
)
.expect("dispatch");
let label = format!("borrowed.stage{stage_idx}");
sess.fence_stage(Some(label.as_str()))
.expect("fence_stage Ok");
sess.reset_for_next_stage()
.expect("reset_for_next_stage Ok");
}
}
sess.commit_and_wait().expect("terminal commit_and_wait");
let fence_val = sess.fence_value();
let wait_count = sess.wait_count();
eprintln!("borrowed.fence_value={fence_val}");
eprintln!("borrowed.wait_count={wait_count}");
assert_eq!(
fence_val, N as u64,
"fence_value must equal N={N} (one per fence_stage)"
);
assert_eq!(
wait_count, N as u64,
"wait_count must equal N={N} when reset is called after every fence \
(including the last); spec §2.3 expectation"
);
}