Skip to main content

entrenar/autograd/cuda_backward/
cache.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Arc, Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13
14// trueno#200: Blackwell Blackwell skip requires local trueno-gpu with from_ptx fix.
15// Add [patch.crates-io] trueno-gpu = { path = "../trueno/trueno-gpu" } to Cargo.toml.
16
17use super::super::cuda_tensor::{CudaTensorError, Result};
18
19/// Cached compiled CUDA modules for backward kernels
20#[cfg(feature = "cuda")]
21pub(super) static KERNEL_CACHE: OnceLock<Mutex<KernelCache>> = OnceLock::new();
22
23/// Cache for compiled backward kernel modules
24///
25/// # Contract: F-PTX-001 (Target Parity)
26///
27/// Same invariant as forward cache — PTX target must match device.
28#[cfg(feature = "cuda")]
29pub(super) struct KernelCache {
30    ctx: Arc<CudaContext>,
31    modules: HashMap<String, CudaModule>,
32    sm_target: String,
33    /// cuBLAS handle for backward GEMMs (ALB-075). Uses CUBLAS_DEFAULT_MATH
34    /// (SIMD, no tensor cores) per ALB-076/trueno#170 to avoid NaN in transposed GEMMs.
35    cublas: Option<CublasHandle>,
36}
37
38#[cfg(feature = "cuda")]
39impl KernelCache {
40    pub(super) fn new(ctx: Arc<CudaContext>) -> Self {
41        let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
42        let cublas = CublasHandle::new(&ctx).ok();
43        Self { ctx, modules: HashMap::new(), sm_target, cublas }
44    }
45
46    /// Get a reference to the cuBLAS handle, if available.
47    pub(super) fn cublas(&self) -> Option<&CublasHandle> {
48        self.cublas.as_ref()
49    }
50
51    /// Bind cuBLAS to a stream for the current training step.
52    pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
53        if let Some(ref handle) = self.cublas {
54            handle.set_stream(stream).map_err(|e| {
55                CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
56            })?;
57        }
58        Ok(())
59    }
60
61    pub(super) fn sm_target(&self) -> &str {
62        &self.sm_target
63    }
64
65    /// Get the CudaContext this cache was built against. Used by callers that
66    /// need to allocate temp `GpuBuffer<T>` (which require a context) — e.g.
67    /// `rms_norm_backward`'s per-row partial buffer for the FALSIFY-GPUTRAIN-006
68    /// deterministic two-stage gamma reduction.
69    pub(super) fn ctx(&self) -> &Arc<CudaContext> {
70        &self.ctx
71    }
72
73    /// Look up a previously compiled module by key (KAIZEN-058).
74    pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
75        self.modules.get_mut(name)
76    }
77
78    pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
79        use std::collections::hash_map::Entry;
80
81        // F-PTX-001: Validate PTX target matches device
82        if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
83            let ptx_target = target_line.trim().trim_start_matches(".target ");
84            if ptx_target != self.sm_target {
85                return Err(CudaTensorError::KernelError(format!(
86                    "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'",
87                    self.sm_target
88                )));
89            }
90        }
91
92        match self.modules.entry(name.to_string()) {
93            Entry::Occupied(e) => Ok(e.into_mut()),
94            Entry::Vacant(e) => {
95                eprintln!("[BWD-CACHE] Compiling '{name}' (ptx_len={})", ptx.len());
96
97                // trueno#200: On Blackwell, CudaModule::from_ptx uses cuModuleLoadDataEx
98                // which poisons the CUDA context. Bypass it entirely with direct
99                // cuModuleLoadData via the raw driver API.
100                let (major, _minor) = self.ctx.compute_capability().map_err(|e| {
101                    CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
102                })?;
103
104                // trueno#200: Use from_ptx_direct on Blackwell to avoid context poisoning
105                let module = if major >= 12 {
106                    CudaModule::from_ptx_direct(&self.ctx, ptx)
107                } else {
108                    CudaModule::from_ptx(&self.ctx, ptx)
109                }
110                .map_err(|err| {
111                    eprintln!("[BWD-CACHE] FAILED '{name}': {err:?}");
112                    CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
113                })?;
114                eprintln!("[BWD-CACHE] OK '{name}'");
115                Ok(e.insert(module))
116            }
117        }
118    }
119}
120
121/// Initialize kernel cache with CUDA context
122#[cfg(feature = "cuda")]
123pub fn init_kernel_cache(ctx: Arc<CudaContext>) -> Result<()> {
124    KERNEL_CACHE.get_or_init(|| Mutex::new(KernelCache::new(ctx)));
125    Ok(())
126}
127
128/// Bind cuBLAS handle in the backward cache to a stream (ALB-075).
129#[cfg(feature = "cuda")]
130pub fn set_backward_cublas_stream(stream: &CudaStream) -> Result<()> {
131    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
132    let cache = cache.lock().map_err(|_err| {
133        CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
134    })?;
135    cache.set_cublas_stream(stream)
136}
137
138/// Pre-warm backward GEMM kernels for training gradient computation (ENT-153).
139///
140/// Covers both LoRA-only shapes (NF4 QLoRA) and full fp32 backward shapes.
141///
142/// ## LoRA backward shapes
143/// `gemm_backward_b` (weight grads): `(S,R,qd)`, `(S,R,kv)`, `(S,H,R)`
144/// `gemm_backward_a` (input grads): `(S,qd,R)`, `(S,kv,R)`, `(S,R,H)`
145///
146/// ## Full fp32 backward shapes (non-NF4)
147/// `gemm_backward_a`: `(S,I,H)` down, `(S,H,I)` gate/up, `(S,H,H)` Q/O, `(S,kv,H)` K/V
148/// `gemm_backward_b`: `(S,I,H)` grad_w_down, `(S,H,I)` grad_w_gate/up,
149///                     `(S,H,H)` grad_w_q/o, `(S,kv,H)` grad_w_k/v
150///
151/// All must be JIT-compiled before block upload fills VRAM (C-PREWARM-001).
152#[cfg(feature = "cuda")]
153pub fn pre_warm_lora_backward_kernels(
154    hidden_size: usize,
155    q_dim: usize,
156    kv_hidden_size: usize,
157    max_seq_len: usize,
158    lora_rank: usize,
159    intermediate_size: usize,
160    num_heads: usize,
161    quantize_nf4: bool,
162) -> Result<()> {
163    use trueno_gpu::kernels::backward::{
164        BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, GemmBackwardAKernel,
165        GemmBackwardBKernel, SiluBackwardKernel,
166    };
167    use trueno_gpu::kernels::Kernel;
168
169    eprintln!("[BWD-PREWARM] Called with lora_rank={lora_rank}, hidden={hidden_size}, inter={intermediate_size}");
170    if lora_rank == 0 {
171        eprintln!("[BWD-PREWARM] Skipping (lora_rank=0)");
172        return Ok(());
173    }
174
175    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
176    let mut cache = cache.lock().map_err(|_err| {
177        CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
178    })?;
179
180    let s = max_seq_len as u32;
181    let h = hidden_size as u32;
182    let r = lora_rank as u32;
183    let qd = q_dim as u32;
184    let kv = kv_hidden_size as u32;
185    let i = intermediate_size as u32;
186    let nh = num_heads as u32;
187
188    let mut count = 0u32;
189    let target = cache.sm_target().to_string();
190
191    macro_rules! warm {
192        ($key:expr, $kernel:expr) => {{
193            let key = $key;
194            let ptx = $kernel.emit_ptx_for_target(&target);
195            cache.get_or_compile(&key, &ptx)?;
196            count += 1;
197        }};
198    }
199
200    // Tile size must match BACKWARD_TILE_SIZE in gemm.rs (C-TILE-BWD-007)
201    let tile: u32 = 16;
202
203    // ── LoRA backward shapes (always needed) ──
204    // gemm_backward_b: weight gradients
205    warm!(
206        format!("gemm_backward_b_{s}_{r}_{qd}"),
207        GemmBackwardBKernel::tiled_unrolled(s, r, qd, tile)
208    );
209    if kv != qd {
210        warm!(
211            format!("gemm_backward_b_{s}_{r}_{kv}"),
212            GemmBackwardBKernel::tiled_unrolled(s, r, kv, tile)
213        );
214    }
215    warm!(
216        format!("gemm_backward_b_{s}_{h}_{r}"),
217        GemmBackwardBKernel::tiled_unrolled(s, h, r, tile)
218    );
219
220    // gemm_backward_a: input gradients
221    warm!(
222        format!("gemm_backward_a_{s}_{qd}_{r}"),
223        GemmBackwardAKernel::tiled_unrolled(s, qd, r, tile)
224    );
225    if kv != qd {
226        warm!(
227            format!("gemm_backward_a_{s}_{kv}_{r}"),
228            GemmBackwardAKernel::tiled_unrolled(s, kv, r, tile)
229        );
230    }
231    warm!(
232        format!("gemm_backward_a_{s}_{r}_{h}"),
233        GemmBackwardAKernel::tiled_unrolled(s, r, h, tile)
234    );
235
236    // ── Full fp32 backward shapes (non-NF4 mode) ──
237    if !quantize_nf4 {
238        // Attention backward: Q/O (S,H,H), K/V (S,kv,H)
239        warm!(
240            format!("gemm_backward_a_{s}_{h}_{h}"),
241            GemmBackwardAKernel::tiled_unrolled(s, h, h, tile)
242        );
243        warm!(
244            format!("gemm_backward_b_{s}_{h}_{h}"),
245            GemmBackwardBKernel::tiled_unrolled(s, h, h, tile)
246        );
247        if kv != h {
248            warm!(
249                format!("gemm_backward_a_{s}_{kv}_{h}"),
250                GemmBackwardAKernel::tiled_unrolled(s, kv, h, tile)
251            );
252            warm!(
253                format!("gemm_backward_b_{s}_{kv}_{h}"),
254                GemmBackwardBKernel::tiled_unrolled(s, kv, h, tile)
255            );
256        }
257
258        // FFN backward: gate/up (S,H,I), down (S,I,H)
259        warm!(
260            format!("gemm_backward_a_{s}_{h}_{i}"),
261            GemmBackwardAKernel::tiled_unrolled(s, h, i, tile)
262        );
263        warm!(
264            format!("gemm_backward_b_{s}_{h}_{i}"),
265            GemmBackwardBKernel::tiled_unrolled(s, h, i, tile)
266        );
267        warm!(
268            format!("gemm_backward_a_{s}_{i}_{h}"),
269            GemmBackwardAKernel::tiled_unrolled(s, i, h, tile)
270        );
271        warm!(
272            format!("gemm_backward_b_{s}_{i}_{h}"),
273            GemmBackwardBKernel::tiled_unrolled(s, i, h, tile)
274        );
275    }
276
277    // ── Activation backward: SiLU ──
278    let si = s * i;
279    warm!("silu_backward".to_string(), SiluBackwardKernel::new(si));
280
281    // ── Structured backward kernels (attention + normalization) ──
282    // Batched softmax backward (dimension-independent)
283    let softmax_rows = nh * s;
284    warm!(
285        "batched_softmax_backward".to_string(),
286        BatchedSoftmaxBackwardKernel::new(softmax_rows, s)
287    );
288
289    // RMSNorm backward (dimension-independent)
290    let eps = 1e-5_f32;
291    warm!("batched_rms_norm_backward".to_string(), BatchedRmsNormBackwardKernel::new(s, h, eps));
292
293    let _ = count;
294    Ok(())
295}