entrenar/autograd/cuda_backward/
cache.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use std::collections::HashMap;
8#[cfg(feature = "cuda")]
9use std::sync::{Arc, Mutex, OnceLock};
10
11#[cfg(feature = "cuda")]
12use trueno_gpu::driver::{CublasHandle, CudaContext, CudaModule, CudaStream};
13
14use super::super::cuda_tensor::{CudaTensorError, Result};
18
19#[cfg(feature = "cuda")]
21pub(super) static KERNEL_CACHE: OnceLock<Mutex<KernelCache>> = OnceLock::new();
22
23#[cfg(feature = "cuda")]
29pub(super) struct KernelCache {
30 ctx: Arc<CudaContext>,
31 modules: HashMap<String, CudaModule>,
32 sm_target: String,
33 cublas: Option<CublasHandle>,
36}
37
38#[cfg(feature = "cuda")]
39impl KernelCache {
40 pub(super) fn new(ctx: Arc<CudaContext>) -> Self {
41 let sm_target = ctx.sm_target().unwrap_or_else(|_| "sm_70".to_string());
42 let cublas = CublasHandle::new(&ctx).ok();
43 Self { ctx, modules: HashMap::new(), sm_target, cublas }
44 }
45
46 pub(super) fn cublas(&self) -> Option<&CublasHandle> {
48 self.cublas.as_ref()
49 }
50
51 pub(super) fn set_cublas_stream(&self, stream: &CudaStream) -> Result<()> {
53 if let Some(ref handle) = self.cublas {
54 handle.set_stream(stream).map_err(|e| {
55 CudaTensorError::KernelError(format!("cuBLAS set_stream failed: {e:?}"))
56 })?;
57 }
58 Ok(())
59 }
60
61 pub(super) fn sm_target(&self) -> &str {
62 &self.sm_target
63 }
64
65 pub(super) fn ctx(&self) -> &Arc<CudaContext> {
70 &self.ctx
71 }
72
73 pub(super) fn get_cached(&mut self, name: &str) -> Option<&mut CudaModule> {
75 self.modules.get_mut(name)
76 }
77
78 pub(super) fn get_or_compile(&mut self, name: &str, ptx: &str) -> Result<&mut CudaModule> {
79 use std::collections::hash_map::Entry;
80
81 if let Some(target_line) = ptx.lines().find(|l| l.starts_with(".target ")) {
83 let ptx_target = target_line.trim().trim_start_matches(".target ");
84 if ptx_target != self.sm_target {
85 return Err(CudaTensorError::KernelError(format!(
86 "F-PTX-001 violated: PTX target '{ptx_target}' != device target '{}'",
87 self.sm_target
88 )));
89 }
90 }
91
92 match self.modules.entry(name.to_string()) {
93 Entry::Occupied(e) => Ok(e.into_mut()),
94 Entry::Vacant(e) => {
95 eprintln!("[BWD-CACHE] Compiling '{name}' (ptx_len={})", ptx.len());
96
97 let (major, _minor) = self.ctx.compute_capability().map_err(|e| {
101 CudaTensorError::KernelError(format!("compute_capability: {e:?}"))
102 })?;
103
104 let module = if major >= 12 {
106 CudaModule::from_ptx_direct(&self.ctx, ptx)
107 } else {
108 CudaModule::from_ptx(&self.ctx, ptx)
109 }
110 .map_err(|err| {
111 eprintln!("[BWD-CACHE] FAILED '{name}': {err:?}");
112 CudaTensorError::KernelError(format!("Failed to compile {name}: {err:?}"))
113 })?;
114 eprintln!("[BWD-CACHE] OK '{name}'");
115 Ok(e.insert(module))
116 }
117 }
118 }
119}
120
121#[cfg(feature = "cuda")]
123pub fn init_kernel_cache(ctx: Arc<CudaContext>) -> Result<()> {
124 KERNEL_CACHE.get_or_init(|| Mutex::new(KernelCache::new(ctx)));
125 Ok(())
126}
127
128#[cfg(feature = "cuda")]
130pub fn set_backward_cublas_stream(stream: &CudaStream) -> Result<()> {
131 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
132 let cache = cache.lock().map_err(|_err| {
133 CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
134 })?;
135 cache.set_cublas_stream(stream)
136}
137
138#[cfg(feature = "cuda")]
153pub fn pre_warm_lora_backward_kernels(
154 hidden_size: usize,
155 q_dim: usize,
156 kv_hidden_size: usize,
157 max_seq_len: usize,
158 lora_rank: usize,
159 intermediate_size: usize,
160 num_heads: usize,
161 quantize_nf4: bool,
162) -> Result<()> {
163 use trueno_gpu::kernels::backward::{
164 BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, GemmBackwardAKernel,
165 GemmBackwardBKernel, RmsNormGammaReduceKernel, SiluBackwardKernel,
166 };
167 use trueno_gpu::kernels::Kernel;
168
169 eprintln!("[BWD-PREWARM] Called with lora_rank={lora_rank}, hidden={hidden_size}, inter={intermediate_size}");
170
171 let is_lora = lora_rank > 0;
186
187 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
188 let mut cache = cache.lock().map_err(|_err| {
189 CudaTensorError::KernelError("Failed to acquire backward kernel cache lock".to_string())
190 })?;
191
192 let s = max_seq_len as u32;
193 let h = hidden_size as u32;
194 let r = lora_rank as u32;
195 let qd = q_dim as u32;
196 let kv = kv_hidden_size as u32;
197 let i = intermediate_size as u32;
198 let nh = num_heads as u32;
199
200 let mut count = 0u32;
201 let target = cache.sm_target().to_string();
202
203 macro_rules! warm {
204 ($key:expr, $kernel:expr) => {{
205 let key = $key;
206 let ptx = $kernel.emit_ptx_for_target(&target);
207 cache.get_or_compile(&key, &ptx)?;
208 count += 1;
209 }};
210 }
211
212 let tile: u32 = 16;
214
215 if is_lora {
217 warm!(
219 format!("gemm_backward_b_{s}_{r}_{qd}"),
220 GemmBackwardBKernel::tiled_unrolled(s, r, qd, tile)
221 );
222 if kv != qd {
223 warm!(
224 format!("gemm_backward_b_{s}_{r}_{kv}"),
225 GemmBackwardBKernel::tiled_unrolled(s, r, kv, tile)
226 );
227 }
228 warm!(
229 format!("gemm_backward_b_{s}_{h}_{r}"),
230 GemmBackwardBKernel::tiled_unrolled(s, h, r, tile)
231 );
232
233 warm!(
235 format!("gemm_backward_a_{s}_{qd}_{r}"),
236 GemmBackwardAKernel::tiled_unrolled(s, qd, r, tile)
237 );
238 if kv != qd {
239 warm!(
240 format!("gemm_backward_a_{s}_{kv}_{r}"),
241 GemmBackwardAKernel::tiled_unrolled(s, kv, r, tile)
242 );
243 }
244 warm!(
245 format!("gemm_backward_a_{s}_{r}_{h}"),
246 GemmBackwardAKernel::tiled_unrolled(s, r, h, tile)
247 );
248 }
249
250 if !quantize_nf4 {
252 warm!(
254 format!("gemm_backward_a_{s}_{h}_{h}"),
255 GemmBackwardAKernel::tiled_unrolled(s, h, h, tile)
256 );
257 warm!(
258 format!("gemm_backward_b_{s}_{h}_{h}"),
259 GemmBackwardBKernel::tiled_unrolled(s, h, h, tile)
260 );
261 if kv != h {
262 warm!(
263 format!("gemm_backward_a_{s}_{kv}_{h}"),
264 GemmBackwardAKernel::tiled_unrolled(s, kv, h, tile)
265 );
266 warm!(
267 format!("gemm_backward_b_{s}_{kv}_{h}"),
268 GemmBackwardBKernel::tiled_unrolled(s, kv, h, tile)
269 );
270 }
271
272 warm!(
274 format!("gemm_backward_a_{s}_{h}_{i}"),
275 GemmBackwardAKernel::tiled_unrolled(s, h, i, tile)
276 );
277 warm!(
278 format!("gemm_backward_b_{s}_{h}_{i}"),
279 GemmBackwardBKernel::tiled_unrolled(s, h, i, tile)
280 );
281 warm!(
282 format!("gemm_backward_a_{s}_{i}_{h}"),
283 GemmBackwardAKernel::tiled_unrolled(s, i, h, tile)
284 );
285 warm!(
286 format!("gemm_backward_b_{s}_{i}_{h}"),
287 GemmBackwardBKernel::tiled_unrolled(s, i, h, tile)
288 );
289 }
290
291 let si = s * i;
293 warm!("silu_backward".to_string(), SiluBackwardKernel::new(si));
294
295 let softmax_rows = nh * s;
298 warm!(
299 "batched_softmax_backward".to_string(),
300 BatchedSoftmaxBackwardKernel::new(softmax_rows, s)
301 );
302
303 let eps = 1e-5_f32;
305 warm!("batched_rms_norm_backward".to_string(), BatchedRmsNormBackwardKernel::new(s, h, eps));
306
307 warm!("rms_norm_gamma_reduce".to_string(), RmsNormGammaReduceKernel::new(s, h));
318
319 let _ = count;
320 Ok(())
321}