#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CaptureMode, CudaGraphExec, CudaStream};
#[cfg(feature = "cuda")]
pub(crate) struct BackwardGraphState {
pub exec: CudaGraphExec,
pub cached_seq_len: usize,
}
#[cfg(feature = "cuda")]
pub(crate) fn use_backward_graph() -> bool {
static USE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*USE.get_or_init(|| std::env::var("CUDA_GRAPH").as_deref() == Ok("1"))
}
#[cfg(feature = "cuda")]
pub(crate) fn try_capture_backward<F>(
stream: &CudaStream,
seq_len: usize,
backward_fn: F,
) -> Option<BackwardGraphState>
where
F: FnOnce() -> Option<()>,
{
stream
.begin_capture(CaptureMode::ThreadLocal)
.map_err(|e| eprintln!("[CUDA] Backward graph capture begin failed: {e}"))
.ok()?;
let result = backward_fn();
if result.is_none() {
let _ = stream.end_capture();
eprintln!("[CUDA] Backward graph capture aborted: backward failed");
return None;
}
match stream.end_capture() {
Ok(graph) => match graph.instantiate() {
Ok(exec) => {
eprintln!("[CUDA] Backward graph captured: seq_len={seq_len}");
Some(BackwardGraphState { exec, cached_seq_len: seq_len })
}
Err(e) => {
eprintln!("[CUDA] Backward graph instantiate failed: {e}");
None
}
},
Err(e) => {
eprintln!("[CUDA] Backward graph end_capture failed: {e}");
None
}
}
}
#[cfg(feature = "cuda")]
pub(crate) fn replay_backward(state: &BackwardGraphState, stream: &CudaStream) -> Option<()> {
state
.exec
.launch(stream.raw())
.map_err(|e| eprintln!("[CUDA] Backward graph replay failed: {e}"))
.ok()
}