#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::sync::Mutex;
use mlx_native::{DType, EncoderSession, KernelRegistry, MlxDevice};
static TEST_LOCK: Mutex<()> = Mutex::new(());
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
TEST_LOCK
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn encoder_session_wait_event_smoke() {
let _guard = acquire_test_lock();
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_wait_event_smoke] SKIP — HF2Q_ENCODER_SESSION not set to \"1\" \
in process env. Re-run with HF2Q_ENCODER_SESSION=1 to exercise the wait-event path.\n\
fence_value=skipped\n\
wait_count=skipped\n\
wait_value=skipped"
);
return;
}
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under HF2Q_ENCODER_SESSION=1");
assert_eq!(
sess.fence_value(),
0,
"fresh session: fence_value starts at 0"
);
assert_eq!(
sess.wait_count(),
0,
"fresh session: wait_count starts at 0"
);
assert_eq!(
sess.wait_value(),
0,
"fresh session: wait_value starts at 0 (no wait emitted yet)"
);
assert!(
!sess.has_event(),
"fresh session: no SharedEvent allocated until first fence_stage"
);
for i in 0..5usize {
let label = format!("session.wait_smoke.stage{i}");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("elementwise_add session");
sess.fence_stage(Some(label.as_str()))
.expect("fence_stage Ok");
let expected_signal = (i as u64) + 1;
assert_eq!(
sess.fence_value(),
expected_signal,
"after fence_stage #{i}: fence_value must be {expected_signal}"
);
assert!(
sess.is_fence_pending(),
"after fence_stage #{i}: is_fence_pending must be true"
);
assert!(
sess.is_drained(),
"after fence_stage #{i}: is_drained must be true"
);
assert!(
sess.has_event(),
"after fence_stage #{i}: has_event must be true (lazy-alloc on i==0)"
);
if i < 4 {
sess.reset_for_next_stage()
.expect("reset_for_next_stage Ok");
let expected_wait = (i as u64) + 1;
assert_eq!(
sess.wait_count(),
expected_wait,
"after reset following fence #{i}: wait_count must be {expected_wait}"
);
assert_eq!(
sess.wait_value(),
expected_wait,
"after reset following fence #{i}: wait_value must be {expected_wait} \
(matches the signal we just signaled)"
);
assert!(
!sess.is_fence_pending(),
"after reset following fence #{i}: is_fence_pending must be cleared"
);
assert!(
!sess.is_drained(),
"after reset following fence #{i}: is_drained must be cleared"
);
}
}
let fence_val = sess.fence_value();
let wait_count = sess.wait_count();
let wait_value = sess.wait_value();
sess.metal_command_buffer().wait_until_completed();
assert_eq!(
sess.fence_value(),
fence_val,
"fence_value must be stable across wait_until_completed"
);
assert_eq!(
sess.wait_count(),
wait_count,
"wait_count must be stable across wait_until_completed"
);
assert_eq!(
sess.wait_value(),
wait_value,
"wait_value must be stable across wait_until_completed"
);
eprintln!("fence_value={fence_val}");
eprintln!("wait_count={wait_count}");
eprintln!("wait_value={wait_value}");
assert_eq!(
fence_val, 5,
"fence_value must be 5 after exactly 5 fence_stage calls (got {fence_val})"
);
assert_eq!(
wait_count, 4,
"wait_count must be 4 (4 non-terminal resets between 5 fences; got {wait_count})"
);
assert_eq!(
wait_value, 4,
"wait_value must be 4 (the wait following fence #4 at value 4; got {wait_value})"
);
}