use cuda_async::cuda_graph::CudaGraph;
use cuda_async::device_operation::{value, DeviceOp};
use cuda_async::error::DeviceError;
fn has_gpu() -> bool {
cuda_core::CudaContext::device_count()
.map(|n| n > 0)
.unwrap_or(false)
}
fn on_fresh_thread<F: FnOnce() + Send + 'static>(f: F) {
std::thread::spawn(f).join().expect("test thread panicked");
}
#[test]
fn scope_empty_closure() {
if !has_gpu() {
return;
}
on_fresh_thread(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let graph = CudaGraph::scope(&stream, |_s| Ok(())).unwrap();
graph.launch().unwrap();
});
}
#[test]
fn scope_records_value_ops() {
if !has_gpu() {
return;
}
on_fresh_thread(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let mut recorded = Vec::new();
let graph = CudaGraph::scope(&stream, |s| {
let a = s.record(value(42))?;
let b = s.record(value("hello"))?;
recorded.push(a);
recorded.push(b.len() as i32);
Ok(())
})
.unwrap();
assert_eq!(recorded, vec![42, 5]);
graph.launch().unwrap();
});
}
#[test]
fn scope_error_propagation() {
if !has_gpu() {
return;
}
on_fresh_thread(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let result = CudaGraph::scope(&stream, |_s| {
Err(DeviceError::Internal("test error".into()))
});
assert!(result.is_err());
match result {
Err(DeviceError::Internal(msg)) => {
assert!(
msg.contains("test error"),
"Expected test error, got: {msg}"
);
}
Err(e) => panic!("Expected Internal error, got: {e}"),
Ok(_) => panic!("Expected error, got Ok"),
}
});
}
#[test]
fn scope_panic_safety() {
if !has_gpu() {
return;
}
let result = std::thread::spawn(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
CudaGraph::scope(&stream, |_s| {
panic!("intentional panic in scope");
})
}));
stream.synchronize().unwrap();
})
.join();
assert!(
result.is_ok(),
"Thread should not panic after scope cleanup"
);
}
#[test]
fn scope_multiple_launches() {
if !has_gpu() {
return;
}
on_fresh_thread(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let graph = CudaGraph::scope(&stream, |_s| Ok(())).unwrap();
for _ in 0..10 {
graph.launch().unwrap();
}
});
}
#[test]
fn scope_nested_execution_rejected() {
if !has_gpu() {
return;
}
on_fresh_thread(|| {
let ctx = cuda_core::CudaContext::new(0).unwrap();
let stream = ctx.new_stream().unwrap();
let other_stream = ctx.new_stream().unwrap();
let result = CudaGraph::scope(&stream, |_s| {
let _ = value(42).sync_on(&stream)?;
Ok(())
});
assert!(result.is_err(), "nested sync_on should fail");
let result = CudaGraph::scope(&stream, |_s| {
let _ = value(42).sync_on(&other_stream)?;
Ok(())
});
assert!(result.is_err(), "nested sync_on (other stream) should fail");
let result = CudaGraph::scope(&stream, |_s| {
value(42).sync()?;
Ok(())
});
assert!(result.is_err(), "nested sync should fail");
});
}