use std::ffi::c_void;
use std::ptr;
use flodl_sys as ffi;
use crate::tensor::{check_err, Result};
#[derive(Clone, Copy, Debug, Default)]
pub struct MemPoolId {
pub hi: u64,
pub lo: u64,
}
#[derive(Clone, Copy, Debug, Default)]
#[repr(i32)]
pub enum CaptureMode {
Global = 0,
ThreadLocal = 1,
#[default]
Relaxed = 2,
}
pub struct CudaGraph {
ptr: *mut c_void,
}
impl CudaGraph {
pub fn new() -> Result<Self> {
let mut ptr: *mut c_void = ptr::null_mut();
let err = unsafe { ffi::flodl_cuda_graph_new(&mut ptr) };
check_err(err)?;
Ok(CudaGraph { ptr })
}
pub fn capture_begin(&mut self, pool: Option<MemPoolId>, mode: CaptureMode) -> Result<()> {
let (hi, lo) = pool.map_or((0, 0), |p| (p.hi, p.lo));
let err = unsafe {
ffi::flodl_cuda_graph_capture_begin(self.ptr, hi, lo, mode as i32)
};
check_err(err)
}
pub fn capture_end(&mut self) -> Result<()> {
let err = unsafe { ffi::flodl_cuda_graph_capture_end(self.ptr) };
check_err(err)
}
pub fn replay(&self) -> Result<()> {
let err = unsafe { ffi::flodl_cuda_graph_replay(self.ptr) };
check_err(err)
}
pub fn reset(&mut self) -> Result<()> {
let err = unsafe { ffi::flodl_cuda_graph_reset(self.ptr) };
check_err(err)
}
pub fn pool(&self) -> MemPoolId {
let mut hi: u64 = 0;
let mut lo: u64 = 0;
unsafe { ffi::flodl_cuda_graph_pool(self.ptr, &mut hi, &mut lo) };
MemPoolId { hi, lo }
}
}
impl Drop for CudaGraph {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::flodl_cuda_graph_delete(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
pub fn cuda_graph_pool_handle() -> MemPoolId {
let mut hi: u64 = 0;
let mut lo: u64 = 0;
unsafe { ffi::flodl_cuda_graph_pool_handle(&mut hi, &mut lo) };
MemPoolId { hi, lo }
}
pub fn cuda_graph_capture<F>(
warmup_runs: usize,
pool: Option<MemPoolId>,
mut f: F,
) -> Result<CudaGraph>
where
F: FnMut() -> Result<()>,
{
for _ in 0..warmup_runs {
f()?;
}
crate::tensor::cuda_synchronize(0);
let mut graph = CudaGraph::new()?;
graph.capture_begin(pool, CaptureMode::default())?;
f()?;
graph.capture_end()?;
Ok(graph)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, test_device, test_opts};
use std::sync::Mutex;
static GRAPH_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn test_copy_basic() {
let opts = test_opts();
let src = Tensor::ones(&[3, 4], opts).unwrap();
let dst = Tensor::zeros(&[3, 4], opts).unwrap();
dst.copy_(&src, false).unwrap();
let buf = dst.to_f32_vec().unwrap();
assert!(buf.iter().all(|&v| v == 1.0), "copy_ should have filled dst with 1.0");
}
#[test]
fn test_cuda_graph_fails_on_cpu() {
if test_device().is_cuda() {
return; }
let result = CudaGraph::new();
assert!(result.is_err(), "CudaGraph::new() should fail on CPU");
}
#[test]
#[ignore = "CUDA graph capture blocks device-wide RNG; run with: fdl cuda-test-graph"]
fn test_cuda_graph_capture_replay() {
if !test_device().is_cuda() {
return;
}
let _lock = GRAPH_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let opts = test_opts();
let a = Tensor::ones(&[4, 4], opts).unwrap();
let b = Tensor::ones(&[4, 4], opts).unwrap();
let c = Tensor::zeros(&[4, 4], opts).unwrap();
let graph = cuda_graph_capture(1, None, || {
let sum = a.add(&b)?;
c.copy_(&sum, false)?;
Ok(())
}).unwrap();
graph.replay().unwrap();
crate::tensor::cuda_synchronize(0);
let buf = c.to_f32_vec().unwrap();
assert!(buf.iter().all(|&v| (v - 2.0).abs() < 1e-5),
"c should be 2.0 after replay, got {:?}", &buf[..4]);
}
#[test]
#[ignore = "CUDA graph capture blocks device-wide RNG; run with: fdl cuda-test-graph"]
fn test_cuda_graph_with_model() {
if !test_device().is_cuda() {
return;
}
let _lock = GRAPH_LOCK.lock().unwrap_or_else(|e| e.into_inner());
use crate::autograd::Variable;
use crate::nn::{Linear, Module, mse_loss, Adam, Optimizer};
let dev = test_device();
let model = Linear::on_device(4, 2, dev).unwrap();
let params = model.parameters();
let mut optimizer = Adam::new(¶ms, 0.01);
let init_data = params[0].variable.data().to_f32_vec().unwrap();
let opts = test_opts();
let static_input = Tensor::randn(&[8, 4], opts).unwrap();
let static_target = Tensor::randn(&[8, 2], opts).unwrap();
let graph = cuda_graph_capture(3, None, || {
let inp = Variable::new(static_input.clone(), true);
let tgt = Variable::new(static_target.clone(), false);
optimizer.zero_grad();
let pred = model.forward(&inp)?;
let loss = mse_loss(&pred, &tgt)?;
loss.backward()?;
optimizer.step()
}).unwrap();
for _ in 0..5 {
graph.replay().unwrap();
}
crate::tensor::cuda_synchronize(0);
let final_data = params[0].variable.data().to_f32_vec().unwrap();
assert_ne!(init_data, final_data, "params should have changed after graph replay");
}
#[test]
#[ignore = "CUDA graph capture blocks device-wide RNG; run with: fdl cuda-test-graph"]
fn test_cuda_graph_pool_handle() {
if !test_device().is_cuda() {
return;
}
let _lock = GRAPH_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let pool = cuda_graph_pool_handle();
assert!(pool.hi != 0 || pool.lo != 0, "pool handle should be nonzero");
}
#[test]
#[ignore = "CUDA graph capture blocks device-wide RNG; run with: fdl cuda-test-graph"]
fn test_cuda_graph_reset_recapture() {
if !test_device().is_cuda() {
return;
}
let _lock = GRAPH_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let opts = test_opts();
let a = Tensor::ones(&[4], opts).unwrap();
let b = Tensor::ones(&[4], opts).unwrap();
let c = Tensor::zeros(&[4], opts).unwrap();
let mut graph = cuda_graph_capture(1, None, || {
let sum = a.add(&b)?;
c.copy_(&sum, false)?;
Ok(())
}).unwrap();
graph.replay().unwrap();
crate::tensor::cuda_synchronize(0);
let buf = c.to_f32_vec().unwrap();
assert!(buf.iter().all(|&v| (v - 2.0).abs() < 1e-5));
graph.reset().unwrap();
let three = Tensor::full(&[4], 3.0, opts).unwrap();
graph.capture_begin(None, CaptureMode::default()).unwrap();
let prod = a.mul(&three).unwrap();
c.copy_(&prod, false).unwrap();
graph.capture_end().unwrap();
graph.replay().unwrap();
crate::tensor::cuda_synchronize(0);
let buf = c.to_f32_vec().unwrap();
assert!(buf.iter().all(|&v| (v - 3.0).abs() < 1e-5),
"after recapture, c should be 3.0, got {:?}", &buf[..4]);
}
}