#![allow(clippy::expect_used, clippy::unwrap_used)]
use std::sync::Mutex;
use mlx_native::{
cmd_buf_count, reset_counters, DType, EncoderSession, KernelRegistry, MlxDevice,
};
static TEST_LOCK: Mutex<()> = Mutex::new(());
fn acquire_test_lock() -> std::sync::MutexGuard<'static, ()> {
TEST_LOCK
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn encoder_session_cb_count_smoke() {
let _guard = acquire_test_lock();
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_cb_count_smoke] SKIP — HF2Q_ENCODER_SESSION not set to \"1\" \
in process env. Re-run with HF2Q_ENCODER_SESSION=1 to exercise the H2b path.\n\
fence_value=skipped\n\
cb_count_plain=skipped\n\
cb_count_session=skipped\n\
wait_count=skipped"
);
return;
}
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
reset_counters();
let cb_before_plain = cmd_buf_count();
for i in 0..5usize {
{
let mut enc = device
.command_encoder()
.expect("command_encoder plain attn");
mlx_native::ops::elementwise::elementwise_add(
&mut enc,
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("elementwise_add plain attn");
let label = format!("plain.attn.layer{i}");
enc.commit_labeled(&label);
}
{
let mut enc = device
.command_encoder()
.expect("command_encoder plain ffn");
mlx_native::ops::elementwise::elementwise_add(
&mut enc,
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("elementwise_add plain ffn");
let label = format!("plain.ffn.layer{i}");
enc.commit_labeled(&label);
}
}
{
let mut drain_enc = device.command_encoder().expect("drain encoder");
mlx_native::ops::elementwise::elementwise_add(
&mut drain_enc,
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("drain dispatch");
drain_enc.commit_and_wait().expect("drain commit_and_wait");
}
let cb_after_plain = cmd_buf_count();
let cb_count_plain = cb_after_plain - cb_before_plain - 1;
reset_counters();
let cb_before_session = cmd_buf_count();
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under HF2Q_ENCODER_SESSION=1");
for i in 0..5usize {
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("elementwise_add session attn");
sess.encoder().memory_barrier();
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("elementwise_add session ffn");
let label = format!("session.attn.layer{i}");
sess.fence_stage(Some(label.as_str()))
.expect("fence_stage Ok");
if i < 4 {
sess.reset_for_next_stage().expect("reset_for_next_stage Ok");
}
}
let fence_val = sess.fence_value();
let wait_count = sess.wait_count();
sess.metal_command_buffer().wait_until_completed();
let cb_after_session = cmd_buf_count();
let cb_count_session = cb_after_session - cb_before_session;
eprintln!("fence_value={fence_val}");
eprintln!("cb_count_plain={cb_count_plain}");
eprintln!("cb_count_session={cb_count_session}");
eprintln!("wait_count={wait_count}");
assert_eq!(
fence_val, 5,
"fence_value must be 5 after exactly 5 fence_stage calls (got {fence_val})"
);
assert_eq!(
cb_count_plain, 10,
"cb_count_plain must be 10 (5 attention + 5 FFN CBs); got {cb_count_plain}"
);
assert_eq!(
cb_count_session, 5,
"cb_count_session must be 5 (1 initial + 4 resets, with FFN folded \
into each attention CB via memory_barrier); got {cb_count_session}"
);
assert!(
cb_count_session * 2 <= cb_count_plain,
"iter90b H2b structural proof: cb_count_session ({cb_count_session}) must be \
at most half of cb_count_plain ({cb_count_plain}). FAILURE means the \
FFN→next-attention boundary emitted fence_stage instead of memory_barrier — \
i.e. carry_into_next_stage on the Sessioned variant did NOT keep the \
persistent compute encoder open."
);
assert_eq!(
wait_count, 4,
"wait_count must be 4 (one wait per non-terminal reset_for_next_stage); \
got {wait_count}"
);
}