#![cfg(feature = "cuda")]
use std::sync::{Arc, Mutex, MutexGuard};
use ferrotorch_gpu::buffer::CudaBuffer;
use ferrotorch_gpu::device::GpuDevice;
use ferrotorch_gpu::graph::{CapturePool, begin_capture_with_pool, end_capture_with_pool};
use ferrotorch_gpu::kernels::gpu_add_into;
use ferrotorch_gpu::transfer::{alloc_zeros_f32, cpu_to_gpu};
fn capture_lock() -> MutexGuard<'static, ()> {
static CAPTURE_MUTEX: Mutex<()> = Mutex::new(());
CAPTURE_MUTEX
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("CUDA device 0 must be available for these tests")
}
#[test]
fn captured_graph_holds_pool_buffers_alive() {
let _guard = capture_lock();
let device = dev();
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let mut out: CudaBuffer<f32> = alloc_zeros_f32(4, &device).expect("alloc out");
let a = cpu_to_gpu(&a_data, &device).expect("upload a");
let b = cpu_to_gpu(&b_data, &device).expect("upload b");
let pool = Arc::new(CapturePool::new());
let stream = device.context().new_stream().expect("non-blocking stream");
begin_capture_with_pool(&pool, &stream).expect("begin capture");
gpu_add_into(&a, &b, &mut out, &device).expect("add_into during capture");
let graph = end_capture_with_pool(&stream, Arc::clone(&pool)).expect("end_capture_with_pool");
pool.record_buffer(a);
pool.record_buffer(b);
pool.record_buffer(out);
assert_eq!(pool.buffer_count(), 3);
assert!(graph.has_pool());
assert_eq!(graph.pool_buffer_count(), 3);
graph.launch().expect("replay 1");
graph.launch().expect("replay 2");
graph.launch().expect("replay 3");
}
#[test]
fn dropping_graph_releases_pool_buffers() {
let _guard = capture_lock();
let device = dev();
let _stream = device.context().new_stream().expect("stream");
let sentinel = Arc::new(42u8);
let pool = Arc::new(CapturePool::new());
pool.record_buffer(Arc::clone(&sentinel));
assert_eq!(Arc::strong_count(&sentinel), 2);
let stream = device.context().new_stream().expect("non-blocking stream");
begin_capture_with_pool(&pool, &stream).expect("begin");
let graph = end_capture_with_pool(&stream, Arc::clone(&pool)).expect("end");
assert!(graph.has_pool());
assert_eq!(Arc::strong_count(&sentinel), 2);
drop(graph);
assert_eq!(Arc::strong_count(&sentinel), 2);
drop(pool);
assert_eq!(Arc::strong_count(&sentinel), 1);
}
#[test]
fn pool_seal_blocks_begin_capture_with_pool() {
let device = dev();
let stream = device.context().new_stream().expect("non-blocking stream");
let pool = Arc::new(CapturePool::new());
pool.seal();
let result = begin_capture_with_pool(&pool, &stream);
assert!(
result.is_err(),
"sealed pool must reject begin_capture_with_pool"
);
}
#[test]
fn captured_graph_without_pool_has_zero_buffer_count() {
let _guard = capture_lock();
let device = dev();
let stream = device.context().new_stream().expect("non-blocking stream");
use ferrotorch_gpu::graph::{begin_capture, end_capture};
begin_capture(&stream).expect("begin");
let graph = end_capture(&stream).expect("end");
assert!(!graph.has_pool());
assert_eq!(graph.pool_buffer_count(), 0);
}
#[test]
fn pool_buffer_count_grows_with_record_buffer() {
let pool = Arc::new(CapturePool::new());
assert_eq!(pool.buffer_count(), 0);
pool.record_buffer(vec![0u8; 16]);
pool.record_buffer(vec![0u8; 32]);
pool.record_buffer(vec![0u8; 64]);
assert_eq!(pool.buffer_count(), 3);
}