Skip to main content

entrenar/finetune/
backward_graph.rs

1//! PMAT-464/477: CUDA graph capture for NF4 backward pass.
2//!
3//! With fused LoRA gradient clipping (PMAT-477, zero D2H sync), the entire
4//! backward loop is capturable: backward + fused_clip + optimizer are all
5//! async GPU kernel launches with no host-device synchronization.
6//!
7//! This module implements capture/replay of the 28-layer backward loop,
8//! eliminating per-step kernel launch overhead.
9//!
10//! # Contract: cuda-graph-backward-v1.yaml
11//!
12//! - F-GRAPH-BWD-001: Loss trajectory matches ungraphed within 0.1
13//! - F-GRAPH-BWD-002: Graph capture succeeds (no CUDA_ERROR)
14//! - F-GRAPH-BWD-003: Throughput >= 1.10x ungraphed at batch=4
15
16#[cfg(feature = "cuda")]
17use trueno_gpu::driver::{CaptureMode, CudaGraphExec, CudaStream};
18
19/// Cached backward graph state.
20#[cfg(feature = "cuda")]
21pub(crate) struct BackwardGraphState {
22    /// Cached CUDA graph executable for backward replay
23    pub exec: CudaGraphExec,
24    /// seq_len this graph was captured at (invalidate on change)
25    pub cached_seq_len: usize,
26}
27
28/// Check if backward graph capture is enabled via environment variable.
29#[cfg(feature = "cuda")]
30pub(crate) fn use_backward_graph() -> bool {
31    static USE: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
32    *USE.get_or_init(|| std::env::var("CUDA_GRAPH").as_deref() == Ok("1"))
33}
34
35/// Try to capture the backward loop into a CUDA graph.
36///
37/// Called on the first backward at a given seq_len. Records all kernel
38/// launches (backward + fused_clip + optimizer per layer) into a graph.
39///
40/// # Returns
41///
42/// `Some(BackwardGraphState)` on successful capture, `None` on failure.
43#[cfg(feature = "cuda")]
44pub(crate) fn try_capture_backward<F>(
45    stream: &CudaStream,
46    seq_len: usize,
47    backward_fn: F,
48) -> Option<BackwardGraphState>
49where
50    F: FnOnce() -> Option<()>,
51{
52    // Pre-allocate cuBLAS workspace must have happened before this point (PMAT-063)
53    stream
54        .begin_capture(CaptureMode::ThreadLocal)
55        .map_err(|e| eprintln!("[CUDA] Backward graph capture begin failed: {e}"))
56        .ok()?;
57
58    let result = backward_fn();
59
60    if result.is_none() {
61        // Backward failed during capture — abort
62        let _ = stream.end_capture();
63        eprintln!("[CUDA] Backward graph capture aborted: backward failed");
64        return None;
65    }
66
67    match stream.end_capture() {
68        Ok(graph) => match graph.instantiate() {
69            Ok(exec) => {
70                eprintln!("[CUDA] Backward graph captured: seq_len={seq_len}");
71                Some(BackwardGraphState { exec, cached_seq_len: seq_len })
72            }
73            Err(e) => {
74                eprintln!("[CUDA] Backward graph instantiate failed: {e}");
75                None
76            }
77        },
78        Err(e) => {
79            eprintln!("[CUDA] Backward graph end_capture failed: {e}");
80            None
81        }
82    }
83}
84
85/// Replay a previously captured backward graph.
86///
87/// Must be called with the same seq_len the graph was captured at.
88/// All buffer pointers must remain valid (pre-allocated training state).
89#[cfg(feature = "cuda")]
90pub(crate) fn replay_backward(state: &BackwardGraphState, stream: &CudaStream) -> Option<()> {
91    state
92        .exec
93        .launch(stream.raw())
94        .map_err(|e| eprintln!("[CUDA] Backward graph replay failed: {e}"))
95        .ok()
96}