#![allow(clippy::expect_used, clippy::unwrap_used)]
use mlx_native::{DType, EncoderSession, KernelRegistry, MlxDevice};
#[test]
fn test_env_gate_factory_agreement() {
let env_on = EncoderSession::env_enabled();
let actual_env_var = std::env::var("HF2Q_ENCODER_SESSION").as_deref() == Ok("1");
assert_eq!(
env_on, actual_env_var,
"EncoderSession::env_enabled() ({env_on}) must match the actual \
HF2Q_ENCODER_SESSION env var ({actual_env_var}) — OnceLock cache \
primes from os env at first read."
);
let device = MlxDevice::new().expect("MlxDevice::new");
let sess_opt = device
.encoder_session()
.expect("encoder_session() infallible past metal-rs new_command_buffer");
if env_on {
assert!(
sess_opt.is_some(),
"encoder_session() must return Some(_) when env_enabled()==true"
);
let session = sess_opt.expect("just unwrapped to Some above");
assert!(
!session.is_drained(),
"fresh EncoderSession::is_drained() must be false"
);
} else {
assert!(
sess_opt.is_none(),
"encoder_session() must return None when HF2Q_ENCODER_SESSION is unset \
(zero-behavior-change invariant)"
);
}
}
#[test]
fn test_env_enabled_is_stable() {
let first = EncoderSession::env_enabled();
for _ in 0..5 {
assert_eq!(
EncoderSession::env_enabled(),
first,
"env_enabled() must be stable across calls (OnceLock cache contract)"
);
}
}
#[test]
fn test_begin_stage_then_commit_stage_drains() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_lifecycle] test_begin_stage_then_commit_stage_drains \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env. \
Re-run with HF2Q_ENCODER_SESSION=1 to exercise."
);
return;
}
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
let mut sess = device
.encoder_session()
.expect("encoder_session() Ok")
.expect("Some under env=1");
sess.begin_stage("phase.iter89e2a_smoke_commit_stage");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("dispatch elementwise_add through session.encoder()");
assert!(
!sess.is_drained(),
"EncoderSession must NOT be drained until commit_* is called"
);
sess.commit_stage().expect("commit_stage() must succeed");
assert!(
sess.is_drained(),
"EncoderSession::is_drained must be true after commit_stage"
);
sess.metal_command_buffer().wait_until_completed();
let result = out.as_slice::<f32>().expect("read out");
assert_eq!(
result,
&[11.0, 22.0, 33.0, 44.0],
"elementwise_add result must propagate through EncoderSession dispatch path"
);
let cb_label = sess.metal_command_buffer().label();
assert_eq!(
cb_label, "phase.iter89e2a_smoke_commit_stage",
"stage label must propagate to MTLCommandBuffer.label via commit_labeled"
);
sess.commit_stage()
.expect("second commit_stage() must be a no-op, not an error");
}
#[test]
fn test_drop_uncommitted_is_safe() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_lifecycle] test_drop_uncommitted_is_safe \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[100.0, 200.0, 300.0, 400.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
{
let mut sess = device
.encoder_session()
.expect("encoder_session()")
.expect("Some under env=1");
sess.begin_stage("phase.iter89e2a_drop_uncommitted");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("dispatch through session.encoder()");
assert!(
!sess.is_drained(),
"session is in Encoding state, must not yet be drained"
);
}
let mut enc = device.command_encoder().expect("command_encoder post-drop");
let n2 = 2usize;
let bl2 = n2 * std::mem::size_of::<f32>();
let mut a2 = device.alloc_buffer(bl2, DType::F32, vec![n2]).expect("a2");
let mut b2 = device.alloc_buffer(bl2, DType::F32, vec![n2]).expect("b2");
let out2 = device.alloc_buffer(bl2, DType::F32, vec![n2]).expect("out2");
a2.as_mut_slice::<f32>().unwrap().copy_from_slice(&[7.0, 8.0]);
b2.as_mut_slice::<f32>().unwrap().copy_from_slice(&[1.0, 2.0]);
mlx_native::ops::elementwise::elementwise_add(
&mut enc,
&mut registry,
device.metal_device(),
&a2,
&b2,
&out2,
n2,
DType::F32,
)
.expect("post-drop dispatch");
enc.commit_and_wait().expect("post-drop commit_and_wait");
assert_eq!(
out2.as_slice::<f32>().expect("read out2"),
&[8.0, 10.0],
"device usable after EncoderSession Drop"
);
}
#[test]
fn test_commit_and_wait_blocks_until_done() {
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_lifecycle] test_commit_and_wait_blocks_until_done \
SKIPPED — HF2Q_ENCODER_SESSION not set in process env."
);
return;
}
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.5, 2.5, 3.5, 4.5]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[10.0, 20.0, 30.0, 40.0]);
let sync_before = mlx_native::sync_count();
let mut sess = device
.encoder_session()
.expect("encoder_session()")
.expect("Some under env=1");
sess.begin_stage("phase.iter89e2a_commit_and_wait");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("dispatch through session.encoder()");
sess.commit_and_wait().expect("commit_and_wait must succeed");
assert!(
sess.is_drained(),
"EncoderSession::is_drained must be true after commit_and_wait"
);
let sync_after = mlx_native::sync_count();
assert!(
sync_after >= sync_before + 1,
"commit_and_wait must increment SYNC_COUNT by ≥ 1 \
(before={sync_before}, after={sync_after})"
);
let result = out.as_slice::<f32>().expect("read out");
assert_eq!(
result,
&[11.5, 22.5, 33.5, 44.5],
"elementwise_add result must be visible after commit_and_wait \
(synchronous semantic)"
);
let cb_label = sess.metal_command_buffer().label();
assert_eq!(
cb_label, "phase.iter89e2a_commit_and_wait",
"stage label must propagate to MTLCommandBuffer.label"
);
let sync_mid = mlx_native::sync_count();
sess.commit_and_wait()
.expect("second commit_and_wait must be a no-op");
let sync_end = mlx_native::sync_count();
assert!(
sync_end >= sync_mid,
"second commit_and_wait must not decrement SYNC_COUNT"
);
}
#[test]
fn test_commit_and_wait_propagates_inner_cb_error() {
fn _typecheck<F: FnOnce(&mut EncoderSession) -> mlx_native::Result<()>>(_f: F) {}
_typecheck(|sess| sess.commit_and_wait());
if !EncoderSession::env_enabled() {
eprintln!(
"[encoder_session_lifecycle] test_commit_and_wait_propagates_inner_cb_error \
SKIPPED — HF2Q_ENCODER_SESSION not set; structural source check still runs below."
);
} else {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let n = 4usize;
let byte_len = n * std::mem::size_of::<f32>();
let mut a = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("a");
let mut b = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("b");
let out = device
.alloc_buffer(byte_len, DType::F32, vec![n])
.expect("out");
a.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[1.0, 2.0, 3.0, 4.0]);
b.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&[5.0, 6.0, 7.0, 8.0]);
let mut sess = device
.encoder_session()
.expect("encoder_session()")
.expect("Some under env=1");
sess.begin_stage("phase.iter94_task2_fail_loud_smoke");
mlx_native::ops::elementwise::elementwise_add(
sess.encoder(),
&mut registry,
device.metal_device(),
&a,
&b,
&out,
n,
DType::F32,
)
.expect("dispatch through session.encoder()");
let res: mlx_native::Result<()> = sess.commit_and_wait();
assert!(res.is_ok(), "commit_and_wait must return Ok on success; got {res:?}");
}
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let src_path = std::path::Path::new(manifest_dir)
.join("src")
.join("encoder_session.rs");
let src = std::fs::read_to_string(&src_path)
.unwrap_or_else(|e| panic!("read {}: {e}", src_path.display()));
let needle_anchor = "ADR-015 iter94 Task #2";
assert!(
src.contains(needle_anchor),
"encoder_session.rs must contain the iter94 Task #2 doc anchor \
'{needle_anchor}' so future maintainers see the fail-loud rationale"
);
let normalized: String = src.split_whitespace().collect::<Vec<_>>().join(" ");
assert!(
normalized.contains("result?; Ok(())"),
"encoder_session.rs::commit_and_wait MUST end with the explicit \
`result?; Ok(())` propagation pattern (iter94 Task #2 fail-loud \
contract). A tail-only `self.inner.commit_and_wait()` form is \
functionally equivalent but defeats the documentation intent and \
is brittle under future refactors."
);
assert!(
!src.contains("let _ = self.inner.commit_and_wait("),
"encoder_session.rs::commit_and_wait MUST NOT swallow the inner \
commit_and_wait result with `let _ = ...` (iter94 Task #2)."
);
assert!(
!src.contains("let _ = self.inner.commit_and_wait_labeled("),
"encoder_session.rs::commit_and_wait MUST NOT swallow the inner \
commit_and_wait_labeled result with `let _ = ...` (iter94 Task #2)."
);
}