use std::sync::{Mutex, MutexGuard, PoisonError};
use mlxrs::{
Stream,
memory::{
WiredBudgetPolicy, WiredFixedPolicy, WiredLimitGuard, WiredMaxPolicy, WiredMemoryMeasurement,
WiredMemoryPolicy, WiredSumPolicy, recommended_working_set_bytes, set_wired_limit,
},
};
static WIRED_LIMIT_LOCK: Mutex<()> = Mutex::new(());
fn lock_wired_limit() -> MutexGuard<'static, ()> {
WIRED_LIMIT_LOCK
.lock()
.unwrap_or_else(PoisonError::into_inner)
}
#[test]
fn wired_sum_policy_clamps_to_cap() {
let policy = WiredSumPolicy::new(Some(200));
assert_eq!(policy.limit(100, &[50, 100]), 200);
assert!(policy.can_admit(100, &[50], 50));
assert!(!policy.can_admit(100, &[50], 51));
}
#[test]
fn wired_sum_policy_below_cap() {
let policy = WiredSumPolicy::new(Some(1_000));
assert_eq!(policy.limit(100, &[50, 75]), 225);
assert!(policy.can_admit(100, &[50, 75], 700)); }
#[test]
fn wired_sum_policy_empty_active_returns_baseline() {
let policy = WiredSumPolicy::new(Some(500));
assert_eq!(policy.limit(100, &[]), 100);
}
#[test]
fn wired_sum_policy_no_cap_clamps_to_recommended_if_available() {
let policy = WiredSumPolicy::new(None);
let baseline = 100u64;
let active = [50u64, 75];
let raw_sum = baseline + active.iter().sum::<u64>();
let result = policy.limit(baseline, &active);
match recommended_working_set_bytes().expect("recommended_working_set_bytes FFI") {
Some(rec) => {
assert_eq!(
result,
raw_sum.min(rec),
"cap=None: result == min(raw_sum, recommended)"
);
}
None => {
assert_eq!(result, raw_sum, "cap=None + no recommended: pass-through");
}
}
}
#[test]
fn wired_max_policy_max_active_or_baseline() {
let policy = WiredMaxPolicy::new();
assert_eq!(policy.limit(100, &[20, 150, 60]), 150);
assert_eq!(policy.limit(200, &[20, 150, 60]), 200);
}
#[test]
fn wired_max_policy_empty_active_returns_baseline() {
let policy = WiredMaxPolicy::new();
assert_eq!(policy.limit(100, &[]), 100);
}
#[test]
fn wired_max_policy_admits_unconditionally() {
let policy = WiredMaxPolicy::new();
assert!(policy.can_admit(100, &[], u64::MAX / 2));
}
#[test]
fn wired_fixed_policy_returns_limit() {
let policy = WiredFixedPolicy::new(123);
assert_eq!(policy.limit(0, &[]), 123);
assert_eq!(policy.limit(500, &[1, 2, 3]), 123);
}
#[test]
fn wired_fixed_policy_admits_unconditionally() {
let policy = WiredFixedPolicy::new(100);
assert!(policy.can_admit(0, &[], u64::MAX));
}
#[test]
fn wired_budget_policy_identity_and_cap_behavior() {
let shared_id = "shared-test-id";
let first = WiredBudgetPolicy::with_id(shared_id, 100, Some(300));
let second = WiredBudgetPolicy::with_id(shared_id, 999, Some(999));
let third = WiredBudgetPolicy::with_id("other-id", 100, Some(300));
assert_eq!(first, second);
assert_ne!(first, third);
assert_eq!(first.limit(50, &[75]), 225);
assert!(first.can_admit(50, &[75], 75));
assert!(!first.can_admit(50, &[75], 76));
}
#[test]
fn wired_budget_policy_auto_id_is_unique() {
let a = WiredBudgetPolicy::new(100, Some(300));
let b = WiredBudgetPolicy::new(100, Some(300));
assert_ne!(a, b, "auto-id instances should be distinct");
assert_ne!(a.id(), b.id());
}
#[test]
fn wired_budget_policy_hash_matches_eq() {
use std::collections::HashSet;
let id = "hash-test";
let a = WiredBudgetPolicy::with_id(id, 100, Some(300));
let b = WiredBudgetPolicy::with_id(id, 999, None);
let mut set: HashSet<WiredBudgetPolicy> = HashSet::new();
set.insert(a.clone());
set.insert(b.clone());
assert_eq!(set.len(), 1);
assert!(set.contains(&a));
assert!(set.contains(&b));
}
#[test]
fn wired_memory_measurement_construction_and_total() {
let m = WiredMemoryMeasurement::new(1_000, 200, 50, 1_400, 128, 32);
assert_eq!(m.weight_bytes(), 1_000);
assert_eq!(m.kv_bytes(), 200);
assert_eq!(m.workspace_bytes(), 50);
assert_eq!(m.peak_active_bytes(), 1_400);
assert_eq!(m.token_count(), 128);
assert_eq!(m.prefill_step_size(), 32);
assert_eq!(m.total_bytes(), 1_250);
}
#[test]
fn wired_limit_guard_install_succeeds_or_returns_none() {
let _serialized = lock_wired_limit();
let result = WiredLimitGuard::install(0, &[]);
assert!(
result.is_ok(),
"install must not error on a healthy GPU/no-GPU host: {result:?}"
);
}
#[test]
fn wired_limit_guard_drop_restores_old_limit() {
let _serialized = lock_wired_limit();
let Ok(Some(recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let guard = WiredLimitGuard::install(0, &[]).expect("install rc");
let guard = guard.expect("guard installed (recommended is Some)");
let captured_old = guard.old_limit();
drop(guard);
let observed_after_drop = set_wired_limit(captured_old).expect("set_wired_limit rc");
assert_eq!(
observed_after_drop, captured_old,
"guard's Drop must restore the captured old_limit (recommended={recommended}, \
captured_old={captured_old}, observed_after_drop={observed_after_drop})"
);
}
#[test]
fn concurrent_install_drop_restores_correct_old_limit_per_owner() {
let _serialized = lock_wired_limit();
let Ok(Some(_recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let snapshot_before = set_wired_limit(0).expect("snapshot via round-trip");
let _ = set_wired_limit(snapshot_before).expect("restore snapshot baseline");
let handles: Vec<_> = (0..2)
.map(|_| {
std::thread::spawn(|| {
for _ in 0..32 {
match WiredLimitGuard::install(0, &[]) {
Ok(Some(_guard)) => {
}
Ok(None) => {
panic!(
"refcounted semantics: every install on a Metal-available \
host must yield Some(guard); Ok(None) is reserved for the \
Metal-unavailable path. A single-active-guard design \
silently returns None on concurrent installs which loses \
in-scope protection — do not reintroduce that contract."
);
}
Err(e) => panic!("install must not error on a healthy host: {e:?}"),
}
}
})
})
.collect();
for h in handles {
h.join().expect("worker thread must not panic");
}
let observed = set_wired_limit(snapshot_before).expect("read-back via set");
assert_eq!(
observed, snapshot_before,
"after concurrent install/drop stress, the wired-memory limit \
must equal the pre-stress snapshot (snapshot_before={snapshot_before}, \
observed={observed}). A mismatch (typically observed = recommended) \
indicates an install/Drop race — the refcounted shared state is \
supposed to make it impossible for two guards to simultaneously \
hold inconsistent captures, while still letting each guard's full \
scope enjoy the recommended-limit protection."
);
}
#[test]
fn sequential_install_then_install_then_drop_first_still_protects_second_guard() {
let _serialize = lock_wired_limit();
let Ok(Some(recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let test_baseline: u64 = if recommended > 1024 * 1024 * 1024 {
recommended / 2
} else {
recommended.saturating_sub(1024)
};
let original = set_wired_limit(test_baseline).expect("set baseline");
struct RestoreOriginal(u64);
impl Drop for RestoreOriginal {
fn drop(&mut self) {
let _ = set_wired_limit(self.0);
}
}
let _restore = RestoreOriginal(original);
let g1 = WiredLimitGuard::install(0, &[])
.expect("g1 install: FFI error")
.expect("g1 install: returned None on first install with Metal available");
let g2 = WiredLimitGuard::install(0, &[])
.expect("g2 install: FFI error")
.expect(
"g2 install: returned Ok(None) while g1 was active — \
single-active regression (refcounted semantics must return Ok(Some) via refcount bump)",
);
let observed_before_drop = set_wired_limit(recommended).expect("readback 1");
assert_eq!(
observed_before_drop, recommended,
"before any drop, limit was {observed_before_drop} not recommended {recommended} — \
refcount bookkeeping broken"
);
drop(g1);
let observed_after_g1_drop = set_wired_limit(recommended).expect("readback 2");
assert_eq!(
observed_after_g1_drop, recommended,
"after dropping g1 while g2 alive, limit was {observed_after_g1_drop} not recommended \
{recommended} — in-scope protection lost"
);
drop(g2);
let final_observed = set_wired_limit(test_baseline).expect("readback 3");
assert_eq!(
final_observed, test_baseline,
"after all drops, limit was {final_observed} not test_baseline {test_baseline} — \
restore broken"
);
}
#[test]
fn scoped_threads_t1_drop_first_does_not_revoke_t2_protection() {
let _serialize = lock_wired_limit();
let Some(recommended) = recommended_working_set_bytes().expect("rw query") else {
eprintln!("[skip] Metal not available");
return;
};
let test_baseline: u64 = if recommended > 1024 * 1024 * 1024 {
recommended / 2
} else {
recommended.saturating_sub(1024)
};
let original = set_wired_limit(test_baseline).expect("set baseline");
struct RestoreOriginal(u64);
impl Drop for RestoreOriginal {
fn drop(&mut self) {
let _ = set_wired_limit(self.0);
}
}
let _restore = RestoreOriginal(original);
use std::sync::mpsc;
let (t1_installed_tx, t1_installed_rx) = mpsc::channel::<()>();
let (t2_observed_tx, t2_observed_rx) = mpsc::channel::<u64>(); let (t1_dropped_tx, t1_dropped_rx) = mpsc::channel::<()>();
let observed_after_t1_drop = std::thread::scope(|s| {
s.spawn(move || {
let g1 = WiredLimitGuard::install(0, &[])
.expect("T1 install FFI error")
.expect("T1 install returned None on first install with Metal");
t1_installed_tx.send(()).expect("send t1_installed");
let _ = t2_observed_rx
.recv_timeout(std::time::Duration::from_secs(10))
.expect("T2 never sent observation — T2 likely failed install");
drop(g1);
t1_dropped_tx.send(()).expect("send t1_dropped");
});
t1_installed_rx
.recv_timeout(std::time::Duration::from_secs(10))
.expect("T1 install never completed within 10s");
let t2 = s.spawn(move || {
let _g2 = WiredLimitGuard::install(0, &[])
.expect("T2 install FFI error")
.expect(
"T2 install returned None while T1 was active — \
single-active regression OR per-thread-scoped guard regression",
);
let limit_with_both = set_wired_limit(recommended).expect("readback both");
assert_eq!(
limit_with_both, recommended,
"while both guards alive, limit was {limit_with_both} not recommended {recommended}"
);
t2_observed_tx
.send(limit_with_both)
.expect("send t2_observed");
t1_dropped_rx
.recv_timeout(std::time::Duration::from_secs(10))
.expect("T1 drop never completed within 10s");
let limit_after_t1 = set_wired_limit(recommended).expect("readback after T1");
assert_eq!(
limit_after_t1, recommended,
"after T1 drop while T2 alive, limit was {limit_after_t1} not recommended \
{recommended} — cross-thread in-scope protection lost (refcount regression OR \
per-thread-scoped guards)"
);
limit_after_t1
});
t2.join()
.expect("T2 panicked (assertion failure shown above)")
});
assert_eq!(observed_after_t1_drop, recommended);
let final_limit = set_wired_limit(test_baseline).expect("readback final");
assert_eq!(
final_limit, test_baseline,
"after all guards dropped, limit was {final_limit} not test_baseline {test_baseline} — \
refcount or restore broken"
);
}
#[test]
fn concurrent_install_returns_some_guard_not_none_when_already_installed() {
let _serialized = lock_wired_limit();
let Ok(Some(_recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let snapshot_before = set_wired_limit(0).expect("snapshot via round-trip");
let _ = set_wired_limit(snapshot_before).expect("restore baseline");
let g1 = WiredLimitGuard::install(0, &[])
.expect("t1 install rc")
.expect("t1 must receive Some(guard)");
let t2 = std::thread::spawn(|| WiredLimitGuard::install(0, &[]).map(|opt| opt.is_some()));
let t2_yielded_some = t2
.join()
.expect("t2 must not panic")
.expect("t2 install rc");
assert!(
t2_yielded_some,
"refcounted semantics: a concurrent install (T1 active) must yield \
Ok(Some(_)) not Ok(None). A single-active-guard design returns \
None here, silently losing in-scope protection for the caller. This \
test asserts the refcounted API shape directly — if it ever fails, \
a single-active regression has crept in."
);
drop(g1);
let observed = set_wired_limit(snapshot_before).expect("final read-back");
assert_eq!(
observed, snapshot_before,
"After both T1 + T2 guards drop, limit must restore to snapshot_before"
);
}
#[test]
fn refcounted_guard_drop_does_not_restore_until_last_drop() {
let _serialized = lock_wired_limit();
let Ok(Some(recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let snapshot_before = set_wired_limit(0).expect("snapshot via round-trip");
let _ = set_wired_limit(snapshot_before).expect("restore baseline");
if snapshot_before == recommended {
eprintln!(
"skipping: snapshot_before ({snapshot_before}) == recommended ({recommended}); \
this test cannot discriminate refcount-discipline in that degenerate state"
);
return;
}
let a = WiredLimitGuard::install(0, &[])
.expect("a install rc")
.expect("a must receive Some(guard)");
let b = WiredLimitGuard::install(0, &[])
.expect("b install rc")
.expect("b must receive Some(guard) under refcounted semantics");
drop(a);
let observed_after_a_drop = set_wired_limit(recommended).expect("read-back after a drop");
assert_eq!(
observed_after_a_drop, recommended,
"refcount discipline: dropping A while B is still alive must NOT \
restore the limit (observed={observed_after_a_drop}, \
recommended={recommended}). A mismatch here means Drop unconditionally \
restores instead of refcount-gated restoring — the exact regression \
the refcounted design exists to prevent."
);
let _ = set_wired_limit(recommended).expect("re-set recommended");
drop(b);
let observed_after_b_drop = set_wired_limit(snapshot_before).expect("read-back after b drop");
assert_eq!(
observed_after_b_drop, snapshot_before,
"last-drop-restores: after B (the last live guard in the epoch) \
drops, limit must restore to snapshot_before (observed={observed_after_b_drop}, \
snapshot_before={snapshot_before})."
);
}
#[test]
fn wired_limit_guard_drop_after_stream_cleanup_does_not_panic_and_still_restores() {
let _serialized = lock_wired_limit();
let Ok(Some(_recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let snapshot_before = set_wired_limit(0).expect("snapshot via round-trip");
let _ = set_wired_limit(snapshot_before).expect("restore snapshot baseline");
let join = std::thread::spawn(move || {
let _ = Stream::default_gpu();
let guard = WiredLimitGuard::install(0, &[])
.expect("install rc")
.expect("guard installed");
let captured = guard.old_limit();
Stream::clear_current_thread_streams().expect("clear_streams shim rc");
drop(guard);
captured
});
let captured = join.join().expect(
"worker thread MUST NOT panic — a WiredLimitGuard::drop that called \
Stream::default_gpu()/synchronize() would panic on a stream-cleared \
thread, and a panic-on-Drop would leak the process-global \
wired-memory limit (worst case: double-panic abort).",
);
let observed = set_wired_limit(captured).expect("read-back via set");
assert_eq!(
observed, captured,
"Drop must restore captured old_limit (captured={captured}, \
observed={observed}, snapshot_before={snapshot_before}) — \
Drop uses Stream::try_synchronize/try_default_gpu to skip the \
sync step on a cleared thread while still running the limit restore."
);
}
#[test]
fn wired_limit_guard_drop_during_panic_does_not_double_panic() {
let _serialized = lock_wired_limit();
let Ok(Some(_recommended)) = recommended_working_set_bytes() else {
eprintln!("skipping: recommended_working_set_bytes unavailable on this host");
return;
};
let snapshot_before = set_wired_limit(0).expect("snapshot via round-trip");
let _ = set_wired_limit(snapshot_before).expect("restore snapshot baseline");
let captured_cell = std::sync::Mutex::new(None::<u64>);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let guard = WiredLimitGuard::install(0, &[])
.expect("install rc")
.expect("guard installed");
*captured_cell.lock().unwrap() = Some(guard.old_limit());
panic!("deliberate panic with live WiredLimitGuard");
}));
assert!(
result.is_err(),
"deliberate panic inside the closure must propagate as Err from catch_unwind"
);
let captured = captured_cell.lock().unwrap().expect(
"guard install ran before the panic, so old_limit was recorded; if this is \
None, the install path or panic ordering changed",
);
let observed = set_wired_limit(captured).expect("read-back via set");
assert_eq!(
observed, captured,
"Drop-during-panic must restore captured old_limit (captured={captured}, \
observed={observed}, snapshot_before={snapshot_before})"
);
}
#[test]
fn recommended_working_set_bytes_returns_ok() {
let result = recommended_working_set_bytes();
assert!(
result.is_ok(),
"FFI rc must surface as Ok(...) on a healthy host"
);
if let Ok(Some(n)) = result {
assert!(n > 0, "Some-value contract: > 0 bytes");
}
}
#[cfg(target_os = "macos")]
#[test]
fn recommended_working_set_bytes_returns_some_on_metal() {
let result = recommended_working_set_bytes().expect("FFI rc on healthy mac");
let bytes = result.expect(
"macOS host MUST surface Some(bytes) from the populated mlx_device_info \
map — None here means an always-None regression has crept in",
);
assert!(
bytes > 0,
"Metal max_recommended_working_set_size must be > 0 (got {bytes})"
);
}
#[test]
fn tune_returns_actionable_unimplemented_error() {
let err = mlxrs::memory::tune(0, 0, 0, &[]).expect_err("tune is a stub");
let msg = err.to_string();
assert!(
msg.contains("not yet implemented"),
"tune error message should advertise its unimplemented status: {msg}"
);
}