use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use super::{
test_seam::{ClosureCustomNewFn, ClosureNewFn, ScopedClosureCtor, ScopedCustomCtor},
*,
};
fn serial_guard() -> std::sync::MutexGuard<'static, ()> {
static SERIAL: std::sync::Mutex<()> = std::sync::Mutex::new(());
SERIAL.lock().unwrap_or_else(|poison| poison.into_inner())
}
struct DropSentinel {
counter: Arc<AtomicUsize>,
}
impl Drop for DropSentinel {
fn drop(&mut self) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}
unsafe extern "C" fn stub_closure_new_invokes_dtor_then_returns_null(
_fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
payload: *mut c_void,
dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure {
if let Some(d) = dtor {
unsafe { d(payload) };
}
mlxrs_sys::mlx_closure {
ctx: ptr::null_mut(),
}
}
unsafe extern "C" fn stub_closure_new_returns_null_no_dtor(
_fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
_payload: *mut c_void,
_dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure {
mlxrs_sys::mlx_closure {
ctx: ptr::null_mut(),
}
}
#[test]
fn closure_new_returns_err_without_double_free_when_ffi_returns_null_after_invoking_destructor() {
let _serial = serial_guard();
let drop_counter = Arc::new(AtomicUsize::new(0));
let _guard =
ScopedClosureCtor::install(stub_closure_new_invokes_dtor_then_returns_null as ClosureNewFn);
let sentinel = DropSentinel {
counter: Arc::clone(&drop_counter),
};
let result = Closure::new(move |_xs: &[Array]| {
let _keep = &sentinel;
Ok(Vec::<Array>::new())
});
assert!(
result.is_err(),
"Closure::new must surface Err when mlx-c returns NULL ctx"
);
let observed = drop_counter.load(Ordering::SeqCst);
assert_eq!(
observed, 1,
"REGRESSION: boxed closure was dropped {observed} times; expected \
EXACTLY 1 (a `Box::from_raw` reclaim on the NULL-ctx branch produces \
2 = double-free / UAF; a missing-dtor regression produces 0)."
);
}
#[test]
fn closure_new_returns_err_without_reclaim_when_ffi_returns_null_without_invoking_destructor() {
let _serial = serial_guard();
let drop_counter = Arc::new(AtomicUsize::new(0));
let _guard = ScopedClosureCtor::install(stub_closure_new_returns_null_no_dtor as ClosureNewFn);
let sentinel = DropSentinel {
counter: Arc::clone(&drop_counter),
};
let result = Closure::new(move |_xs: &[Array]| {
let _keep = &sentinel;
Ok(Vec::<Array>::new())
});
assert!(
result.is_err(),
"Closure::new must surface Err when mlx-c returns NULL ctx (no-dtor path)"
);
let observed = drop_counter.load(Ordering::SeqCst);
assert_eq!(
observed, 0,
"Leak-over-UAF contract violated: boxed closure was dropped {observed} \
times; expected EXACTLY 0 (Rust reclaim here would be UAF on any \
subsequent mlx-c-internal payload drop)."
);
}
unsafe extern "C" fn stub_closure_custom_new_invokes_dtor_then_returns_null(
_fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
payload: *mut c_void,
dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure_custom {
if let Some(d) = dtor {
unsafe { d(payload) };
}
mlxrs_sys::mlx_closure_custom {
ctx: ptr::null_mut(),
}
}
unsafe extern "C" fn stub_closure_custom_new_returns_null_no_dtor(
_fun: Option<
unsafe extern "C" fn(
*mut mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
mlxrs_sys::mlx_vector_array,
*mut c_void,
) -> c_int,
>,
_payload: *mut c_void,
_dtor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> mlxrs_sys::mlx_closure_custom {
mlxrs_sys::mlx_closure_custom {
ctx: ptr::null_mut(),
}
}
#[test]
fn closure_custom_new_returns_err_without_double_free_when_ffi_returns_null_after_invoking_destructor()
{
let _serial = serial_guard();
let drop_counter = Arc::new(AtomicUsize::new(0));
let _guard = ScopedCustomCtor::install(
stub_closure_custom_new_invokes_dtor_then_returns_null as ClosureCustomNewFn,
);
let sentinel = DropSentinel {
counter: Arc::clone(&drop_counter),
};
let result = closure_custom_new(move |_p: &[Array], _o: &[Array], _c: &[Array]| {
let _keep = &sentinel;
Ok(Vec::new())
});
assert!(
result.is_err(),
"closure_custom_new must surface Err when mlx-c returns NULL ctx"
);
let observed = drop_counter.load(Ordering::SeqCst);
assert_eq!(
observed, 1,
"REGRESSION (custom-VJP): boxed closure was dropped {observed} times; \
expected EXACTLY 1 (a `Box::from_raw` reclaim on the NULL-ctx branch \
produces 2 = double-free / UAF)."
);
}
#[test]
fn closure_custom_new_returns_err_without_reclaim_when_ffi_returns_null_without_invoking_destructor()
{
let _serial = serial_guard();
let drop_counter = Arc::new(AtomicUsize::new(0));
let _guard =
ScopedCustomCtor::install(stub_closure_custom_new_returns_null_no_dtor as ClosureCustomNewFn);
let sentinel = DropSentinel {
counter: Arc::clone(&drop_counter),
};
let result = closure_custom_new(move |_p: &[Array], _o: &[Array], _c: &[Array]| {
let _keep = &sentinel;
Ok(Vec::new())
});
assert!(
result.is_err(),
"closure_custom_new must surface Err when mlx-c returns NULL ctx (no-dtor path)"
);
let observed = drop_counter.load(Ordering::SeqCst);
assert_eq!(
observed, 0,
"Leak-over-UAF contract violated (custom-VJP): boxed closure was dropped \
{observed} times; expected EXACTLY 0."
);
}