entrenar/finetune/backward_graph.rs
1//! PMAT-464/477: CUDA graph capture for NF4 backward pass.
2//!
3//! With fused LoRA gradient clipping (PMAT-477, zero D2H sync), the entire
4//! backward loop is capturable: backward + fused_clip + optimizer are all
5//! async GPU kernel launches with no host-device synchronization.
6//!
7//! This module implements capture/replay of the 28-layer backward loop,
8//! eliminating per-step kernel launch overhead.
9//!
10//! # Contract: cuda-graph-backward-v1.yaml
11//!
12//! - F-GRAPH-BWD-001: Loss trajectory matches ungraphed within 0.1
13//! - F-GRAPH-BWD-002: Graph capture succeeds (no CUDA_ERROR)
14//! - F-GRAPH-BWD-003: Throughput >= 1.10x ungraphed at batch=4
15
16#[cfg(feature = "cuda")]
17use trueno_gpu::driver::{CaptureMode, CudaGraphExec, CudaStream};
18
19/// Cached backward graph state.
20#[cfg(feature = "cuda")]
21pub(crate) struct BackwardGraphState {
22 /// Cached CUDA graph executable for backward replay
23 pub exec: CudaGraphExec,
24 /// seq_len this graph was captured at (invalidate on change)
25 pub cached_seq_len: usize,
26}
27
28/// Check if backward graph capture is enabled via environment variable.
29#[cfg(feature = "cuda")]
30pub(crate) fn use_backward_graph() -> bool {
31 static USE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
32 *USE.get_or_init(|| std::env::var("CUDA_GRAPH").as_deref() == Ok("1"))
33}
34
35/// Try to capture the backward loop into a CUDA graph.
36///
37/// Called on the first backward at a given seq_len. Records all kernel
38/// launches (backward + fused_clip + optimizer per layer) into a graph.
39///
40/// # Returns
41///
42/// `Some(BackwardGraphState)` on successful capture, `None` on failure.
43#[cfg(feature = "cuda")]
44pub(crate) fn try_capture_backward<F>(
45 stream: &CudaStream,
46 seq_len: usize,
47 backward_fn: F,
48) -> Option<BackwardGraphState>
49where
50 F: FnOnce() -> Option<()>,
51{
52 // Pre-allocate cuBLAS workspace must have happened before this point (PMAT-063)
53 stream
54 .begin_capture(CaptureMode::ThreadLocal)
55 .map_err(|e| eprintln!("[CUDA] Backward graph capture begin failed: {e}"))
56 .ok()?;
57
58 let result = backward_fn();
59
60 if result.is_none() {
61 // Backward failed during capture — abort
62 let _ = stream.end_capture();
63 eprintln!("[CUDA] Backward graph capture aborted: backward failed");
64 return None;
65 }
66
67 match stream.end_capture() {
68 Ok(graph) => match graph.instantiate() {
69 Ok(exec) => {
70 eprintln!("[CUDA] Backward graph captured: seq_len={seq_len}");
71 Some(BackwardGraphState { exec, cached_seq_len: seq_len })
72 }
73 Err(e) => {
74 eprintln!("[CUDA] Backward graph instantiate failed: {e}");
75 None
76 }
77 },
78 Err(e) => {
79 eprintln!("[CUDA] Backward graph end_capture failed: {e}");
80 None
81 }
82 }
83}
84
85/// Replay a previously captured backward graph.
86///
87/// Must be called with the same seq_len the graph was captured at.
88/// All buffer pointers must remain valid (pre-allocated training state).
89#[cfg(feature = "cuda")]
90pub(crate) fn replay_backward(state: &BackwardGraphState, stream: &CudaStream) -> Option<()> {
91 state
92 .exec
93 .launch(stream.raw())
94 .map_err(|e| eprintln!("[CUDA] Backward graph replay failed: {e}"))
95 .ok()
96}