use mlxrs::{
Device, DeviceKind, Stream,
stream::{get_default_stream, set_default_stream},
};
static DEFAULT_DEVICE_TEST_GUARD: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn device_cpu_constructs() {
let dev = Device::cpu().expect("cpu device");
assert_eq!(dev.kind().expect("kind"), DeviceKind::Cpu);
assert_eq!(dev.index().expect("index"), 0);
}
#[test]
fn device_gpu_constructs() {
let dev = Device::gpu().expect("gpu device");
assert_eq!(dev.kind().expect("kind"), DeviceKind::Gpu);
assert_eq!(dev.index().expect("index"), 0);
}
#[test]
fn device_with_index_round_trips_kind_and_index() {
let dev = Device::with_index(DeviceKind::Cpu, 0).expect("cpu(0)");
assert_eq!(dev.kind().unwrap(), DeviceKind::Cpu);
assert_eq!(dev.index().unwrap(), 0);
}
#[test]
fn device_kind_count_returns_at_least_one_for_cpu() {
let n = DeviceKind::Cpu.count().expect("cpu count");
assert!(n >= 1, "expected at least one CPU device, got {n}");
}
#[test]
fn device_kind_count_returns_at_least_one_for_gpu_on_apple_silicon() {
let n = DeviceKind::Gpu.count().expect("gpu count");
let _ = n;
}
#[test]
fn device_current_returns_some_device() {
let dev = Device::current().expect("current device");
let _ = dev.kind().expect("current device has a kind");
}
#[test]
fn device_set_default_round_trip() {
let _guard = DEFAULT_DEVICE_TEST_GUARD
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let original = Device::current().expect("current");
let cpu = Device::cpu().expect("cpu");
cpu.set_default().expect("set_default cpu");
let after = Device::current().expect("current after");
assert_eq!(after.kind().unwrap(), DeviceKind::Cpu);
original.set_default().expect("restore");
}
#[test]
fn device_is_available_for_cpu() {
let dev = Device::cpu().unwrap();
assert!(dev.is_available().expect("availability query"));
}
#[test]
fn device_equal_and_eq_agree() {
let a = Device::cpu().unwrap();
let b = Device::cpu().unwrap();
assert!(a.equal(&b));
assert_eq!(a, b);
let g = Device::gpu().unwrap();
assert_ne!(a, g);
}
#[test]
fn device_try_clone_produces_equal_handle() {
let a = Device::cpu().unwrap();
let b = a.try_clone().expect("test: device clone");
assert_eq!(a, b);
}
#[test]
fn device_debug_prints_something() {
let dev = Device::cpu().unwrap();
let s = format!("{dev:?}");
assert!(s.starts_with("Device("), "unexpected debug format: {s}");
}
#[test]
fn device_display_is_concise_and_consistent_with_debug() {
let dev = Device::cpu().unwrap();
let disp = format!("{dev}");
let dbg = format!("{dev:?}");
assert!(!disp.is_empty(), "Display produced empty string");
assert!(
disp.to_lowercase().contains("cpu"),
"expected cpu device Display to mention cpu: {disp}"
);
assert_eq!(
dbg,
format!("Device({disp})"),
"Debug must wrap Display once"
);
assert_eq!(
disp.matches("Device(").count() + 1,
dbg.matches("Device(").count(),
"Display should have exactly one fewer Device( than Debug"
);
}
#[test]
fn device_hash_consistent_with_eq_in_a_hashset() {
use std::collections::HashSet;
let a = Device::cpu().unwrap();
let b = Device::cpu().unwrap();
assert_eq!(a, b);
let mut set = HashSet::new();
set.insert(a);
assert!(!set.insert(b), "equal Device must already be in the set");
assert_eq!(set.len(), 1);
let g = Device::gpu().unwrap();
assert!(set.insert(g), "a GPU device is distinct from CPU");
assert_eq!(set.len(), 2);
}
#[test]
fn device_get_default_alias_agrees_with_current() {
let a = Device::get_default().expect("get_default");
let b = Device::current().expect("current");
assert_eq!(a, b);
}
#[test]
fn stream_default_gpu_constructs() {
let s = Stream::default_gpu().expect("default gpu stream");
let _ = s.index().expect("stream index");
}
#[test]
fn stream_default_cpu_constructs() {
let s = Stream::default_cpu().expect("default cpu stream");
let dev = s.device().expect("device");
assert_eq!(dev.kind().unwrap(), DeviceKind::Cpu);
}
#[test]
fn stream_new_on_cpu_targets_cpu() {
let cpu = Device::cpu().unwrap();
let s = Stream::new_on(&cpu).expect("stream on cpu");
assert_eq!(s.device().unwrap().kind().unwrap(), DeviceKind::Cpu);
}
#[test]
fn stream_new_on_gpu_targets_gpu() {
let gpu = Device::gpu().unwrap();
let s = Stream::new_on(&gpu).expect("stream on gpu");
assert_eq!(s.device().unwrap().kind().unwrap(), DeviceKind::Gpu);
}
#[test]
fn stream_synchronize_succeeds_on_idle_stream() {
let s = Stream::default_cpu().expect("cpu stream");
s.synchronize().expect("sync on idle stream");
}
#[test]
fn stream_try_clone_equals_source() {
let s = Stream::default_cpu().unwrap();
let t = s.try_clone().expect("test: stream clone");
assert_eq!(s, t);
}
#[test]
fn stream_default_cpu_index_is_non_negative() {
let s = Stream::default_cpu().unwrap();
let i = s.index().unwrap();
assert!(i >= 0, "stream index unexpectedly negative: {i}");
}
#[test]
fn stream_debug_prints_something() {
let s = Stream::default_cpu().unwrap();
let txt = format!("{s:?}");
assert!(txt.starts_with("Stream("), "unexpected debug format: {txt}");
}
#[test]
fn get_default_stream_for_cpu_device() {
let cpu = Device::cpu().unwrap();
let s = get_default_stream(&cpu).expect("default stream for cpu");
assert_eq!(s.device().unwrap().kind().unwrap(), DeviceKind::Cpu);
}
#[test]
fn set_default_stream_round_trip() {
let cpu = Device::cpu().unwrap();
let original = get_default_stream(&cpu).expect("original cpu default");
let new_stream = Stream::new_on(&cpu).expect("new cpu stream");
set_default_stream(&new_stream).expect("set default");
let after = get_default_stream(&cpu).expect("after");
assert_eq!(after, new_stream);
set_default_stream(&original).expect("restore");
}
#[test]
fn device_can_move_to_another_thread() {
let dev = Device::cpu().unwrap();
let handle = std::thread::spawn(move || dev.kind().unwrap());
assert_eq!(handle.join().unwrap(), DeviceKind::Cpu);
}
#[test]
fn device_can_be_shared_across_threads() {
use std::sync::Arc;
let dev = Arc::new(Device::cpu().unwrap());
let d2 = Arc::clone(&dev);
let handle = std::thread::spawn(move || d2.kind().unwrap());
assert_eq!(handle.join().unwrap(), DeviceKind::Cpu);
assert_eq!(dev.kind().unwrap(), DeviceKind::Cpu);
}
#[test]
fn concurrent_set_default_and_current_is_race_free() {
use std::thread;
let _guard = DEFAULT_DEVICE_TEST_GUARD
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let cpu = Device::cpu().unwrap();
let gpu_available = Device::gpu().is_ok();
let original = Device::current().unwrap();
let handles: Vec<_> = (0..8)
.map(|i| {
let cpu = cpu.try_clone().unwrap();
thread::spawn(move || {
for _ in 0..50 {
if i % 2 == 0 {
cpu.set_default().unwrap();
}
let _ = Device::current().unwrap();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let kind = Device::current().unwrap().kind().unwrap();
assert!(
kind == DeviceKind::Cpu || (gpu_available && kind == DeviceKind::Gpu),
"default device left in an incoherent state: {kind:?}",
);
original.set_default().unwrap();
}
#[test]
fn post_clear_array_display_also_panics_fast() {
let outcome = std::thread::spawn(|| {
let a = mlxrs::Array::ones::<f32>(&(2usize, 2)).unwrap(); Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = format!("{a}"); }))
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("post-clear Display must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}
#[test]
fn post_clear_cpu_linalg_panics_fast() {
let outcome = std::thread::spawn(|| {
let a = mlxrs::Array::eye::<f32>(2, None, 0).unwrap(); Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = mlxrs::ops::linalg_full::svd(&a, false); }))
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("post-clear CPU linalg must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}
#[test]
fn post_clear_random_key_panics_fast() {
let outcome = std::thread::spawn(|| {
Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = mlxrs::ops::random::key(0); }))
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("post-clear random::key must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}
#[test]
fn post_clear_random_seed_panics_fast() {
let outcome = std::thread::spawn(|| {
Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = mlxrs::ops::random::seed(0); }))
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("post-clear random::seed must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}
#[test]
fn clear_current_thread_streams_is_end_of_thread_cleanup() {
let worker = std::thread::spawn(|| {
let mut a = mlxrs::Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut s = mlxrs::ops::reduction::sum(&a, false).unwrap();
assert_eq!(s.item::<f32>().unwrap(), 10.0);
let _ = a.to_vec::<f32>().unwrap(); Stream::clear_current_thread_streams().unwrap();
});
worker.join().expect("worker thread panicked / aborted");
let mut b = mlxrs::Array::ones::<f32>(&(4usize, 4)).unwrap();
b.eval().unwrap();
assert_eq!(b.to_vec::<f32>().unwrap(), vec![1.0; 16]);
}
#[test]
fn clear_current_thread_streams_returns_ok_on_idle_thread() {
std::thread::spawn(|| {
Stream::clear_current_thread_streams().unwrap();
})
.join()
.expect("idle clear should not panic");
}
#[test]
fn reusing_a_cleared_thread_panics_fast_with_actionable_message() {
let outcome = std::thread::spawn(|| {
Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(|| {
let _ = mlxrs::Array::ones::<f32>(&(2usize, 2));
})
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("op on a cleared/poisoned thread must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}
#[test]
fn post_clear_every_public_stream_entry_point_panics_fast() {
let outcome = std::thread::spawn(|| {
let dev = Device::gpu().expect("gpu device");
let s = Stream::default_gpu().expect("default gpu stream");
Stream::clear_current_thread_streams().unwrap();
let mut results: Vec<(&str, bool)> = Vec::new();
macro_rules! check_panics {
($label:literal, $body:expr) => {
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = $body;
}))
.is_err();
results.push(($label, panicked));
};
}
check_panics!("default_gpu", Stream::default_gpu());
check_panics!("default_cpu", Stream::default_cpu());
check_panics!("new_on", Stream::new_on(&dev));
check_panics!("try_clone", s.try_clone());
check_panics!("synchronize", s.synchronize());
check_panics!("device", s.device());
check_panics!("index", s.index());
check_panics!("equal", s.equal(&s));
check_panics!(
"get_default_stream",
mlxrs::stream::get_default_stream(&dev)
);
check_panics!("set_default_stream", mlxrs::stream::set_default_stream(&s));
results
})
.join()
.expect("spawned thread itself should not abort");
for (label, panicked) in outcome {
assert!(
panicked,
"post-clear `{label}` must panic-fast on a poisoned thread, but returned"
);
}
}
#[test]
fn post_clear_eval_of_existing_array_also_panics_fast() {
let outcome = std::thread::spawn(|| {
let mut a = mlxrs::Array::ones::<f32>(&(2usize, 2)).unwrap(); Stream::clear_current_thread_streams().unwrap(); std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = a.to_vec::<f32>(); }))
})
.join()
.expect("spawned thread itself should not abort");
let payload = outcome.expect_err("post-clear eval/to_vec must panic");
let msg = payload
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| payload.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("clear_current_thread_streams"),
"panic message should name the culprit API; got: {msg:?}"
);
}