use cuda_async::device_context::{
get_default_device, init_device_contexts, set_default_device, DEFAULT_DEVICE_ID,
};
use cuda_async::device_future::DeviceFuture;
use cuda_async::device_operation::Value;
use cuda_async::error::{device_assert, device_error, DeviceError};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
#[test]
fn device_error_returns_context_variant() {
let err = device_error(3, "something went wrong");
assert_eq!(
err,
DeviceError::Context {
device_id: 3,
message: "something went wrong".to_string(),
}
);
}
#[test]
fn device_assert_ok_when_predicate_is_true() {
let result = device_assert(0, true, "should not fire");
assert!(result.is_ok());
}
#[test]
fn device_assert_err_when_predicate_is_false() {
let result = device_assert(7, false, "assertion failed");
let err = result.unwrap_err();
assert_eq!(
err,
DeviceError::Context {
device_id: 7,
message: "assertion failed".to_string(),
}
);
}
#[test]
fn context_error_display_contains_device_id_and_message() {
let err = DeviceError::Context {
device_id: 42,
message: "bad thing".to_string(),
};
let display = format!("{err}");
assert!(
display.contains("device_id=42"),
"expected device_id in display, got: {display}"
);
assert!(
display.contains("bad thing"),
"expected message in display, got: {display}"
);
}
#[test]
fn internal_error_display() {
let err = DeviceError::Internal("oops".to_string());
let display = format!("{err}");
assert!(
display.contains("oops"),
"expected message in display, got: {display}"
);
}
#[test]
fn launch_error_display() {
let err = DeviceError::Launch("kernel failed".to_string());
let display = format!("{err}");
assert!(
display.contains("kernel failed"),
"expected message in display, got: {display}"
);
}
#[test]
fn scheduling_error_display() {
let err = DeviceError::Scheduling("no streams".to_string());
let display = format!("{err}");
assert!(
display.contains("no streams"),
"expected message in display, got: {display}"
);
}
#[test]
fn anyhow_error_converts_to_device_error() {
let anyhow_err = anyhow::anyhow!("something from anyhow");
let device_err: DeviceError = anyhow_err.into();
match &device_err {
DeviceError::Anyhow(msg) => {
assert!(
msg.contains("something from anyhow"),
"expected anyhow message, got: {msg}"
);
}
other => panic!("expected Anyhow variant, got: {other:?}"),
}
}
fn on_fresh_thread<F: FnOnce() + Send + 'static>(f: F) {
std::thread::spawn(f).join().expect("test thread panicked");
}
#[test]
fn double_init_returns_context_already_initialized() {
on_fresh_thread(|| {
init_device_contexts(0, 1).expect("first init should succeed");
let err = init_device_contexts(0, 1).unwrap_err();
match &err {
DeviceError::Context { device_id, message } => {
assert_eq!(*device_id, 0);
assert!(
message.contains("Context already initialized"),
"unexpected message: {message}"
);
}
other => panic!("expected Context variant, got: {other:?}"),
}
});
}
#[test]
fn default_device_id_is_zero() {
on_fresh_thread(|| {
assert_eq!(get_default_device(), DEFAULT_DEVICE_ID);
assert_eq!(DEFAULT_DEVICE_ID, 0);
});
}
#[test]
fn set_default_device_changes_value() {
on_fresh_thread(|| {
set_default_device(5);
assert_eq!(get_default_device(), 5);
});
}
#[test]
fn new_device_context_with_invalid_device_returns_driver_error() {
let result = cuda_core::CudaContext::new(9999);
let result = result.map_err(DeviceError::Driver);
match result {
Err(DeviceError::Driver(_)) => { }
Err(other) => panic!("expected Driver variant, got: {other:?}"),
Ok(_) => panic!("expected error for invalid device 9999, but got Ok"),
}
}
#[test]
fn device_error_is_cloneable_and_eq() {
let err = DeviceError::Context {
device_id: 1,
message: "test".to_string(),
};
let cloned = err.clone();
assert_eq!(err, cloned);
}
#[test]
fn different_variants_are_not_equal() {
let a = DeviceError::Internal("x".to_string());
let b = DeviceError::Launch("x".to_string());
assert_ne!(a, b);
}
fn noop_waker() -> Waker {
fn noop(_: *const ()) {}
fn clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, noop, noop, noop);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
#[test]
fn failed_future_returns_err_on_first_poll() {
let error = DeviceError::Internal("test failure".to_string());
let mut future: DeviceFuture<(), Value<()>> = DeviceFuture::failed(error);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut future).poll(&mut cx);
match result {
Poll::Ready(Err(DeviceError::Internal(msg))) => {
assert_eq!(msg, "test failure");
}
Poll::Ready(Ok(_)) => panic!("expected Err, got Ok"),
Poll::Ready(Err(other)) => panic!("expected Internal variant, got: {other:?}"),
Poll::Pending => panic!("expected Ready, got Pending"),
}
}
#[test]
fn failed_future_is_immediately_ready() {
let error = DeviceError::Internal("done".to_string());
let mut future: DeviceFuture<(), Value<()>> = DeviceFuture::failed(error);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut future).poll(&mut cx);
assert!(
matches!(result, Poll::Ready(Err(DeviceError::Internal(_)))),
"expected Poll::Ready(Err(Internal(...))), got: {result:?}"
);
}
#[test]
#[should_panic(expected = "Poll called after completion")]
fn failed_future_panics_on_second_poll() {
let error = DeviceError::Internal("once".to_string());
let mut future: DeviceFuture<(), Value<()>> = DeviceFuture::failed(error);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let _ = Pin::new(&mut future).poll(&mut cx);
let _ = Pin::new(&mut future).poll(&mut cx);
}
#[test]
fn failed_future_preserves_error_variant() {
let error = DeviceError::Context {
device_id: 42,
message: "device gone".to_string(),
};
let mut future: DeviceFuture<String, Value<String>> = DeviceFuture::failed(error);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut future).poll(&mut cx);
match result {
Poll::Ready(Err(DeviceError::Context { device_id, message })) => {
assert_eq!(device_id, 42);
assert_eq!(message, "device gone");
}
other => panic!("expected Context error, got: {other:?}"),
}
}