1#![allow(dead_code)]
16#![allow(unsafe_code)]
21
22#[cfg(feature = "cuda")]
25const OP_RMSNORM_ATTN: usize = 0;
26#[cfg(feature = "cuda")]
27const OP_QKV_GEMM: usize = 1;
28#[cfg(feature = "cuda")]
29const OP_ATTENTION: usize = 2;
30#[cfg(feature = "cuda")]
31const OP_O_PROJ: usize = 3;
32#[cfg(feature = "cuda")]
33const OP_RMSNORM_FFN: usize = 4;
34#[cfg(feature = "cuda")]
35const OP_GATE_UP_GEMM: usize = 5;
36#[cfg(feature = "cuda")]
37const OP_SILU: usize = 6;
38#[cfg(feature = "cuda")]
39const OP_DOWN_GEMM: usize = 7;
40
41#[cfg(feature = "cuda")]
43const OP_LORA_FWD: usize = 8;
44#[cfg(feature = "cuda")]
45const OP_DOWN_BWD: usize = 9;
46#[cfg(feature = "cuda")]
47const OP_SWIGLU_BWD: usize = 10;
48#[cfg(feature = "cuda")]
49const OP_GATE_UP_BWD: usize = 11;
50#[cfg(feature = "cuda")]
51const OP_ATTN_BWD: usize = 12;
52#[cfg(feature = "cuda")]
53const OP_QKV_BWD: usize = 13;
54#[cfg(feature = "cuda")]
55const OP_NORM_BWD: usize = 14;
56#[cfg(feature = "cuda")]
57const OP_LORA_BWD: usize = 15;
58
59#[cfg(feature = "cuda")]
60use std::sync::Arc;
61
62#[cfg(feature = "cuda")]
63#[inline]
64fn saturating_u32(v: usize) -> u32 {
65 v.min(u32::MAX as usize) as u32
66}
67
68#[cfg(feature = "cuda")]
70#[inline]
71fn leak<T>(val: T) {
72 let _ = std::mem::ManuallyDrop::new(val);
73}
74
75#[cfg(feature = "cuda")]
76use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
77
78#[cfg(feature = "cuda")]
79use crate::autograd::cuda_backward::{
80 batched_softmax_backward, gemm_backward_a, gemm_backward_a_fp16_dispatch,
81 gemm_backward_a_fp16_dispatch_accumulate, gemm_backward_b, rms_norm_backward, silu_backward,
82};
83#[cfg(feature = "cuda")]
84use crate::autograd::cuda_forward::{
85 batched_4d_gemm_forward, batched_rope_neox_backward, batched_rope_neox_forward,
86 batched_softmax_forward, batched_to_interleaved_forward, batched_transpose_forward,
87 cast_f32_to_f16_gpu, elementwise_mul_forward, expand_kv_heads, fused_residual_rmsnorm_forward,
88 fused_swiglu_forward, gemm_f16_to_f32_forward, gemm_forward, interleaved_to_batched_forward,
89 per_head_rmsnorm_forward, residual_add_forward, rms_norm_forward, scale_forward, silu_forward,
90};
91#[cfg(feature = "cuda")]
92use crate::autograd::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, squared_sum_cuda};
93#[cfg(feature = "cuda")]
94use crate::autograd::cuda_tensor::Result;
95
96#[cfg(feature = "cuda")]
97use super::config::TransformerConfig;
98
99#[cfg(feature = "cuda")]
103pub struct CudaTransformerBlock {
104 config: TransformerConfig,
106 layer_idx: usize,
108 input_norm_weight: GpuBuffer<f32>,
110 post_attn_norm_weight: GpuBuffer<f32>,
112 w_q: GpuBuffer<f32>,
114 w_k: GpuBuffer<f32>,
116 w_v: GpuBuffer<f32>,
118 w_o: GpuBuffer<f32>,
120 w_gate: GpuBuffer<f32>,
122 w_up: GpuBuffer<f32>,
124 w_down: GpuBuffer<f32>,
126 ctx: Arc<CudaContext>,
128 scratch: CudaBlockScratch,
130 norm_zero_buf: Vec<f32>,
132 q_norm_weight: Option<GpuBuffer<f32>>,
134 k_norm_weight: Option<GpuBuffer<f32>>,
135}
136
137#[cfg(feature = "cuda")]
148pub(crate) struct CudaBlockScratch {
149 norm1_out: GpuBuffer<f32>,
151 q: GpuBuffer<f32>,
153 k: GpuBuffer<f32>,
155 v: GpuBuffer<f32>,
157 attn_scores: GpuBuffer<f32>,
159 attn_out: GpuBuffer<f32>,
161 o_proj_out: GpuBuffer<f32>,
163 residual1: GpuBuffer<f32>,
165 norm2_out: GpuBuffer<f32>,
167 gate_out: GpuBuffer<f32>,
169 up_out: GpuBuffer<f32>,
171 swiglu_out: GpuBuffer<f32>,
173 ffn_out: GpuBuffer<f32>,
175 norm1_out_f16: Option<GpuBuffer<u16>>,
178 attn_out_f16: Option<GpuBuffer<u16>>,
179 norm2_out_f16: Option<GpuBuffer<u16>>,
180 swiglu_out_f16: Option<GpuBuffer<u16>>,
181 grad_hidden: GpuBuffer<f32>,
184 grad_swiglu: GpuBuffer<f32>,
186 attn_q_batched: GpuBuffer<f32>,
189 attn_kv_temp: GpuBuffer<f32>,
191 attn_kv_temp2: GpuBuffer<f32>,
193 grad_attn_scores: GpuBuffer<f32>,
197 lora_inter: GpuBuffer<f32>,
200 lora_temp: GpuBuffer<f32>,
203 rope_positions: GpuBuffer<u32>,
205 causal_mask_contiguous: GpuBuffer<f32>,
207 pub(crate) causal_mask_cached_seq_len: usize,
209 pub(crate) op_us: [u64; 16],
213 pub(crate) op_profiling_enabled: bool,
215}
216
217#[cfg(feature = "cuda")]
218impl CudaBlockScratch {
219 #[inline]
221 pub(crate) fn op_begin(&self) -> Option<std::time::Instant> {
222 if self.op_profiling_enabled {
223 Some(std::time::Instant::now())
224 } else {
225 None
226 }
227 }
228
229 #[inline]
231 pub(crate) fn op_end(&mut self, start: Option<std::time::Instant>, op: usize) {
232 if let Some(t) = start {
233 if op < 16 {
234 self.op_us[op] += t.elapsed().as_micros() as u64;
235 }
236 }
237 }
238
239 pub(crate) fn max_seq_len(&self, hidden_size: usize) -> usize {
243 self.norm1_out.len() / hidden_size.max(1)
244 }
245
246 #[rustfmt::skip]
247 pub(crate) fn zero_forward_buffers(&mut self, stream: &CudaStream) {
248 let z = |b: &mut GpuBuffer<f32>| { b.zero_async(stream).ok(); };
249 z(&mut self.norm1_out); z(&mut self.q); z(&mut self.k); z(&mut self.v); z(&mut self.attn_scores); z(&mut self.attn_out);
250 z(&mut self.o_proj_out); z(&mut self.residual1); z(&mut self.norm2_out); z(&mut self.gate_out); z(&mut self.up_out);
251 z(&mut self.swiglu_out); z(&mut self.ffn_out); z(&mut self.attn_q_batched); z(&mut self.attn_kv_temp); z(&mut self.attn_kv_temp2);
252 z(&mut self.grad_hidden); z(&mut self.grad_swiglu); z(&mut self.grad_attn_scores); z(&mut self.lora_inter); z(&mut self.lora_temp);
253 self.causal_mask_cached_seq_len = 0;
254 }
255
256 pub(crate) fn new(
262 config: &TransformerConfig,
263 max_seq_len: usize,
264 ctx: &Arc<CudaContext>,
265 lora_rank: usize,
266 ) -> Result<Self> {
267 let hidden_size = config.hidden_size;
268 let q_dim = config.q_dim();
269 let kv_hidden_size = config.num_kv_heads * config.head_dim();
270 let intermediate_size = config.intermediate_size;
271 let num_heads = config.num_attention_heads;
272 let head_dim = config.head_dim();
273
274 let max_proj_dim = q_dim.max(kv_hidden_size);
276 let lora_inter_size = (max_seq_len * lora_rank).max(1);
278 let lora_temp_size = (max_seq_len * max_proj_dim).max(1);
279
280 let causal_mask_data: Vec<f32> = (0..max_seq_len * max_seq_len)
283 .map(|idx| {
284 let row = idx / max_seq_len;
285 let col = idx % max_seq_len;
286 if col <= row {
287 0.0f32
288 } else {
289 f32::NEG_INFINITY
290 }
291 })
292 .collect();
293 Ok(Self {
294 norm1_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
295 q: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
296 k: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
297 v: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
298 attn_scores: GpuBuffer::new(ctx, num_heads * max_seq_len * max_seq_len)?,
299 attn_out: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
300 o_proj_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
301 residual1: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
302 norm2_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
303 gate_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
304 up_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
305 swiglu_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
306 ffn_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
307 norm1_out_f16: None,
308 attn_out_f16: None,
309 norm2_out_f16: None,
310 swiglu_out_f16: None,
311 grad_hidden: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
312 grad_swiglu: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
313 attn_q_batched: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
314 attn_kv_temp: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
315 attn_kv_temp2: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
316 grad_attn_scores: GpuBuffer::new(
317 ctx,
318 num_heads * max_seq_len * max_seq_len.max(head_dim),
319 )?,
320 lora_inter: GpuBuffer::new(ctx, lora_inter_size)?,
321 lora_temp: GpuBuffer::new(ctx, lora_temp_size)?,
322 rope_positions: {
323 let positions: Vec<u32> = (0..max_seq_len as u32).collect();
324 let mut buf = GpuBuffer::new(ctx, max_seq_len)?;
325 buf.copy_from_host(&positions)?;
326 buf
327 },
328 causal_mask_contiguous: GpuBuffer::from_host(ctx, &causal_mask_data)?,
329 causal_mask_cached_seq_len: max_seq_len,
330 op_us: [0u64; 16],
331 op_profiling_enabled: false,
332 })
333 }
334
335 pub(crate) fn prepare_causal_mask(
338 &mut self,
339 seq_len: usize,
340 ctx: &Arc<CudaContext>,
341 ) -> crate::autograd::cuda_tensor::Result<()> {
342 if seq_len == self.causal_mask_cached_seq_len {
343 return Ok(());
344 }
345 let mask_data: Vec<f32> = (0..seq_len * seq_len)
346 .map(|idx| {
347 let row = idx / seq_len;
348 let col = idx % seq_len;
349 if col <= row {
350 0.0f32
351 } else {
352 f32::NEG_INFINITY
353 }
354 })
355 .collect();
356 self.causal_mask_contiguous = GpuBuffer::from_host(ctx, &mask_data)?;
357 self.causal_mask_cached_seq_len = seq_len;
358 Ok(())
359 }
360}
361
362#[cfg(feature = "cuda")]
376pub struct CudaGradWorkspace {
377 pub(crate) grad_input_norm: GpuBuffer<f32>,
379 pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
381 pub(crate) grad_gate: GpuBuffer<f32>,
383 pub(crate) grad_up: GpuBuffer<f32>,
385 pub(crate) grad_down: GpuBuffer<f32>,
387 pub(crate) grad_w_q: GpuBuffer<f32>,
389 pub(crate) grad_w_k: GpuBuffer<f32>,
391 pub(crate) grad_w_v: GpuBuffer<f32>,
393 pub(crate) grad_w_o: GpuBuffer<f32>,
395}
396
397#[cfg(feature = "cuda")]
398impl CudaGradWorkspace {
399 pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> Result<Self> {
405 let h = config.hidden_size;
406 let q = config.q_dim();
407 let kv = config.num_kv_heads * config.head_dim();
408 let i = config.intermediate_size;
409
410 Ok(Self {
411 grad_input_norm: GpuBuffer::new(ctx, h)?,
412 grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
413 grad_gate: GpuBuffer::new(ctx, h * i)?,
414 grad_up: GpuBuffer::new(ctx, h * i)?,
415 grad_down: GpuBuffer::new(ctx, i * h)?,
416 grad_w_q: GpuBuffer::new(ctx, q * h)?,
417 grad_w_k: GpuBuffer::new(ctx, h * kv)?,
418 grad_w_v: GpuBuffer::new(ctx, h * kv)?,
419 grad_w_o: GpuBuffer::new(ctx, h * q)?,
420 })
421 }
422
423 pub fn zero_norm_grads(&mut self, zero_buf: &[f32]) -> Result<()> {
429 let n = self.grad_input_norm.len();
430 self.grad_input_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
431 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
432 "Failed to zero grad_input_norm: {e:?}"
433 ))
434 })?;
435 self.grad_post_attn_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
436 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
437 "Failed to zero grad_post_attn_norm: {e:?}"
438 ))
439 })?;
440 Ok(())
441 }
442}
443
444#[cfg(feature = "cuda")]
456pub struct GpuBlockOptimizerState {
457 m_w_q: GpuBuffer<f32>,
459 v_w_q: GpuBuffer<f32>,
460 m_w_k: GpuBuffer<f32>,
461 v_w_k: GpuBuffer<f32>,
462 m_w_v: GpuBuffer<f32>,
463 v_w_v: GpuBuffer<f32>,
464 m_w_o: GpuBuffer<f32>,
465 v_w_o: GpuBuffer<f32>,
466 m_w_gate: GpuBuffer<f32>,
468 v_w_gate: GpuBuffer<f32>,
469 m_w_up: GpuBuffer<f32>,
470 v_w_up: GpuBuffer<f32>,
471 m_w_down: GpuBuffer<f32>,
472 v_w_down: GpuBuffer<f32>,
473 m_input_norm: GpuBuffer<f32>,
475 v_input_norm: GpuBuffer<f32>,
476 m_post_attn_norm: GpuBuffer<f32>,
477 v_post_attn_norm: GpuBuffer<f32>,
478}
479
480#[cfg(feature = "cuda")]
482impl GpuBlockOptimizerState {
483 pub fn download_to_host(
486 &self,
487 ) -> crate::autograd::cuda_tensor::Result<Vec<(String, Vec<f32>)>> {
488 let dl = |name: &str,
489 buf: &GpuBuffer<f32>|
490 -> crate::autograd::cuda_tensor::Result<(String, Vec<f32>)> {
491 let mut host = vec![0.0f32; buf.len()];
492 buf.copy_to_host(&mut host).map_err(|e| {
493 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
494 "optimizer D2H {name}: {e}"
495 ))
496 })?;
497 Ok((name.to_string(), host))
498 };
499 Ok(vec![
500 dl("m.w_q", &self.m_w_q)?,
501 dl("v.w_q", &self.v_w_q)?,
502 dl("m.w_k", &self.m_w_k)?,
503 dl("v.w_k", &self.v_w_k)?,
504 dl("m.w_v", &self.m_w_v)?,
505 dl("v.w_v", &self.v_w_v)?,
506 dl("m.w_o", &self.m_w_o)?,
507 dl("v.w_o", &self.v_w_o)?,
508 dl("m.w_gate", &self.m_w_gate)?,
509 dl("v.w_gate", &self.v_w_gate)?,
510 dl("m.w_up", &self.m_w_up)?,
511 dl("v.w_up", &self.v_w_up)?,
512 dl("m.w_down", &self.m_w_down)?,
513 dl("v.w_down", &self.v_w_down)?,
514 dl("m.input_norm", &self.m_input_norm)?,
515 dl("v.input_norm", &self.v_input_norm)?,
516 dl("m.post_attn_norm", &self.m_post_attn_norm)?,
517 dl("v.post_attn_norm", &self.v_post_attn_norm)?,
518 ])
519 }
520
521 pub fn restore_from_host(
524 &mut self,
525 data: &std::collections::HashMap<String, Vec<f32>>,
526 ) -> crate::autograd::cuda_tensor::Result<()> {
527 let ul = |name: &str,
528 buf: &mut GpuBuffer<f32>,
529 data: &std::collections::HashMap<String, Vec<f32>>|
530 -> crate::autograd::cuda_tensor::Result<()> {
531 if let Some(host_data) = data.get(name) {
532 if host_data.len() == buf.len() {
533 buf.copy_from_host(host_data).map_err(|e| {
534 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
535 "optimizer H2D {name}: {e}"
536 ))
537 })?;
538 }
539 }
540 Ok(())
541 };
542 ul("m.w_q", &mut self.m_w_q, data)?;
543 ul("v.w_q", &mut self.v_w_q, data)?;
544 ul("m.w_k", &mut self.m_w_k, data)?;
545 ul("v.w_k", &mut self.v_w_k, data)?;
546 ul("m.w_v", &mut self.m_w_v, data)?;
547 ul("v.w_v", &mut self.v_w_v, data)?;
548 ul("m.w_o", &mut self.m_w_o, data)?;
549 ul("v.w_o", &mut self.v_w_o, data)?;
550 ul("m.w_gate", &mut self.m_w_gate, data)?;
551 ul("v.w_gate", &mut self.v_w_gate, data)?;
552 ul("m.w_up", &mut self.m_w_up, data)?;
553 ul("v.w_up", &mut self.v_w_up, data)?;
554 ul("m.w_down", &mut self.m_w_down, data)?;
555 ul("v.w_down", &mut self.v_w_down, data)?;
556 ul("m.input_norm", &mut self.m_input_norm, data)?;
557 ul("v.input_norm", &mut self.v_input_norm, data)?;
558 ul("m.post_attn_norm", &mut self.m_post_attn_norm, data)?;
559 ul("v.post_attn_norm", &mut self.v_post_attn_norm, data)?;
560 Ok(())
561 }
562}
563
564#[cfg(feature = "cuda")]
565impl CudaTransformerBlock {
566 pub fn new(
570 config: &TransformerConfig,
571 layer_idx: usize,
572 ctx: Arc<CudaContext>,
573 input_norm_weight: &[f32],
574 post_attn_norm_weight: &[f32],
575 w_q: &[f32],
576 w_k: &[f32],
577 w_v: &[f32],
578 w_o: &[f32],
579 w_gate: &[f32],
580 w_up: &[f32],
581 w_down: &[f32],
582 max_seq_len: usize,
583 ) -> Result<Self> {
584 let hidden_size = config.hidden_size;
585 let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
587 let intermediate_size = config.intermediate_size;
588 let num_heads = config.num_attention_heads;
589
590 let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
592 let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
593 let w_q = GpuBuffer::from_host(&ctx, w_q)?;
594 let w_k = GpuBuffer::from_host(&ctx, w_k)?;
595 let w_v = GpuBuffer::from_host(&ctx, w_v)?;
596 let w_o = GpuBuffer::from_host(&ctx, w_o)?;
597 let w_gate = GpuBuffer::from_host(&ctx, w_gate)?;
598 let w_up = GpuBuffer::from_host(&ctx, w_up)?;
599 let w_down = GpuBuffer::from_host(&ctx, w_down)?;
600
601 let single_mask: Vec<f32> = (0..max_seq_len * max_seq_len)
603 .map(|idx| {
604 let row = idx / max_seq_len;
605 let col = idx % max_seq_len;
606 if col <= row {
607 0.0f32
608 } else {
609 f32::NEG_INFINITY
610 }
611 })
612 .collect();
613 let scratch = CudaBlockScratch {
615 norm1_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
616 q: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
617 k: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
618 v: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
619 attn_scores: GpuBuffer::new(&ctx, num_heads * max_seq_len * max_seq_len)?,
620 attn_out: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
621 o_proj_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
622 residual1: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
623 norm2_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
624 gate_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
625 up_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
626 swiglu_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
627 ffn_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
628 norm1_out_f16: None,
629 attn_out_f16: None,
630 norm2_out_f16: None,
631 swiglu_out_f16: None,
632 grad_hidden: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
634 grad_swiglu: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
635 attn_q_batched: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
637 attn_kv_temp: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
638 attn_kv_temp2: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
639 grad_attn_scores: GpuBuffer::new(
642 &ctx,
643 num_heads * max_seq_len * max_seq_len.max(config.head_dim()),
644 )?,
645 lora_inter: GpuBuffer::new(&ctx, 1)?,
647 lora_temp: GpuBuffer::new(&ctx, 1)?,
648 rope_positions: {
649 let positions: Vec<u32> = (0..max_seq_len as u32).collect();
650 let mut buf = GpuBuffer::new(&ctx, max_seq_len)?;
651 buf.copy_from_host(&positions)?;
652 buf
653 },
654 causal_mask_contiguous: GpuBuffer::from_host(&ctx, &single_mask)?,
655 causal_mask_cached_seq_len: max_seq_len,
656 op_us: [0u64; 16],
657 op_profiling_enabled: false,
658 };
659
660 Ok(Self {
661 config: config.clone(),
662 layer_idx,
663 input_norm_weight,
664 post_attn_norm_weight,
665 w_q,
666 w_k,
667 w_v,
668 w_o,
669 w_gate,
670 w_up,
671 w_down,
672 ctx,
673 scratch,
674 norm_zero_buf: vec![0.0f32; hidden_size],
675 q_norm_weight: None, k_norm_weight: None,
677 })
678 }
679
680 #[allow(dead_code)]
682 pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()> {
683 self.q_norm_weight = Some(GpuBuffer::from_host(&self.ctx, q_norm)?);
684 self.k_norm_weight = Some(GpuBuffer::from_host(&self.ctx, k_norm)?);
685 Ok(())
686 }
687
688 pub fn forward(
696 &mut self,
697 input: &GpuBuffer<f32>,
698 output: &mut GpuBuffer<f32>,
699 seq_len: usize,
700 stream: &CudaStream,
701 ) -> Result<()> {
702 let hidden_size = self.config.hidden_size;
703 let q_dim = self.config.q_dim();
704 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
705 let intermediate_size = self.config.intermediate_size;
706
707 rms_norm_forward(
709 input,
710 &self.input_norm_weight,
711 &mut self.scratch.norm1_out,
712 saturating_u32(seq_len),
713 saturating_u32(hidden_size),
714 stream,
715 )?;
716
717 gemm_forward(
720 &self.scratch.norm1_out,
721 &self.w_q,
722 &mut self.scratch.q,
723 saturating_u32(seq_len),
724 saturating_u32(hidden_size),
725 saturating_u32(q_dim),
726 stream,
727 )?;
728
729 gemm_forward(
730 &self.scratch.norm1_out,
731 &self.w_k,
732 &mut self.scratch.k,
733 saturating_u32(seq_len),
734 saturating_u32(hidden_size),
735 saturating_u32(kv_hidden_size),
736 stream,
737 )?;
738
739 gemm_forward(
740 &self.scratch.norm1_out,
741 &self.w_v,
742 &mut self.scratch.v,
743 saturating_u32(seq_len),
744 saturating_u32(hidden_size),
745 saturating_u32(kv_hidden_size),
746 stream,
747 )?;
748
749 self.compute_attention_cuda(seq_len, stream)?;
751
752 gemm_forward(
755 &self.scratch.attn_out,
756 &self.w_o,
757 &mut self.scratch.o_proj_out,
758 saturating_u32(seq_len),
759 saturating_u32(q_dim),
760 saturating_u32(hidden_size),
761 stream,
762 )?;
763
764 cuda_add(
766 input,
767 &self.scratch.o_proj_out,
768 &mut self.scratch.residual1,
769 seq_len * hidden_size,
770 stream,
771 )?;
772
773 rms_norm_forward(
775 &self.scratch.residual1,
776 &self.post_attn_norm_weight,
777 &mut self.scratch.norm2_out,
778 saturating_u32(seq_len),
779 saturating_u32(hidden_size),
780 stream,
781 )?;
782
783 gemm_forward(
785 &self.scratch.norm2_out,
786 &self.w_gate,
787 &mut self.scratch.gate_out,
788 saturating_u32(seq_len),
789 saturating_u32(hidden_size),
790 saturating_u32(intermediate_size),
791 stream,
792 )?;
793
794 gemm_forward(
795 &self.scratch.norm2_out,
796 &self.w_up,
797 &mut self.scratch.up_out,
798 saturating_u32(seq_len),
799 saturating_u32(hidden_size),
800 saturating_u32(intermediate_size),
801 stream,
802 )?;
803
804 fused_swiglu_forward(
806 &self.scratch.gate_out,
807 &self.scratch.up_out,
808 &mut self.scratch.swiglu_out,
809 saturating_u32(seq_len * intermediate_size),
810 stream,
811 )?;
812
813 gemm_forward(
815 &self.scratch.swiglu_out,
816 &self.w_down,
817 &mut self.scratch.ffn_out,
818 saturating_u32(seq_len),
819 saturating_u32(intermediate_size),
820 saturating_u32(hidden_size),
821 stream,
822 )?;
823
824 cuda_add(
826 &self.scratch.residual1,
827 &self.scratch.ffn_out,
828 output,
829 seq_len * hidden_size,
830 stream,
831 )?;
832
833 Ok(())
834 }
835
836 fn compute_attention_cuda(&mut self, seq_len: usize, stream: &CudaStream) -> Result<()> {
854 let num_heads = self.config.num_attention_heads;
855 let num_kv_heads = self.config.num_kv_heads;
856 let head_dim = self.config.head_dim();
857 let heads_per_kv = num_heads / num_kv_heads;
858 let scale = 1.0 / (head_dim as f32).sqrt();
859
860 let seq = saturating_u32(seq_len);
861 let nh = saturating_u32(num_heads);
862 let nkv = saturating_u32(num_kv_heads);
863 let hd = saturating_u32(head_dim);
864
865 self.scratch.prepare_causal_mask(seq_len, &self.ctx)?;
867
868 if let Some(ref q_norm) = self.q_norm_weight {
873 for pos in 0..seq_len {
874 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
875 per_head_rmsnorm_forward(q_ref, q_norm, &mut self.scratch.q, nh, hd, pos, stream)?;
876 }
877 }
878 if let Some(ref k_norm) = self.k_norm_weight {
879 for pos in 0..seq_len {
880 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
881 per_head_rmsnorm_forward(k_ref, k_norm, &mut self.scratch.k, nkv, hd, pos, stream)?;
882 }
883 }
884
885 let rope_theta = self.config.rope_theta;
888 {
889 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
890 batched_rope_neox_forward(
891 q_ref,
892 &mut self.scratch.q,
893 &self.scratch.rope_positions,
894 nh,
895 hd,
896 seq,
897 rope_theta,
898 stream,
899 )?;
900 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
901 batched_rope_neox_forward(
902 k_ref,
903 &mut self.scratch.k,
904 &self.scratch.rope_positions,
905 nkv,
906 hd,
907 seq,
908 rope_theta,
909 stream,
910 )?;
911 }
912
913 interleaved_to_batched_forward(
915 &self.scratch.q,
916 &mut self.scratch.attn_q_batched,
917 seq,
918 nh,
919 hd,
920 stream,
921 )?;
922
923 interleaved_to_batched_forward(
925 &self.scratch.k,
926 &mut self.scratch.attn_kv_temp,
927 seq,
928 nkv,
929 hd,
930 stream,
931 )?;
932
933 if heads_per_kv == 1 {
935 batched_transpose_forward(
937 &self.scratch.attn_kv_temp,
938 &mut self.scratch.attn_kv_temp2,
939 nh,
940 seq,
941 hd,
942 stream,
943 )?;
944 } else {
945 expand_kv_heads(
947 &self.scratch.attn_kv_temp,
948 &mut self.scratch.attn_kv_temp2,
949 num_kv_heads,
950 heads_per_kv,
951 seq_len * head_dim,
952 stream,
953 )?;
954 batched_transpose_forward(
956 &self.scratch.attn_kv_temp2,
957 &mut self.scratch.attn_kv_temp,
958 nh,
959 seq,
960 hd,
961 stream,
962 )?;
963 unsafe {
967 self.scratch
968 .attn_kv_temp2
969 .copy_from_buffer_async(&self.scratch.attn_kv_temp, stream)
970 .map_err(|e| {
971 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
972 "K^T buffer copy failed: {e}"
973 ))
974 })?;
975 }
976 }
977
978 batched_4d_gemm_forward(
983 &self.scratch.attn_q_batched,
984 &self.scratch.attn_kv_temp2,
985 &mut self.scratch.attn_scores,
986 1,
987 nh,
988 seq,
989 seq,
990 hd,
991 stream,
992 )?;
993
994 let total_scores = nh * seq * seq;
996 {
997 let scores_view = unsafe {
1001 GpuBuffer::<f32>::from_raw_parts(
1002 self.scratch.attn_scores.as_ptr(),
1003 self.scratch.attn_scores.len(),
1004 )
1005 };
1006 scale_forward(
1007 &scores_view,
1008 &mut self.scratch.attn_scores,
1009 scale,
1010 total_scores,
1011 stream,
1012 )?;
1013 leak(scores_view);
1014 }
1015
1016 {
1022 let seq_sq = (seq * seq) as usize;
1023 let mask_ptr = self.scratch.causal_mask_contiguous.as_ptr();
1024 let scores_base = self.scratch.attn_scores.as_ptr();
1025 for head in 0..nh as usize {
1026 let byte_offset = (head * seq_sq * 4) as u64; let head_ptr = scores_base + byte_offset;
1028 let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
1032 let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1033 let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1034 residual_add_forward(&mask_view, &scores_view, &mut out_view, seq * seq, stream)?;
1035 leak(mask_view);
1036 leak(scores_view);
1037 leak(out_view);
1038 }
1039 }
1040
1041 let total_rows = nh * seq;
1043 {
1044 let scores_view = unsafe {
1048 GpuBuffer::<f32>::from_raw_parts(
1049 self.scratch.attn_scores.as_ptr(),
1050 self.scratch.attn_scores.len(),
1051 )
1052 };
1053 batched_softmax_forward(
1054 &scores_view,
1055 &mut self.scratch.attn_scores,
1056 total_rows,
1057 seq,
1058 stream,
1059 )?;
1060 leak(scores_view);
1061 }
1062
1063 interleaved_to_batched_forward(
1065 &self.scratch.v,
1066 &mut self.scratch.attn_kv_temp,
1067 seq,
1068 nkv,
1069 hd,
1070 stream,
1071 )?;
1072
1073 if heads_per_kv == 1 {
1074 } else {
1076 expand_kv_heads(
1078 &self.scratch.attn_kv_temp,
1079 &mut self.scratch.attn_kv_temp2,
1080 num_kv_heads,
1081 heads_per_kv,
1082 seq_len * head_dim,
1083 stream,
1084 )?;
1085 unsafe {
1088 self.scratch
1089 .attn_kv_temp
1090 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1091 .map_err(|e| {
1092 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1093 "V expanded buffer copy failed: {e}"
1094 ))
1095 })?;
1096 }
1097 }
1098
1099 batched_4d_gemm_forward(
1104 &self.scratch.attn_scores,
1105 &self.scratch.attn_kv_temp,
1106 &mut self.scratch.attn_q_batched,
1107 1,
1108 nh,
1109 seq,
1110 hd,
1111 seq,
1112 stream,
1113 )?;
1114
1115 batched_to_interleaved_forward(
1117 &self.scratch.attn_q_batched,
1118 &mut self.scratch.attn_out,
1119 seq,
1120 nh,
1121 hd,
1122 stream,
1123 )?;
1124
1125 Ok(())
1126 }
1127
1128 pub fn layer_idx(&self) -> usize {
1130 self.layer_idx
1131 }
1132
1133 pub fn config(&self) -> &TransformerConfig {
1135 &self.config
1136 }
1137
1138 #[provable_contracts_macros::contract("backward-pass-v1", equation = "backward")]
1156 pub fn backward(
1157 &mut self,
1158 input: &GpuBuffer<f32>,
1159 grad_output: &GpuBuffer<f32>,
1160 grad_input: &mut GpuBuffer<f32>,
1161 seq_len: usize,
1162 stream: &CudaStream,
1163 grad_ws: &mut CudaGradWorkspace,
1164 ) -> Result<()> {
1165 let hidden_size = self.config.hidden_size;
1166 let intermediate_size = self.config.intermediate_size;
1167 let eps = 1e-5_f32;
1168
1169 grad_ws.zero_norm_grads(&self.norm_zero_buf)?;
1173
1174 self.backward_ffn(grad_output, seq_len, hidden_size, intermediate_size, stream, grad_ws)?;
1177
1178 self.backward_post_attn_norm(grad_input, seq_len, hidden_size, eps, stream, grad_ws)?;
1180
1181 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
1186
1187 self.backward_attention(grad_input, seq_len, stream, grad_ws)?;
1190
1191 self.backward_residual_and_input_norm(
1193 input,
1194 grad_output,
1195 grad_input,
1196 seq_len,
1197 hidden_size,
1198 eps,
1199 stream,
1200 grad_ws,
1201 )?;
1202
1203 Ok(())
1204 }
1205
1206 fn backward_ffn(
1221 &mut self,
1222 grad_output: &GpuBuffer<f32>,
1223 seq_len: usize,
1224 hidden_size: usize,
1225 intermediate_size: usize,
1226 stream: &CudaStream,
1227 grad_ws: &mut CudaGradWorkspace,
1228 ) -> Result<()> {
1229 let n_inter = saturating_u32(seq_len * intermediate_size);
1230 let n_hidden = saturating_u32(seq_len * hidden_size);
1231
1232 gemm_backward_a(
1234 grad_output,
1235 &self.w_down,
1236 &mut self.scratch.grad_swiglu,
1237 saturating_u32(seq_len),
1238 saturating_u32(intermediate_size),
1239 saturating_u32(hidden_size),
1240 stream,
1241 )?;
1242
1243 gemm_backward_b(
1246 &self.scratch.swiglu_out,
1247 grad_output,
1248 &mut grad_ws.grad_down,
1249 saturating_u32(seq_len),
1250 saturating_u32(intermediate_size),
1251 saturating_u32(hidden_size),
1252 stream,
1253 )?;
1254
1255 elementwise_mul_forward(
1259 &self.scratch.grad_swiglu,
1260 &self.scratch.up_out,
1261 &mut self.scratch.swiglu_out,
1262 n_inter,
1263 stream,
1264 )?;
1265
1266 silu_backward(
1269 &self.scratch.gate_out,
1270 &self.scratch.swiglu_out,
1271 &mut self.scratch.up_out,
1272 stream,
1273 )?;
1274 silu_forward(&self.scratch.gate_out, &mut self.scratch.swiglu_out, n_inter, stream)?;
1278
1279 elementwise_mul_forward(
1281 &self.scratch.grad_swiglu,
1282 &self.scratch.swiglu_out,
1283 &mut self.scratch.gate_out,
1284 n_inter,
1285 stream,
1286 )?;
1287 gemm_backward_b(
1293 &self.scratch.norm2_out,
1294 &self.scratch.up_out,
1295 &mut grad_ws.grad_gate,
1296 saturating_u32(seq_len),
1297 saturating_u32(hidden_size),
1298 saturating_u32(intermediate_size),
1299 stream,
1300 )?;
1301
1302 gemm_backward_b(
1304 &self.scratch.norm2_out,
1305 &self.scratch.gate_out,
1306 &mut grad_ws.grad_up,
1307 saturating_u32(seq_len),
1308 saturating_u32(hidden_size),
1309 saturating_u32(intermediate_size),
1310 stream,
1311 )?;
1312
1313 gemm_backward_a(
1317 &self.scratch.up_out,
1318 &self.w_gate,
1319 &mut self.scratch.ffn_out,
1320 saturating_u32(seq_len),
1321 saturating_u32(hidden_size),
1322 saturating_u32(intermediate_size),
1323 stream,
1324 )?;
1325
1326 gemm_backward_a(
1328 &self.scratch.gate_out,
1329 &self.w_up,
1330 &mut self.scratch.grad_hidden,
1331 saturating_u32(seq_len),
1332 saturating_u32(hidden_size),
1333 saturating_u32(intermediate_size),
1334 stream,
1335 )?;
1336
1337 residual_add_forward(
1339 &self.scratch.ffn_out,
1340 &self.scratch.grad_hidden,
1341 &mut self.scratch.norm2_out,
1342 n_hidden,
1343 stream,
1344 )?;
1345
1346 Ok(())
1347 }
1348
1349 fn backward_post_attn_norm(
1351 &mut self,
1352 grad_input: &mut GpuBuffer<f32>,
1353 seq_len: usize,
1354 hidden_size: usize,
1355 eps: f32,
1356 stream: &CudaStream,
1357 grad_ws: &mut CudaGradWorkspace,
1358 ) -> Result<()> {
1359 unsafe {
1362 self.scratch
1363 .grad_hidden
1364 .copy_from_buffer_async(&self.scratch.norm2_out, stream)
1365 .map_err(|e| {
1366 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1367 "Backward norm D2D copy failed: {e}"
1368 ))
1369 })?;
1370 }
1371
1372 rms_norm_backward(
1373 &self.scratch.residual1,
1374 &self.post_attn_norm_weight,
1375 &self.scratch.grad_hidden,
1376 grad_input,
1377 &mut grad_ws.grad_post_attn_norm,
1378 saturating_u32(seq_len),
1379 saturating_u32(hidden_size),
1380 eps,
1381 stream,
1382 )
1383 }
1384
1385 fn backward_attention(
1398 &mut self,
1399 grad_input: &mut GpuBuffer<f32>,
1400 seq_len: usize,
1401 stream: &CudaStream,
1402 grad_ws: &mut CudaGradWorkspace,
1403 ) -> Result<()> {
1404 let hidden_size = self.config.hidden_size;
1405 let q_dim = self.config.q_dim();
1406 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
1407 let num_heads = self.config.num_attention_heads;
1408 let num_kv_heads = self.config.num_kv_heads;
1409 let head_dim = self.config.head_dim();
1410 let heads_per_kv = num_heads / num_kv_heads;
1411 let scale = 1.0 / (head_dim as f32).sqrt();
1412
1413 let seq = saturating_u32(seq_len);
1414 let nh = saturating_u32(num_heads);
1415 let nkv = saturating_u32(num_kv_heads);
1416 let hd = saturating_u32(head_dim);
1417
1418 gemm_backward_a(
1423 grad_input,
1424 &self.w_o,
1425 &mut self.scratch.grad_hidden,
1426 seq,
1427 saturating_u32(q_dim),
1428 saturating_u32(hidden_size),
1429 stream,
1430 )?;
1431
1432 gemm_backward_b(
1434 &self.scratch.attn_out,
1435 grad_input,
1436 &mut grad_ws.grad_w_o,
1437 seq,
1438 saturating_u32(q_dim),
1439 saturating_u32(hidden_size),
1440 stream,
1441 )?;
1442
1443 interleaved_to_batched_forward(
1447 &self.scratch.grad_hidden,
1448 &mut self.scratch.attn_q_batched,
1449 seq,
1450 nh,
1451 hd,
1452 stream,
1453 )?;
1454
1455 interleaved_to_batched_forward(
1459 &self.scratch.v,
1460 &mut self.scratch.attn_kv_temp,
1461 seq,
1462 nkv,
1463 hd,
1464 stream,
1465 )?;
1466
1467 if heads_per_kv > 1 {
1469 expand_kv_heads(
1470 &self.scratch.attn_kv_temp,
1471 &mut self.scratch.attn_kv_temp2,
1472 num_kv_heads,
1473 heads_per_kv,
1474 seq_len * head_dim,
1475 stream,
1476 )?;
1477 unsafe {
1479 self.scratch
1480 .attn_kv_temp
1481 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1482 .map_err(|e| {
1483 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1484 "Attn backward V expand D2D copy failed: {e}"
1485 ))
1486 })?;
1487 }
1488 }
1489 batched_transpose_forward(
1493 &self.scratch.attn_kv_temp,
1494 &mut self.scratch.attn_kv_temp2,
1495 nh,
1496 seq,
1497 hd,
1498 stream,
1499 )?;
1500 batched_4d_gemm_forward(
1504 &self.scratch.attn_q_batched,
1505 &self.scratch.attn_kv_temp2,
1506 &mut self.scratch.grad_attn_scores,
1507 1,
1508 nh,
1509 seq,
1510 seq,
1511 hd,
1512 stream,
1513 )?;
1514
1515 batched_transpose_forward(
1525 &self.scratch.attn_q_batched, &mut self.scratch.attn_kv_temp, nh,
1528 seq,
1529 hd,
1530 stream,
1531 )?;
1532
1533 batched_4d_gemm_forward(
1535 &self.scratch.attn_kv_temp, &self.scratch.attn_scores, &mut self.scratch.attn_kv_temp2, 1,
1539 nh,
1540 hd, seq, seq, stream,
1544 )?;
1545
1546 batched_transpose_forward(
1548 &self.scratch.attn_kv_temp2, &mut self.scratch.attn_kv_temp, nh,
1551 hd,
1552 seq,
1553 stream,
1554 )?;
1555 let total_rows = nh * seq;
1562 {
1563 let grad_scores_view = unsafe {
1567 GpuBuffer::<f32>::from_raw_parts(
1568 self.scratch.grad_attn_scores.as_ptr(),
1569 self.scratch.grad_attn_scores.len(),
1570 )
1571 };
1572 batched_softmax_backward(
1573 &self.scratch.attn_scores,
1574 &grad_scores_view,
1575 &mut self.scratch.grad_attn_scores,
1576 total_rows,
1577 seq,
1578 stream,
1579 )?;
1580 leak(grad_scores_view);
1581 }
1582 let total_scores = nh * seq * seq;
1587 {
1588 let scores_view = unsafe {
1590 GpuBuffer::<f32>::from_raw_parts(
1591 self.scratch.grad_attn_scores.as_ptr(),
1592 self.scratch.grad_attn_scores.len(),
1593 )
1594 };
1595 scale_forward(
1596 &scores_view,
1597 &mut self.scratch.grad_attn_scores,
1598 scale,
1599 total_scores,
1600 stream,
1601 )?;
1602 leak(scores_view);
1603 }
1604
1605 interleaved_to_batched_forward(
1609 &self.scratch.k,
1610 &mut self.scratch.attn_kv_temp2,
1611 seq,
1612 nkv,
1613 hd,
1614 stream,
1615 )?;
1616
1617 if heads_per_kv > 1 {
1618 unsafe {
1621 self.scratch
1622 .attn_q_batched
1623 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1624 .map_err(|e| {
1625 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1626 "Attn backward K copy for GQA expand failed: {e}"
1627 ))
1628 })?;
1629 }
1630 expand_kv_heads(
1631 &self.scratch.attn_q_batched,
1632 &mut self.scratch.attn_kv_temp2,
1633 num_kv_heads,
1634 heads_per_kv,
1635 seq_len * head_dim,
1636 stream,
1637 )?;
1638 }
1639 batched_4d_gemm_forward(
1643 &self.scratch.grad_attn_scores,
1644 &self.scratch.attn_kv_temp2,
1645 &mut self.scratch.attn_q_batched,
1646 1,
1647 nh,
1648 seq,
1649 hd,
1650 seq,
1651 stream,
1652 )?;
1653
1654 interleaved_to_batched_forward(
1658 &self.scratch.q,
1659 &mut self.scratch.o_proj_out, seq,
1661 nh,
1662 hd,
1663 stream,
1664 )?;
1665
1666 batched_transpose_forward(
1668 &self.scratch.o_proj_out,
1669 &mut self.scratch.attn_kv_temp2, nh,
1671 seq,
1672 hd,
1673 stream,
1674 )?;
1675
1676 batched_4d_gemm_forward(
1678 &self.scratch.attn_kv_temp2,
1679 &self.scratch.grad_attn_scores,
1680 &mut self.scratch.ffn_out, 1,
1682 nh,
1683 hd,
1684 seq,
1685 seq,
1686 stream,
1687 )?;
1688
1689 batched_transpose_forward(
1691 &self.scratch.ffn_out,
1692 &mut self.scratch.attn_kv_temp2, nh,
1694 hd,
1695 seq,
1696 stream,
1697 )?;
1698
1699 if heads_per_kv > 1 {
1702 self.reduce_gqa_gradients(num_kv_heads, heads_per_kv, seq_len, head_dim, stream)?;
1703 }
1704
1705 batched_to_interleaved_forward(
1708 &self.scratch.attn_q_batched,
1709 &mut self.scratch.o_proj_out,
1710 seq,
1711 nh,
1712 hd,
1713 stream,
1714 )?;
1715
1716 batched_to_interleaved_forward(
1718 &self.scratch.attn_kv_temp2,
1719 &mut self.scratch.norm2_out,
1720 seq,
1721 nkv,
1722 hd,
1723 stream,
1724 )?;
1725
1726 batched_to_interleaved_forward(
1728 &self.scratch.attn_kv_temp,
1729 &mut self.scratch.ffn_out,
1730 seq,
1731 nkv,
1732 hd,
1733 stream,
1734 )?;
1735
1736 let rope_theta = self.config.rope_theta;
1741 {
1742 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.o_proj_out)) };
1744 batched_rope_neox_backward(
1745 q_ref,
1746 &mut self.scratch.o_proj_out,
1747 &self.scratch.rope_positions,
1748 nh,
1749 hd,
1750 seq,
1751 rope_theta,
1752 stream,
1753 )?;
1754 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.norm2_out)) };
1756 batched_rope_neox_backward(
1757 k_ref,
1758 &mut self.scratch.norm2_out,
1759 &self.scratch.rope_positions,
1760 nkv,
1761 hd,
1762 seq,
1763 rope_theta,
1764 stream,
1765 )?;
1766 }
1767
1768 gemm_backward_a(
1773 &self.scratch.o_proj_out, &self.w_q,
1775 &mut self.scratch.grad_hidden,
1776 seq,
1777 saturating_u32(hidden_size),
1778 saturating_u32(q_dim),
1779 stream,
1780 )?;
1781
1782 gemm_backward_a(
1787 &self.scratch.norm2_out, &self.w_k,
1789 &mut self.scratch.grad_attn_scores, seq,
1791 saturating_u32(hidden_size),
1792 saturating_u32(kv_hidden_size),
1793 stream,
1794 )?;
1795 cuda_add_inplace(
1796 &mut self.scratch.grad_hidden,
1797 &self.scratch.grad_attn_scores,
1798 seq_len * hidden_size,
1799 stream,
1800 )?;
1801
1802 gemm_backward_a(
1806 &self.scratch.ffn_out, &self.w_v,
1808 &mut self.scratch.grad_attn_scores, seq,
1810 saturating_u32(hidden_size),
1811 saturating_u32(kv_hidden_size),
1812 stream,
1813 )?;
1814 cuda_add_inplace(
1815 &mut self.scratch.grad_hidden,
1816 &self.scratch.grad_attn_scores,
1817 seq_len * hidden_size,
1818 stream,
1819 )?;
1820
1821 gemm_backward_b(
1823 &self.scratch.norm1_out,
1824 &self.scratch.o_proj_out, &mut grad_ws.grad_w_q,
1826 seq,
1827 saturating_u32(hidden_size),
1828 saturating_u32(q_dim),
1829 stream,
1830 )?;
1831
1832 gemm_backward_b(
1834 &self.scratch.norm1_out,
1835 &self.scratch.norm2_out, &mut grad_ws.grad_w_k,
1837 seq,
1838 saturating_u32(hidden_size),
1839 saturating_u32(kv_hidden_size),
1840 stream,
1841 )?;
1842
1843 gemm_backward_b(
1845 &self.scratch.norm1_out,
1846 &self.scratch.ffn_out, &mut grad_ws.grad_w_v,
1848 seq,
1849 saturating_u32(hidden_size),
1850 saturating_u32(kv_hidden_size),
1851 stream,
1852 )?;
1853
1854 unsafe {
1857 grad_input.copy_from_buffer_async(&self.scratch.grad_hidden, stream).map_err(|e| {
1858 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1859 "Attn backward grad_hidden → grad_input D2D copy failed: {e}"
1860 ))
1861 })?;
1862 }
1863
1864 Ok(())
1865 }
1866
1867 fn reduce_gqa_gradients(
1875 &mut self,
1876 num_kv_heads: usize,
1877 heads_per_kv: usize,
1878 seq_len: usize,
1879 head_dim: usize,
1880 stream: &CudaStream,
1881 ) -> Result<()> {
1882 let elems_per_head = seq_len * head_dim;
1883
1884 self.reduce_single_gqa_gradient(true, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1886
1887 self.reduce_single_gqa_gradient(false, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1889
1890 let kv_elems = num_kv_heads * elems_per_head;
1892 unsafe {
1894 self.scratch
1895 .attn_kv_temp2
1896 .copy_from_buffer_at_async(&self.scratch.grad_attn_scores, 0, 0, kv_elems, stream)
1897 .map_err(|e| {
1898 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1899 "GQA grad_K reduced final copy failed: {e}"
1900 ))
1901 })?;
1902 self.scratch
1903 .attn_kv_temp
1904 .copy_from_buffer_at_async(&self.scratch.ffn_out, 0, 0, kv_elems, stream)
1905 .map_err(|e| {
1906 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1907 "GQA grad_V reduced final copy failed: {e}"
1908 ))
1909 })?;
1910 }
1911 Ok(())
1912 }
1913
1914 fn reduce_single_gqa_gradient(
1920 &mut self,
1921 is_k: bool,
1922 num_kv_heads: usize,
1923 heads_per_kv: usize,
1924 elems_per_head: usize,
1925 stream: &CudaStream,
1926 ) -> Result<()> {
1927 let label = if is_k { "K" } else { "V" };
1928
1929 for kv_h in 0..num_kv_heads {
1930 let dst_offset = kv_h * elems_per_head;
1931 let first_h = kv_h * heads_per_kv;
1932 let src_offset = first_h * elems_per_head;
1933
1934 unsafe {
1937 let (dst, src) = if is_k {
1938 (&mut self.scratch.grad_attn_scores, &self.scratch.attn_kv_temp2)
1939 } else {
1940 (&mut self.scratch.ffn_out, &self.scratch.attn_kv_temp)
1941 };
1942 dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
1943 .map_err(|e| {
1944 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1945 "GQA grad_{label} reduce base copy failed: {e}"
1946 ))
1947 })?;
1948 }
1949
1950 for rep in 1..heads_per_kv {
1952 let h = kv_h * heads_per_kv + rep;
1953 let h_offset = h * elems_per_head;
1954
1955 unsafe {
1958 let src =
1959 if is_k { &self.scratch.attn_kv_temp2 } else { &self.scratch.attn_kv_temp };
1960 self.scratch
1961 .o_proj_out
1962 .copy_from_buffer_at_async(src, 0, h_offset, elems_per_head, stream)
1963 .map_err(|e| {
1964 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1965 "GQA grad_{label} reduce head copy failed: {e}"
1966 ))
1967 })?;
1968 }
1969
1970 unsafe {
1973 let dst_buf =
1974 if is_k { &self.scratch.grad_attn_scores } else { &self.scratch.ffn_out };
1975 let dst_view = GpuBuffer::<f32>::from_raw_parts(
1976 dst_buf.as_ptr() + (dst_offset as u64 * 4),
1977 elems_per_head,
1978 );
1979 let src_view = GpuBuffer::<f32>::from_raw_parts(
1980 self.scratch.o_proj_out.as_ptr(),
1981 elems_per_head,
1982 );
1983 let mut sum_view = GpuBuffer::<f32>::from_raw_parts(
1984 self.scratch.grad_hidden.as_ptr(),
1985 elems_per_head,
1986 );
1987 residual_add_forward(
1988 &dst_view,
1989 &src_view,
1990 &mut sum_view,
1991 saturating_u32(elems_per_head),
1992 stream,
1993 )?;
1994 let dst_buf = if is_k {
1996 &mut self.scratch.grad_attn_scores
1997 } else {
1998 &mut self.scratch.ffn_out
1999 };
2000 dst_buf
2001 .copy_from_buffer_at_async(
2002 &self.scratch.grad_hidden,
2003 dst_offset,
2004 0,
2005 elems_per_head,
2006 stream,
2007 )
2008 .map_err(|e| {
2009 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2010 "GQA grad_{label} reduce sum copy failed: {e}"
2011 ))
2012 })?;
2013 leak(dst_view);
2014 leak(src_view);
2015 leak(sum_view);
2016 }
2017 }
2018 }
2019 Ok(())
2020 }
2021
2022 fn backward_residual_and_input_norm(
2024 &mut self,
2025 input: &GpuBuffer<f32>,
2026 grad_output: &GpuBuffer<f32>,
2027 grad_input: &mut GpuBuffer<f32>,
2028 seq_len: usize,
2029 hidden_size: usize,
2030 eps: f32,
2031 stream: &CudaStream,
2032 grad_ws: &mut CudaGradWorkspace,
2033 ) -> Result<()> {
2034 unsafe {
2045 self.scratch.grad_hidden.copy_from_buffer_async(grad_input, stream).map_err(|e| {
2046 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2047 "Backward residual grad_hidden D2D copy failed: {e}"
2048 ))
2049 })?;
2050 }
2051
2052 rms_norm_backward(
2054 input,
2055 &self.input_norm_weight,
2056 &self.scratch.grad_hidden,
2057 grad_input,
2058 &mut grad_ws.grad_input_norm,
2059 saturating_u32(seq_len),
2060 saturating_u32(hidden_size),
2061 eps,
2062 stream,
2063 )?;
2064
2065 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)
2067 }
2068
2069 pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2081 let hidden = self.config.hidden_size;
2082 let q_dim = self.config.q_dim();
2083 let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
2084 let intermediate = self.config.intermediate_size;
2085
2086 let z = |n: usize| -> Result<GpuBuffer<f32>> {
2090 Ok(GpuBuffer::from_host(&self.ctx, &vec![0.0f32; n])?)
2091 };
2092 Ok(GpuBlockOptimizerState {
2093 m_w_q: z(q_dim * hidden)?,
2094 v_w_q: z(q_dim * hidden)?,
2095 m_w_k: z(hidden * kv_hidden)?,
2096 v_w_k: z(hidden * kv_hidden)?,
2097 m_w_v: z(hidden * kv_hidden)?,
2098 v_w_v: z(hidden * kv_hidden)?,
2099 m_w_o: z(hidden * q_dim)?,
2100 v_w_o: z(hidden * q_dim)?,
2101 m_w_gate: z(hidden * intermediate)?,
2102 v_w_gate: z(hidden * intermediate)?,
2103 m_w_up: z(hidden * intermediate)?,
2104 v_w_up: z(hidden * intermediate)?,
2105 m_w_down: z(intermediate * hidden)?,
2106 v_w_down: z(intermediate * hidden)?,
2107 m_input_norm: z(hidden)?,
2108 v_input_norm: z(hidden)?,
2109 m_post_attn_norm: z(hidden)?,
2110 v_post_attn_norm: z(hidden)?,
2111 })
2112 }
2113
2114 pub fn optimizer_step(
2127 &mut self,
2128 state: &mut GpuBlockOptimizerState,
2129 step: u32,
2130 lr: f32,
2131 beta1: f32,
2132 beta2: f32,
2133 eps: f32,
2134 weight_decay: f32,
2135 stream: &CudaStream,
2136 grad_ws: &CudaGradWorkspace,
2137 ) -> Result<()> {
2138 debug_assert!(step > 0, "C-OPTSTEP-001: step must be > 0 for bias adjust");
2139
2140 let n_wq = self.w_q.len() as u32;
2143 let n_wk = self.w_k.len() as u32;
2144 let n_wv = self.w_v.len() as u32;
2145 let n_wo = self.w_o.len() as u32;
2146 let n_gate = self.w_gate.len() as u32;
2147 let n_up = self.w_up.len() as u32;
2148 let n_down = self.w_down.len() as u32;
2149 let n_inorm = self.input_norm_weight.len() as u32;
2150 let n_panorm = self.post_attn_norm_weight.len() as u32;
2151
2152 adamw_step_cuda(
2154 &mut self.w_q,
2155 &grad_ws.grad_w_q,
2156 &mut state.m_w_q,
2157 &mut state.v_w_q,
2158 lr,
2159 beta1,
2160 beta2,
2161 eps,
2162 weight_decay,
2163 step,
2164 n_wq,
2165 stream,
2166 )?;
2167 adamw_step_cuda(
2168 &mut self.w_k,
2169 &grad_ws.grad_w_k,
2170 &mut state.m_w_k,
2171 &mut state.v_w_k,
2172 lr,
2173 beta1,
2174 beta2,
2175 eps,
2176 weight_decay,
2177 step,
2178 n_wk,
2179 stream,
2180 )?;
2181 adamw_step_cuda(
2182 &mut self.w_v,
2183 &grad_ws.grad_w_v,
2184 &mut state.m_w_v,
2185 &mut state.v_w_v,
2186 lr,
2187 beta1,
2188 beta2,
2189 eps,
2190 weight_decay,
2191 step,
2192 n_wv,
2193 stream,
2194 )?;
2195 adamw_step_cuda(
2196 &mut self.w_o,
2197 &grad_ws.grad_w_o,
2198 &mut state.m_w_o,
2199 &mut state.v_w_o,
2200 lr,
2201 beta1,
2202 beta2,
2203 eps,
2204 weight_decay,
2205 step,
2206 n_wo,
2207 stream,
2208 )?;
2209
2210 adamw_step_cuda(
2212 &mut self.w_gate,
2213 &grad_ws.grad_gate,
2214 &mut state.m_w_gate,
2215 &mut state.v_w_gate,
2216 lr,
2217 beta1,
2218 beta2,
2219 eps,
2220 weight_decay,
2221 step,
2222 n_gate,
2223 stream,
2224 )?;
2225 adamw_step_cuda(
2226 &mut self.w_up,
2227 &grad_ws.grad_up,
2228 &mut state.m_w_up,
2229 &mut state.v_w_up,
2230 lr,
2231 beta1,
2232 beta2,
2233 eps,
2234 weight_decay,
2235 step,
2236 n_up,
2237 stream,
2238 )?;
2239 adamw_step_cuda(
2240 &mut self.w_down,
2241 &grad_ws.grad_down,
2242 &mut state.m_w_down,
2243 &mut state.v_w_down,
2244 lr,
2245 beta1,
2246 beta2,
2247 eps,
2248 weight_decay,
2249 step,
2250 n_down,
2251 stream,
2252 )?;
2253
2254 adamw_step_cuda(
2256 &mut self.input_norm_weight,
2257 &grad_ws.grad_input_norm,
2258 &mut state.m_input_norm,
2259 &mut state.v_input_norm,
2260 lr,
2261 beta1,
2262 beta2,
2263 eps,
2264 weight_decay,
2265 step,
2266 n_inorm,
2267 stream,
2268 )?;
2269 adamw_step_cuda(
2270 &mut self.post_attn_norm_weight,
2271 &grad_ws.grad_post_attn_norm,
2272 &mut state.m_post_attn_norm,
2273 &mut state.v_post_attn_norm,
2274 lr,
2275 beta1,
2276 beta2,
2277 eps,
2278 weight_decay,
2279 step,
2280 n_panorm,
2281 stream,
2282 )?;
2283
2284 Ok(())
2285 }
2286
2287 pub fn download_weights(&self) -> Result<BlockWeights> {
2297 let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
2298 let mut host = vec![0.0f32; buf.len()];
2299 buf.copy_to_host(&mut host).map_err(|e| {
2300 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2301 "Weight download failed: {e}"
2302 ))
2303 })?;
2304 Ok(host)
2305 };
2306
2307 Ok(BlockWeights {
2308 w_q: download(&self.w_q)?,
2309 w_k: download(&self.w_k)?,
2310 w_v: download(&self.w_v)?,
2311 w_o: download(&self.w_o)?,
2312 w_gate: download(&self.w_gate)?,
2313 w_up: download(&self.w_up)?,
2314 w_down: download(&self.w_down)?,
2315 input_norm_weight: download(&self.input_norm_weight)?,
2316 post_attn_norm_weight: download(&self.post_attn_norm_weight)?,
2317 })
2318 }
2319}
2320
2321#[cfg(feature = "cuda")]
2327pub struct BlockWeights {
2328 pub w_q: Vec<f32>,
2329 pub w_k: Vec<f32>,
2330 pub w_v: Vec<f32>,
2331 pub w_o: Vec<f32>,
2332 pub w_gate: Vec<f32>,
2333 pub w_up: Vec<f32>,
2334 pub w_down: Vec<f32>,
2335 pub input_norm_weight: Vec<f32>,
2336 pub post_attn_norm_weight: Vec<f32>,
2337}
2338
2339#[cfg(feature = "cuda")]
2343fn cuda_add(
2344 a: &GpuBuffer<f32>,
2345 b: &GpuBuffer<f32>,
2346 output: &mut GpuBuffer<f32>,
2347 n: usize,
2348 stream: &CudaStream,
2349) -> Result<()> {
2350 residual_add_forward(a, b, output, saturating_u32(n), stream)
2351}
2352
2353#[cfg(feature = "cuda")]
2361pub(crate) fn cuda_add_inplace(
2362 target: &mut GpuBuffer<f32>,
2363 source: &GpuBuffer<f32>,
2364 n: usize,
2365 stream: &CudaStream,
2366) -> Result<()> {
2367 let target_ref: &GpuBuffer<f32> = unsafe { &*std::ptr::from_ref::<GpuBuffer<f32>>(target) };
2371 residual_add_forward(target_ref, source, target, saturating_u32(n), stream)
2372}
2373
2374#[cfg(feature = "cuda")]
2378fn cuda_mul(
2379 a: &GpuBuffer<f32>,
2380 b: &GpuBuffer<f32>,
2381 output: &mut GpuBuffer<f32>,
2382 n: usize,
2383 stream: &CudaStream,
2384) -> Result<()> {
2385 crate::autograd::cuda_forward::elementwise_mul_forward(a, b, output, saturating_u32(n), stream)
2386}
2387
2388#[cfg(not(feature = "cuda"))]
2390pub struct CudaTransformerBlock;
2391
2392#[cfg(not(feature = "cuda"))]
2393impl CudaTransformerBlock {
2394 pub fn layer_idx(&self) -> usize {
2395 0
2396 }
2397}
2398
2399#[cfg(feature = "cuda")]
2408pub enum CudaBlock {
2409 Fp32(CudaTransformerBlock),
2411 Nf4(CudaNf4TransformerBlock),
2413}
2414
2415#[cfg(feature = "cuda")]
2416impl CudaBlock {
2417 pub(crate) fn forward(
2422 &mut self,
2423 input: &GpuBuffer<f32>,
2424 output: &mut GpuBuffer<f32>,
2425 seq_len: usize,
2426 stream: &CudaStream,
2427 shared_scratch: Option<&mut CudaBlockScratch>,
2428 ) -> Result<()> {
2429 match self {
2430 CudaBlock::Fp32(b) => b.forward(input, output, seq_len, stream),
2431 CudaBlock::Nf4(b) => {
2432 let scratch =
2433 shared_scratch.expect("C-SCRATCH-001: NF4 blocks require shared scratch");
2434 b.forward(input, output, seq_len, stream, scratch)
2435 }
2436 }
2437 }
2438
2439 pub fn layer_idx(&self) -> usize {
2441 match self {
2442 CudaBlock::Fp32(b) => b.layer_idx(),
2443 CudaBlock::Nf4(b) => b.layer_idx,
2444 }
2445 }
2446
2447 pub fn backward(
2452 &mut self,
2453 input: &GpuBuffer<f32>,
2454 grad_output: &GpuBuffer<f32>,
2455 grad_input: &mut GpuBuffer<f32>,
2456 seq_len: usize,
2457 stream: &CudaStream,
2458 grad_ws: &mut CudaGradWorkspace,
2459 ) -> Result<()> {
2460 match self {
2461 CudaBlock::Fp32(b) => {
2462 b.backward(input, grad_output, grad_input, seq_len, stream, grad_ws)
2463 }
2464 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2465 "backward not supported on NF4 blocks (frozen weights)".into(),
2466 )),
2467 }
2468 }
2469
2470 pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2472 match self {
2473 CudaBlock::Fp32(b) => b.init_optimizer_state(),
2474 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2475 "init_optimizer_state not supported on NF4 blocks".into(),
2476 )),
2477 }
2478 }
2479
2480 pub fn download_weights(&self) -> Result<BlockWeights> {
2482 match self {
2483 CudaBlock::Fp32(b) => b.download_weights(),
2484 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2485 "download_weights not supported on NF4 blocks".into(),
2486 )),
2487 }
2488 }
2489
2490 pub fn optimizer_step(
2492 &mut self,
2493 state: &mut GpuBlockOptimizerState,
2494 step: u32,
2495 lr: f32,
2496 beta1: f32,
2497 beta2: f32,
2498 eps: f32,
2499 weight_decay: f32,
2500 stream: &CudaStream,
2501 grad_ws: &CudaGradWorkspace,
2502 ) -> Result<()> {
2503 match self {
2504 CudaBlock::Fp32(b) => {
2505 b.optimizer_step(state, step, lr, beta1, beta2, eps, weight_decay, stream, grad_ws)
2506 }
2507 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2508 "optimizer_step not supported on NF4 blocks (frozen weights)".into(),
2509 )),
2510 }
2511 }
2512
2513 #[allow(clippy::too_many_arguments)]
2517 pub(crate) fn backward_nf4(
2518 &self,
2519 layer_input: &GpuBuffer<f32>,
2520 grad_output: &GpuBuffer<f32>,
2521 grad_input: &mut GpuBuffer<f32>,
2522 output_scratch: &mut GpuBuffer<f32>,
2523 seq_len: usize,
2524 stream: &CudaStream,
2525 shared_scratch: &mut CudaBlockScratch,
2526 grad_lora: &mut CudaLoraGradWorkspace,
2527 ) -> Result<()> {
2528 match self {
2529 CudaBlock::Nf4(b) => b.backward(
2530 layer_input,
2531 grad_output,
2532 grad_input,
2533 output_scratch,
2534 seq_len,
2535 stream,
2536 shared_scratch,
2537 grad_lora,
2538 ),
2539 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2540 "backward_nf4 only supported on NF4 blocks".into(),
2541 )),
2542 }
2543 }
2544
2545 pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
2547 match self {
2548 CudaBlock::Nf4(b) => b.init_lora_optimizer_state(),
2549 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2550 "init_lora_optimizer_state only supported on NF4 blocks".into(),
2551 )),
2552 }
2553 }
2554
2555 #[allow(clippy::too_many_arguments)]
2557 pub(crate) fn lora_optimizer_step(
2558 &mut self,
2559 state: &mut GpuLoraOptimizerState,
2560 step: u32,
2561 lr: f32,
2562 beta1: f32,
2563 beta2: f32,
2564 eps: f32,
2565 weight_decay: f32,
2566 stream: &CudaStream,
2567 grad_lora: &CudaLoraGradWorkspace,
2568 ) -> Result<()> {
2569 match self {
2570 CudaBlock::Nf4(b) => b.lora_optimizer_step(
2571 state,
2572 step,
2573 lr,
2574 beta1,
2575 beta2,
2576 eps,
2577 weight_decay,
2578 stream,
2579 grad_lora,
2580 ),
2581 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2582 "lora_optimizer_step only supported on NF4 blocks".into(),
2583 )),
2584 }
2585 }
2586
2587 pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2589 match self {
2590 CudaBlock::Nf4(b) => b.download_lora_weights(),
2591 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2592 "download_lora_weights only supported on NF4 blocks".into(),
2593 )),
2594 }
2595 }
2596
2597 pub fn upload_lora_weights(
2599 &mut self,
2600 a_q: &[f32],
2601 b_q: &[f32],
2602 a_v: &[f32],
2603 b_v: &[f32],
2604 ) -> Result<()> {
2605 match self {
2606 CudaBlock::Nf4(b) => b.upload_lora_weights(a_q, b_q, a_v, b_v),
2607 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2608 "upload_lora_weights only supported on NF4 blocks".into(),
2609 )),
2610 }
2611 }
2612}
2613
2614#[cfg(not(feature = "cuda"))]
2616pub enum CudaBlock {
2617 Fp32(CudaTransformerBlock),
2618}
2619
2620#[cfg(feature = "cuda")]
2641pub struct CudaNf4TransformerBlock {
2642 config: TransformerConfig,
2643 layer_idx: usize,
2644 input_norm_weight: GpuBuffer<f32>,
2646 post_attn_norm_weight: GpuBuffer<f32>,
2647 w_q_nf4: GpuBuffer<u8>,
2649 w_q_scales: GpuBuffer<f32>,
2650 w_k_nf4: GpuBuffer<u8>,
2651 w_k_scales: GpuBuffer<f32>,
2652 w_v_nf4: GpuBuffer<u8>,
2653 w_v_scales: GpuBuffer<f32>,
2654 w_o_nf4: GpuBuffer<u8>,
2655 w_o_scales: GpuBuffer<f32>,
2656 w_gate_nf4: GpuBuffer<u8>,
2657 w_gate_scales: GpuBuffer<f32>,
2658 w_up_nf4: GpuBuffer<u8>,
2659 w_up_scales: GpuBuffer<f32>,
2660 w_down_nf4: GpuBuffer<u8>,
2661 w_down_scales: GpuBuffer<f32>,
2662 w_q_fp32: GpuBuffer<f32>,
2664 w_k_fp32: GpuBuffer<f32>,
2665 w_v_fp32: GpuBuffer<f32>,
2666 w_o_fp32: GpuBuffer<f32>,
2667 w_gate_fp32: GpuBuffer<f32>,
2668 w_up_fp32: GpuBuffer<f32>,
2669 w_down_fp32: GpuBuffer<f32>,
2670 lora_a_q: Option<GpuBuffer<f32>>, lora_b_q: Option<GpuBuffer<f32>>, lora_a_v: Option<GpuBuffer<f32>>, lora_b_v: Option<GpuBuffer<f32>>, lora_scale: f32,
2677 lora_rank: usize,
2678 q_norm_weight: Option<GpuBuffer<f32>>,
2680 k_norm_weight: Option<GpuBuffer<f32>>,
2681 w_q_fp16: Option<GpuBuffer<u16>>,
2684 w_k_fp16: Option<GpuBuffer<u16>>,
2685 w_v_fp16: Option<GpuBuffer<u16>>,
2686 w_o_fp16: Option<GpuBuffer<u16>>,
2687 w_gate_fp16: Option<GpuBuffer<u16>>,
2688 w_up_fp16: Option<GpuBuffer<u16>>,
2689 w_down_fp16: Option<GpuBuffer<u16>>,
2690 ctx: Arc<CudaContext>,
2691 }
2693
2694#[cfg(feature = "cuda")]
2695impl CudaNf4TransformerBlock {
2696 #[allow(clippy::too_many_arguments)]
2701 pub fn new(
2702 config: &TransformerConfig,
2703 layer_idx: usize,
2704 ctx: Arc<CudaContext>,
2705 input_norm_weight: &[f32],
2706 post_attn_norm_weight: &[f32],
2707 w_q: &[f32],
2708 w_k: &[f32],
2709 w_v: &[f32],
2710 w_o: &[f32],
2711 w_gate: &[f32],
2712 w_up: &[f32],
2713 w_down: &[f32],
2714 _max_seq_len: usize, q_lora: Option<(&[f32], &[f32])>,
2717 v_lora: Option<(&[f32], &[f32])>,
2718 lora_scale: f32,
2719 lora_rank: usize,
2720 q_norm: Option<&[f32]>,
2722 k_norm: Option<&[f32]>,
2723 ) -> Result<Self> {
2724 use trueno_gpu::kernels::{quantize_nf4, NF4_BLOCK_SIZE};
2725
2726 let hidden_size = config.hidden_size;
2727 let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
2729 let intermediate_size = config.intermediate_size;
2730
2731 assert_eq!(
2736 w_q.len(),
2737 q_dim * hidden_size,
2738 "C-NF4SHAPE-001: w_q expected {}, got {} (q_dim={q_dim}, hidden={hidden_size})",
2739 q_dim * hidden_size,
2740 w_q.len()
2741 );
2742 assert_eq!(
2743 w_k.len(),
2744 kv_hidden_size * hidden_size,
2745 "C-NF4SHAPE-001: w_k expected {}, got {}",
2746 kv_hidden_size * hidden_size,
2747 w_k.len()
2748 );
2749 assert_eq!(
2750 w_v.len(),
2751 kv_hidden_size * hidden_size,
2752 "C-NF4SHAPE-001: w_v expected {}, got {}",
2753 kv_hidden_size * hidden_size,
2754 w_v.len()
2755 );
2756 assert_eq!(
2757 w_o.len(),
2758 hidden_size * q_dim,
2759 "C-NF4SHAPE-001: w_o expected {}, got {}",
2760 hidden_size * q_dim,
2761 w_o.len()
2762 );
2763 assert_eq!(
2764 w_gate.len(),
2765 intermediate_size * hidden_size,
2766 "C-NF4SHAPE-001: w_gate expected {}, got {}",
2767 intermediate_size * hidden_size,
2768 w_gate.len()
2769 );
2770 assert_eq!(
2771 w_up.len(),
2772 intermediate_size * hidden_size,
2773 "C-NF4SHAPE-001: w_up expected {}, got {}",
2774 intermediate_size * hidden_size,
2775 w_up.len()
2776 );
2777 assert_eq!(
2778 w_down.len(),
2779 hidden_size * intermediate_size,
2780 "C-NF4SHAPE-001: w_down expected {}, got {}",
2781 hidden_size * intermediate_size,
2782 w_down.len()
2783 );
2784
2785 let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
2787 let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
2788
2789 let quantize_and_upload = |weights: &[f32],
2793 total: usize|
2794 -> Result<(
2795 GpuBuffer<u8>,
2796 GpuBuffer<f32>,
2797 trueno_gpu::kernels::Nf4Quantized,
2798 )> {
2799 assert_eq!(weights.len(), total, "weight length mismatch");
2800 assert!(
2801 total.is_multiple_of(NF4_BLOCK_SIZE),
2802 "weight count {total} not divisible by NF4 block size {NF4_BLOCK_SIZE}"
2803 );
2804
2805 let q = quantize_nf4(weights, total / NF4_BLOCK_SIZE, NF4_BLOCK_SIZE);
2806 let nf4_buf = GpuBuffer::from_host(&ctx, &q.data)?;
2807 let scales_buf = GpuBuffer::from_host(&ctx, &q.scales)?;
2808 Ok((nf4_buf, scales_buf, q))
2809 };
2810
2811 let (w_q_nf4, w_q_scales, w_q_nf4_q) = quantize_and_upload(w_q, q_dim * hidden_size)?;
2813 let (w_k_nf4, w_k_scales, w_k_nf4_q) =
2814 quantize_and_upload(w_k, kv_hidden_size * hidden_size)?;
2815 let (w_v_nf4, w_v_scales, w_v_nf4_q) =
2816 quantize_and_upload(w_v, kv_hidden_size * hidden_size)?;
2817 let (w_o_nf4, w_o_scales, w_o_nf4_q) = quantize_and_upload(w_o, hidden_size * q_dim)?;
2818 let (w_gate_nf4, w_gate_scales, w_gate_nf4_q) =
2819 quantize_and_upload(w_gate, intermediate_size * hidden_size)?;
2820 let (w_up_nf4, w_up_scales, w_up_nf4_q) =
2821 quantize_and_upload(w_up, intermediate_size * hidden_size)?;
2822 let (w_down_nf4, w_down_scales, w_down_nf4_q) =
2823 quantize_and_upload(w_down, hidden_size * intermediate_size)?;
2824
2825 use trueno_gpu::kernels::dequantize_nf4;
2837 let dequant_transpose_upload = |q: &trueno_gpu::kernels::Nf4Quantized,
2838 n: usize,
2839 k: usize|
2840 -> std::result::Result<
2841 GpuBuffer<f32>,
2842 crate::autograd::cuda_tensor::CudaTensorError,
2843 > {
2844 let deq = dequantize_nf4(q); let nonzero = deq.iter().filter(|&&x| x != 0.0).count();
2846 eprintln!(
2847 "[TRACE] dequant n={n} k={k} len={} nonzero={nonzero} first5={:?}",
2848 deq.len(),
2849 &deq[..5.min(deq.len())]
2850 );
2851 assert_eq!(deq.len(), n * k, "dequant size mismatch: {} vs {}x{}", deq.len(), n, k);
2852 let mut transposed = vec![0.0f32; n * k];
2854 for row in 0..n {
2855 for col in 0..k {
2856 transposed[col * n + row] = deq[row * k + col];
2857 }
2858 }
2859 let buf = GpuBuffer::from_host(&ctx, &transposed).map_err(|e| {
2860 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2861 "dequant transpose upload: {e:?}"
2862 ))
2863 })?;
2864 let mut verify_full = vec![0.0f32; buf.len()];
2866 let verify_ok = buf.copy_to_host(&mut verify_full).is_ok();
2867 let verify5: Vec<f32> = verify_full.iter().copied().take(5).collect();
2868 let nz = verify_full.iter().filter(|&&x| x != 0.0).count();
2869 eprintln!("[TRACE] uploaded ptr={:?} len={} copy_ok={verify_ok} nonzero={nz} verify[:5]={verify5:?}", buf.as_ptr(), buf.len());
2870 Ok(buf)
2871 };
2872 let w_q_fp32 = dequant_transpose_upload(&w_q_nf4_q, q_dim, hidden_size)?;
2875 let w_k_fp32 = dequant_transpose_upload(&w_k_nf4_q, kv_hidden_size, hidden_size)?;
2876 let w_v_fp32 = dequant_transpose_upload(&w_v_nf4_q, kv_hidden_size, hidden_size)?;
2877 let w_o_fp32 = dequant_transpose_upload(&w_o_nf4_q, hidden_size, q_dim)?;
2878 let w_gate_fp32 = dequant_transpose_upload(&w_gate_nf4_q, intermediate_size, hidden_size)?;
2879 let w_up_fp32 = dequant_transpose_upload(&w_up_nf4_q, intermediate_size, hidden_size)?;
2880 let w_down_fp32 = dequant_transpose_upload(&w_down_nf4_q, hidden_size, intermediate_size)?;
2881
2882 let (lora_a_q, lora_b_q) = match q_lora {
2889 Some((a_data, b_data)) => {
2890 let a = GpuBuffer::from_host(&ctx, a_data)?;
2891 let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2892 let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2893 (Some(a), Some(b))
2894 }
2895 None => (None, None),
2896 };
2897 let (lora_a_v, lora_b_v) = match v_lora {
2898 Some((a_data, b_data)) => {
2899 let a = GpuBuffer::from_host(&ctx, a_data)?;
2900 let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2901 let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2902 (Some(a), Some(b))
2903 }
2904 None => (None, None),
2905 };
2906
2907 let q_norm_weight = match q_norm {
2909 Some(w) => {
2910 assert_eq!(
2911 w.len(),
2912 config.head_dim(),
2913 "ENT-270: q_norm weight expected [head_dim={}], got [{}]",
2914 config.head_dim(),
2915 w.len()
2916 );
2917 Some(GpuBuffer::from_host(&ctx, w)?)
2918 }
2919 None => None,
2920 };
2921 let k_norm_weight = match k_norm {
2922 Some(w) => {
2923 assert_eq!(
2924 w.len(),
2925 config.head_dim(),
2926 "ENT-270: k_norm weight expected [head_dim={}], got [{}]",
2927 config.head_dim(),
2928 w.len()
2929 );
2930 Some(GpuBuffer::from_host(&ctx, w)?)
2931 }
2932 None => None,
2933 };
2934
2935 Ok(Self {
2936 config: config.clone(),
2937 layer_idx,
2938 input_norm_weight,
2939 post_attn_norm_weight,
2940 w_q_nf4,
2941 w_q_scales,
2942 w_k_nf4,
2943 w_k_scales,
2944 w_v_nf4,
2945 w_v_scales,
2946 w_o_nf4,
2947 w_o_scales,
2948 w_gate_nf4,
2949 w_gate_scales,
2950 w_up_nf4,
2951 w_up_scales,
2952 w_down_nf4,
2953 w_down_scales,
2954 w_q_fp32,
2955 w_k_fp32,
2956 w_v_fp32,
2957 w_o_fp32,
2958 w_gate_fp32,
2959 w_up_fp32,
2960 w_down_fp32,
2961 lora_a_q,
2962 lora_b_q,
2963 lora_a_v,
2964 lora_b_v,
2965 lora_scale,
2966 lora_rank,
2967 q_norm_weight,
2968 k_norm_weight,
2969 w_q_fp16: None,
2971 w_k_fp16: None,
2972 w_v_fp16: None,
2973 w_o_fp16: None,
2974 w_gate_fp16: None,
2975 w_up_fp16: None,
2976 w_down_fp16: None,
2977 ctx,
2978 })
2979 }
2980
2981 pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()> {
2983 let cast_weight = |w_fp32: &GpuBuffer<f32>, ctx: &CudaContext| -> Result<GpuBuffer<u16>> {
2984 let n = w_fp32.len();
2985 let mut w_fp16 = GpuBuffer::<u16>::new(ctx, n)?;
2986 cast_f32_to_f16_gpu(w_fp32, &mut w_fp16, n as u32, stream)?;
2987 Ok(w_fp16)
2988 };
2989
2990 self.w_q_fp16 = Some(cast_weight(&self.w_q_fp32, &self.ctx)?);
2991 self.w_k_fp16 = Some(cast_weight(&self.w_k_fp32, &self.ctx)?);
2992 self.w_v_fp16 = Some(cast_weight(&self.w_v_fp32, &self.ctx)?);
2993 self.w_o_fp16 = Some(cast_weight(&self.w_o_fp32, &self.ctx)?);
2994 self.w_gate_fp16 = Some(cast_weight(&self.w_gate_fp32, &self.ctx)?);
2995 self.w_up_fp16 = Some(cast_weight(&self.w_up_fp32, &self.ctx)?);
2996 self.w_down_fp16 = Some(cast_weight(&self.w_down_fp32, &self.ctx)?);
2997
2998 stream.synchronize().map_err(|e| {
2999 crate::autograd::cuda_tensor::CudaTensorError::KernelError(format!(
3000 "FP16 weight cast sync failed: {e:?}"
3001 ))
3002 })?;
3003 let dummy = |ctx: &CudaContext| GpuBuffer::<f32>::new(ctx, 1).unwrap();
3006 self.w_q_fp32 = dummy(&self.ctx);
3007 self.w_k_fp32 = dummy(&self.ctx);
3008 self.w_v_fp32 = dummy(&self.ctx);
3009 self.w_o_fp32 = dummy(&self.ctx);
3010 self.w_gate_fp32 = dummy(&self.ctx);
3011 self.w_up_fp32 = dummy(&self.ctx);
3012 self.w_down_fp32 = dummy(&self.ctx);
3013 eprintln!("[FP16] Weights cast + fp32 dropped (~2.6 GB freed)");
3014
3015 Ok(())
3016 }
3017
3018 #[rustfmt::skip]
3020 pub(crate) fn forward(
3021 &self,
3022 input: &GpuBuffer<f32>,
3023 output: &mut GpuBuffer<f32>,
3024 seq_len: usize,
3025 stream: &CudaStream,
3026 scratch: &mut CudaBlockScratch,
3027 ) -> Result<()> {
3028 use crate::autograd::cuda_forward::{gemm_forward, gemm_nf4_forward, gemm_nf4_tc_forward};
3029
3030 let hidden_size = self.config.hidden_size;
3031 let q_dim = self.config.q_dim();
3032 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3033 let intermediate_size = self.config.intermediate_size;
3034
3035 scratch.prepare_causal_mask(seq_len, &self.ctx)?;
3037
3038 let _t = scratch.op_begin();
3040 rms_norm_forward(
3041 input,
3042 &self.input_norm_weight,
3043 &mut scratch.norm1_out,
3044 saturating_u32(seq_len),
3045 saturating_u32(hidden_size),
3046 stream,
3047 )?;
3048 scratch.op_end(_t, OP_RMSNORM_ATTN);
3049
3050 static USE_NF4_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3056 let nf4_gemm = *USE_NF4_GEMM.get_or_init(|| std::env::var("NF4_FUSED_GEMM").as_deref() == Ok("1"));
3057 static USE_NF4_TC_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3058 let nf4_tc_gemm = *USE_NF4_TC_GEMM.get_or_init(|| std::env::var("NF4_TC_GEMM").as_deref() == Ok("1"));
3059 static USE_FP16_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3060 let fp16_gemm = *USE_FP16_GEMM.get_or_init(|| std::env::var("FP16_GEMM").as_deref() == Ok("1"));
3061
3062 let act_n = (seq_len * hidden_size) as u32;
3064 if fp16_gemm && self.w_q_fp16.is_some() {
3065 if scratch.norm1_out_f16.is_none() {
3067 scratch.norm1_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3068 }
3069 let f16_buf = scratch.norm1_out_f16.as_mut().unwrap();
3070 cast_f32_to_f16_gpu(&scratch.norm1_out, f16_buf, act_n, stream)?;
3071 }
3072
3073 let _t = scratch.op_begin(); if fp16_gemm && self.w_q_fp16.is_some() {
3075 let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3076 gemm_f16_to_f32_forward(f16_act, self.w_q_fp16.as_ref().unwrap(), &mut scratch.q,
3077 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3078 } else if nf4_tc_gemm {
3079 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3080 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3081 } else if nf4_gemm {
3082 gemm_nf4_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3083 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3084 } else {
3085 gemm_forward(&scratch.norm1_out, &self.w_q_fp32, &mut scratch.q,
3086 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3087 }
3088
3089 if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
3091 let s = saturating_u32(seq_len);
3092 let h = saturating_u32(hidden_size);
3093 let r = saturating_u32(self.lora_rank);
3094 let qd = saturating_u32(q_dim);
3095 gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
3097 gemm_forward(&scratch.lora_inter, b_q, &mut scratch.lora_temp, s, r, qd, stream)?;
3099 cuda_add_inplace(&mut scratch.q, &scratch.lora_temp, seq_len * q_dim, stream)?;
3101 }
3102
3103 if fp16_gemm && self.w_k_fp16.is_some() {
3104 let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3105 gemm_f16_to_f32_forward(f16_act, self.w_k_fp16.as_ref().unwrap(), &mut scratch.k,
3106 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3107 gemm_f16_to_f32_forward(f16_act, self.w_v_fp16.as_ref().unwrap(), &mut scratch.v,
3108 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3109 } else if nf4_tc_gemm {
3110 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_k_nf4, &self.w_k_scales, &mut scratch.k,
3112 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3113 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_v_nf4, &self.w_v_scales, &mut scratch.v,
3114 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3115 } else if nf4_gemm {
3116 crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3118 &scratch.norm1_out,
3119 &self.w_k_nf4, &self.w_k_scales,
3120 &self.w_v_nf4, &self.w_v_scales,
3121 &mut scratch.k, &mut scratch.v,
3122 saturating_u32(seq_len), saturating_u32(hidden_size),
3123 saturating_u32(kv_hidden_size), stream,
3124 )?;
3125 } else {
3126 gemm_forward(&scratch.norm1_out, &self.w_k_fp32, &mut scratch.k,
3127 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3128 gemm_forward(&scratch.norm1_out, &self.w_v_fp32, &mut scratch.v,
3129 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3130 }
3131
3132 scratch.op_end(_t, OP_QKV_GEMM); if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
3136 let s = saturating_u32(seq_len);
3137 let h = saturating_u32(hidden_size);
3138 let r = saturating_u32(self.lora_rank);
3139 let vd = saturating_u32(kv_hidden_size);
3140 gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
3142 gemm_forward(&scratch.lora_inter, b_v, &mut scratch.lora_temp, s, r, vd, stream)?;
3144 cuda_add_inplace(&mut scratch.v, &scratch.lora_temp, seq_len * kv_hidden_size, stream)?;
3146 }
3147
3148 let _t = scratch.op_begin();
3150 self.compute_attention_cuda(seq_len, stream, scratch)?;
3151 scratch.op_end(_t, OP_ATTENTION);
3152
3153 let _t = scratch.op_begin();
3155 if fp16_gemm && self.w_o_fp16.is_some() {
3156 if scratch.attn_out_f16.is_none() {
3157 scratch.attn_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * q_dim)?);
3158 }
3159 let f16_buf = scratch.attn_out_f16.as_mut().unwrap();
3160 cast_f32_to_f16_gpu(&scratch.attn_out, f16_buf, (seq_len * q_dim) as u32, stream)?;
3161 gemm_f16_to_f32_forward(f16_buf, self.w_o_fp16.as_ref().unwrap(), &mut scratch.o_proj_out,
3162 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3163 } else if nf4_tc_gemm {
3164 gemm_nf4_tc_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3165 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3166 } else if nf4_gemm {
3167 gemm_nf4_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3168 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3169 } else {
3170 gemm_forward(&scratch.attn_out, &self.w_o_fp32, &mut scratch.o_proj_out,
3171 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3172 }
3173
3174 scratch.op_end(_t, OP_O_PROJ);
3175
3176 let _t = scratch.op_begin();
3178 fused_residual_rmsnorm_forward(
3182 input,
3183 &scratch.o_proj_out,
3184 &mut scratch.residual1,
3185 &mut scratch.norm2_out,
3186 &self.post_attn_norm_weight,
3187 saturating_u32(seq_len),
3188 saturating_u32(hidden_size),
3189 stream,
3190 )?;
3191
3192 scratch.op_end(_t, OP_RMSNORM_FFN); let _t = scratch.op_begin(); if fp16_gemm && self.w_gate_fp16.is_some() {
3197 if scratch.norm2_out_f16.is_none() {
3198 scratch.norm2_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3199 }
3200 let f16_buf = scratch.norm2_out_f16.as_mut().unwrap();
3201 cast_f32_to_f16_gpu(&scratch.norm2_out, f16_buf, (seq_len * hidden_size) as u32, stream)?;
3202 gemm_f16_to_f32_forward(f16_buf, self.w_gate_fp16.as_ref().unwrap(), &mut scratch.gate_out,
3203 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3204 gemm_f16_to_f32_forward(f16_buf, self.w_up_fp16.as_ref().unwrap(), &mut scratch.up_out,
3205 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3206 } else if nf4_tc_gemm {
3207 gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_gate_nf4, &self.w_gate_scales, &mut scratch.gate_out,
3209 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3210 gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_up_nf4, &self.w_up_scales, &mut scratch.up_out,
3211 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3212 } else if nf4_gemm {
3213 crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3215 &scratch.norm2_out,
3216 &self.w_gate_nf4, &self.w_gate_scales,
3217 &self.w_up_nf4, &self.w_up_scales,
3218 &mut scratch.gate_out, &mut scratch.up_out,
3219 saturating_u32(seq_len), saturating_u32(hidden_size),
3220 saturating_u32(intermediate_size), stream,
3221 )?;
3222 } else {
3223 gemm_forward(&scratch.norm2_out, &self.w_gate_fp32, &mut scratch.gate_out,
3224 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3225 gemm_forward(&scratch.norm2_out, &self.w_up_fp32, &mut scratch.up_out,
3226 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3227 }
3228
3229 scratch.op_end(_t, OP_GATE_UP_GEMM);
3230
3231 let _t = scratch.op_begin();
3233 fused_swiglu_forward(&scratch.gate_out, &scratch.up_out, &mut scratch.swiglu_out,
3234 saturating_u32(seq_len * intermediate_size), stream)?;
3235 scratch.op_end(_t, OP_SILU);
3236
3237 let _t = scratch.op_begin();
3239 if fp16_gemm && self.w_down_fp16.is_some() {
3240 if scratch.swiglu_out_f16.is_none() {
3241 scratch.swiglu_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * intermediate_size)?);
3242 }
3243 let f16_buf = scratch.swiglu_out_f16.as_mut().unwrap();
3244 cast_f32_to_f16_gpu(&scratch.swiglu_out, f16_buf, (seq_len * intermediate_size) as u32, stream)?;
3245 gemm_f16_to_f32_forward(f16_buf, self.w_down_fp16.as_ref().unwrap(), &mut scratch.ffn_out,
3246 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3247 } else if nf4_tc_gemm {
3248 gemm_nf4_tc_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3249 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3250 } else if nf4_gemm {
3251 gemm_nf4_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3252 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3253 } else {
3254 gemm_forward(&scratch.swiglu_out, &self.w_down_fp32, &mut scratch.ffn_out,
3255 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3256 }
3257
3258 scratch.op_end(_t, OP_DOWN_GEMM);
3259
3260 cuda_add(&scratch.residual1, &scratch.ffn_out, output, seq_len * hidden_size, stream)?;
3262
3263 Ok(())
3264 }
3265
3266 pub fn layer_idx(&self) -> usize {
3268 self.layer_idx
3269 }
3270}
3271
3272#[cfg(feature = "cuda")]
3277impl CudaNf4TransformerBlock {
3278 fn compute_attention_cuda(
3279 &self,
3280 seq_len: usize,
3281 stream: &CudaStream,
3282 scratch: &mut CudaBlockScratch,
3283 ) -> Result<()> {
3284 let num_heads = self.config.num_attention_heads;
3285 let num_kv_heads = self.config.num_kv_heads;
3286 let head_dim = self.config.head_dim();
3287 let heads_per_kv = num_heads / num_kv_heads;
3288
3289 let s = saturating_u32(seq_len);
3290 let nh = saturating_u32(num_heads);
3291 let nkv = saturating_u32(num_kv_heads);
3292 let hd = saturating_u32(head_dim);
3293
3294 if let Some(ref q_norm) = self.q_norm_weight {
3299 for pos in 0..seq_len {
3300 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3301 per_head_rmsnorm_forward(q_ref, q_norm, &mut scratch.q, nh, hd, pos, stream)?;
3302 }
3303 }
3304 if let Some(ref k_norm) = self.k_norm_weight {
3305 for pos in 0..seq_len {
3306 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3307 per_head_rmsnorm_forward(k_ref, k_norm, &mut scratch.k, nkv, hd, pos, stream)?;
3308 }
3309 }
3310
3311 let rope_theta = self.config.rope_theta;
3314 {
3315 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3316 batched_rope_neox_forward(
3317 q_ref,
3318 &mut scratch.q,
3319 &scratch.rope_positions,
3320 nh,
3321 hd,
3322 s,
3323 rope_theta,
3324 stream,
3325 )?;
3326 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3327 batched_rope_neox_forward(
3328 k_ref,
3329 &mut scratch.k,
3330 &scratch.rope_positions,
3331 nkv,
3332 hd,
3333 s,
3334 rope_theta,
3335 stream,
3336 )?;
3337 }
3338
3339 interleaved_to_batched_forward(&scratch.q, &mut scratch.attn_q_batched, s, nh, hd, stream)?;
3341
3342 interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3344
3345 if heads_per_kv > 1 {
3346 expand_kv_heads(
3347 &scratch.attn_kv_temp,
3348 &mut scratch.attn_kv_temp2,
3349 num_kv_heads,
3350 heads_per_kv,
3351 seq_len * head_dim,
3352 stream,
3353 )?;
3354 } else {
3355 unsafe {
3357 scratch
3358 .attn_kv_temp2
3359 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3360 .map_err(|e| {
3361 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3362 "K copy failed: {e:?}"
3363 ))
3364 })?;
3365 }
3366 }
3367
3368 batched_transpose_forward(
3370 &scratch.attn_kv_temp2,
3371 &mut scratch.attn_kv_temp,
3372 nh,
3373 s,
3374 hd,
3375 stream,
3376 )?;
3377
3378 batched_4d_gemm_forward(
3380 &scratch.attn_q_batched,
3381 &scratch.attn_kv_temp,
3382 &mut scratch.attn_scores,
3383 1,
3384 nh,
3385 s,
3386 s,
3387 hd,
3388 stream,
3389 )?;
3390
3391 let scale_factor = 1.0 / (head_dim as f32).sqrt();
3393 let total_scores = num_heads * seq_len * seq_len;
3394 let scores_view = unsafe {
3395 GpuBuffer::<f32>::from_raw_parts(
3396 scratch.attn_scores.as_ptr(),
3397 scratch.attn_scores.len(),
3398 )
3399 };
3400 scale_forward(
3401 &scores_view,
3402 &mut scratch.attn_scores,
3403 scale_factor,
3404 saturating_u32(total_scores),
3405 stream,
3406 )?;
3407 leak(scores_view);
3408
3409 {
3415 let seq_sq = seq_len * seq_len;
3416 let mask_ptr = scratch.causal_mask_contiguous.as_ptr();
3417 let scores_base = scratch.attn_scores.as_ptr();
3418 for head in 0..num_heads {
3419 let byte_offset = (head * seq_sq * 4) as u64;
3420 let head_ptr = scores_base + byte_offset;
3421 let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
3422 let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3423 let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3424 residual_add_forward(
3425 &mask_view,
3426 &scores_view,
3427 &mut out_view,
3428 saturating_u32(seq_sq),
3429 stream,
3430 )?;
3431 leak(mask_view);
3432 leak(scores_view);
3433 leak(out_view);
3434 }
3435 }
3436
3437 let scores_view = unsafe {
3440 GpuBuffer::<f32>::from_raw_parts(
3441 scratch.attn_scores.as_ptr(),
3442 scratch.attn_scores.len(),
3443 )
3444 };
3445 batched_softmax_forward(
3446 &scores_view,
3447 &mut scratch.attn_scores,
3448 saturating_u32(num_heads * seq_len),
3449 s,
3450 stream,
3451 )?;
3452 leak(scores_view);
3453
3454 interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3456
3457 if heads_per_kv > 1 {
3458 expand_kv_heads(
3459 &scratch.attn_kv_temp,
3460 &mut scratch.attn_kv_temp2,
3461 num_kv_heads,
3462 heads_per_kv,
3463 seq_len * head_dim,
3464 stream,
3465 )?;
3466 } else {
3467 unsafe {
3471 scratch
3472 .attn_kv_temp2
3473 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3474 .map_err(|e| {
3475 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3476 "V copy failed: {e:?}"
3477 ))
3478 })?;
3479 }
3480 }
3481
3482 batched_4d_gemm_forward(
3484 &scratch.attn_scores,
3485 &scratch.attn_kv_temp2,
3486 &mut scratch.attn_q_batched,
3487 1,
3488 nh,
3489 s,
3490 hd,
3491 s,
3492 stream,
3493 )?;
3494
3495 batched_to_interleaved_forward(
3497 &scratch.attn_q_batched,
3498 &mut scratch.attn_out,
3499 s,
3500 nh,
3501 hd,
3502 stream,
3503 )?;
3504
3505 Ok(())
3506 }
3507}
3508
3509#[cfg(feature = "cuda")]
3525pub(crate) struct CudaLoraGradWorkspace {
3526 pub(crate) grad_lora_a_q: GpuBuffer<f32>,
3528 pub(crate) grad_lora_b_q: GpuBuffer<f32>,
3530 pub(crate) grad_lora_a_v: GpuBuffer<f32>,
3532 pub(crate) grad_lora_b_v: GpuBuffer<f32>,
3534 pub(crate) grad_input_norm: GpuBuffer<f32>,
3536 pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
3538}
3539
3540#[cfg(feature = "cuda")]
3541impl CudaLoraGradWorkspace {
3542 pub(crate) fn new(
3544 ctx: &Arc<CudaContext>,
3545 config: &super::config::TransformerConfig,
3546 lora_rank: usize,
3547 ) -> Result<Self> {
3548 let h = config.hidden_size;
3549 let q_dim = config.q_dim();
3550 let kv = config.num_kv_heads * config.head_dim();
3551 let r = lora_rank;
3552
3553 Ok(Self {
3554 grad_lora_a_q: GpuBuffer::new(ctx, h * r)?,
3555 grad_lora_b_q: GpuBuffer::new(ctx, r * q_dim)?,
3556 grad_lora_a_v: GpuBuffer::new(ctx, h * r)?,
3557 grad_lora_b_v: GpuBuffer::new(ctx, r * kv)?,
3558 grad_input_norm: GpuBuffer::new(ctx, h)?,
3559 grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
3560 })
3561 }
3562
3563 pub(crate) fn clip_gradients(&mut self, max_norm: f32, stream: &CudaStream) {
3573 let sq_a_q = squared_sum_cuda(&self.grad_lora_a_q, self.grad_lora_a_q.len() as u32, stream)
3575 .unwrap_or(0.0);
3576 let sq_b_q = squared_sum_cuda(&self.grad_lora_b_q, self.grad_lora_b_q.len() as u32, stream)
3577 .unwrap_or(0.0);
3578 let sq_a_v = squared_sum_cuda(&self.grad_lora_a_v, self.grad_lora_a_v.len() as u32, stream)
3579 .unwrap_or(0.0);
3580 let sq_b_v = squared_sum_cuda(&self.grad_lora_b_v, self.grad_lora_b_v.len() as u32, stream)
3581 .unwrap_or(0.0);
3582 let sq_in =
3583 squared_sum_cuda(&self.grad_input_norm, self.grad_input_norm.len() as u32, stream)
3584 .unwrap_or(0.0);
3585 let sq_pa = squared_sum_cuda(
3586 &self.grad_post_attn_norm,
3587 self.grad_post_attn_norm.len() as u32,
3588 stream,
3589 )
3590 .unwrap_or(0.0);
3591 let total_norm = (sq_a_q + sq_b_q + sq_a_v + sq_b_v + sq_in + sq_pa).sqrt();
3592
3593 if total_norm <= max_norm {
3594 return;
3595 }
3596
3597 let clip_scale = max_norm / (total_norm + 1e-6);
3599 let n_aq = self.grad_lora_a_q.len() as u32;
3600 let n_bq = self.grad_lora_b_q.len() as u32;
3601 let n_av = self.grad_lora_a_v.len() as u32;
3602 let n_bv = self.grad_lora_b_v.len() as u32;
3603 let n_in = self.grad_input_norm.len() as u32;
3604 let n_pa = self.grad_post_attn_norm.len() as u32;
3605 let _ = gradient_clip_cuda(&mut self.grad_lora_a_q, clip_scale, n_aq, stream);
3606 let _ = gradient_clip_cuda(&mut self.grad_lora_b_q, clip_scale, n_bq, stream);
3607 let _ = gradient_clip_cuda(&mut self.grad_lora_a_v, clip_scale, n_av, stream);
3608 let _ = gradient_clip_cuda(&mut self.grad_lora_b_v, clip_scale, n_bv, stream);
3609 let _ = gradient_clip_cuda(&mut self.grad_input_norm, clip_scale, n_in, stream);
3610 let _ = gradient_clip_cuda(&mut self.grad_post_attn_norm, clip_scale, n_pa, stream);
3611 }
3612}
3613
3614#[cfg(feature = "cuda")]
3626pub(crate) struct GpuLoraOptimizerState {
3627 m_lora_a_q: GpuBuffer<f32>,
3628 v_lora_a_q: GpuBuffer<f32>,
3629 m_lora_b_q: GpuBuffer<f32>,
3630 v_lora_b_q: GpuBuffer<f32>,
3631 m_lora_a_v: GpuBuffer<f32>,
3632 v_lora_a_v: GpuBuffer<f32>,
3633 m_lora_b_v: GpuBuffer<f32>,
3634 v_lora_b_v: GpuBuffer<f32>,
3635 m_input_norm: GpuBuffer<f32>,
3636 v_input_norm: GpuBuffer<f32>,
3637 m_post_attn_norm: GpuBuffer<f32>,
3638 v_post_attn_norm: GpuBuffer<f32>,
3639}
3640
3641#[cfg(feature = "cuda")]
3642impl GpuLoraOptimizerState {
3643 fn new(
3644 ctx: &Arc<CudaContext>,
3645 config: &super::config::TransformerConfig,
3646 lora_rank: usize,
3647 ) -> Result<Self> {
3648 let h = config.hidden_size;
3649 let q_dim = config.q_dim();
3650 let kv = config.num_kv_heads * config.head_dim();
3651 let r = lora_rank;
3652
3653 let z = |n: usize| -> Result<GpuBuffer<f32>> {
3656 Ok(GpuBuffer::from_host(ctx, &vec![0.0f32; n])?)
3657 };
3658 Ok(Self {
3659 m_lora_a_q: z(h * r)?,
3660 v_lora_a_q: z(h * r)?,
3661 m_lora_b_q: z(r * q_dim)?,
3662 v_lora_b_q: z(r * q_dim)?,
3663 m_lora_a_v: z(h * r)?,
3664 v_lora_a_v: z(h * r)?,
3665 m_lora_b_v: z(r * kv)?,
3666 v_lora_b_v: z(r * kv)?,
3667 m_input_norm: z(h)?,
3668 v_input_norm: z(h)?,
3669 m_post_attn_norm: z(h)?,
3670 v_post_attn_norm: z(h)?,
3671 })
3672 }
3673}
3674
3675#[cfg(feature = "cuda")]
3680impl CudaNf4TransformerBlock {
3681 #[allow(clippy::too_many_arguments)]
3689 pub(crate) fn backward(
3690 &self,
3691 layer_input: &GpuBuffer<f32>,
3692 grad_output: &GpuBuffer<f32>,
3693 grad_input: &mut GpuBuffer<f32>,
3694 output_scratch: &mut GpuBuffer<f32>,
3695 seq_len: usize,
3696 stream: &CudaStream,
3697 scratch: &mut CudaBlockScratch,
3698 grad_lora: &mut CudaLoraGradWorkspace,
3699 ) -> Result<()> {
3700 let hidden_size = self.config.hidden_size;
3701 let _q_dim = self.config.q_dim();
3702 let _kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3703 let intermediate_size = self.config.intermediate_size;
3704 let eps = 1e-5_f32;
3705
3706 self.forward(layer_input, output_scratch, seq_len, stream, scratch).map_err(|e| {
3709 eprintln!(
3710 "[backward] Layer {} activation-checkpoint forward FAILED: {e:?}",
3711 self.layer_idx
3712 );
3713 e
3714 })?;
3715
3716 self.backward_nf4_ffn(
3718 grad_output,
3719 seq_len,
3720 hidden_size,
3721 intermediate_size,
3722 stream,
3723 scratch,
3724 )?;
3725
3726 let _t = scratch.op_begin(); rms_norm_backward(
3729 &scratch.residual1,
3730 &self.post_attn_norm_weight,
3731 &scratch.grad_hidden, grad_input, &mut grad_lora.grad_post_attn_norm,
3734 saturating_u32(seq_len),
3735 saturating_u32(hidden_size),
3736 eps,
3737 stream,
3738 )?;
3739
3740 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
3743
3744 self.backward_nf4_attention(
3746 grad_input, seq_len, stream, scratch, grad_lora,
3748 )?;
3749
3750 rms_norm_backward(
3754 layer_input,
3755 &self.input_norm_weight,
3756 &scratch.grad_hidden, grad_input, &mut grad_lora.grad_input_norm,
3759 saturating_u32(seq_len),
3760 saturating_u32(hidden_size),
3761 eps,
3762 stream,
3763 )?;
3764
3765 scratch.op_end(_t, OP_NORM_BWD);
3766
3767 Ok(())
3768 }
3769
3770 fn backward_nf4_ffn(
3776 &self,
3777 grad_output: &GpuBuffer<f32>,
3778 seq_len: usize,
3779 hidden_size: usize,
3780 intermediate_size: usize,
3781 stream: &CudaStream,
3782 scratch: &mut CudaBlockScratch,
3783 ) -> Result<()> {
3784 let s = saturating_u32(seq_len);
3785 let h = saturating_u32(hidden_size);
3786 let i_size = saturating_u32(intermediate_size);
3787 let n_inter = saturating_u32(seq_len * intermediate_size);
3788
3789 static USE_NF4_TC_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3791 let nf4_tc_bwd =
3792 *USE_NF4_TC_BWD.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3793
3794 let _t = scratch.op_begin(); if nf4_tc_bwd {
3797 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3799 grad_output,
3800 &self.w_down_nf4,
3801 &self.w_down_scales,
3802 &mut scratch.grad_swiglu,
3803 s,
3804 h, i_size, stream,
3807 )?;
3808 } else {
3809 gemm_backward_a_fp16_dispatch(
3810 grad_output,
3811 self.w_down_fp16.as_ref(),
3812 &self.w_down_fp32,
3813 &mut scratch.grad_swiglu,
3814 s,
3815 i_size,
3816 h,
3817 stream,
3818 &self.ctx,
3819 )?;
3820 }
3821
3822 scratch.op_end(_t, OP_DOWN_BWD);
3823
3824 let _t = scratch.op_begin(); elementwise_mul_forward(
3831 &scratch.grad_swiglu,
3832 &scratch.up_out,
3833 &mut scratch.swiglu_out,
3834 n_inter,
3835 stream,
3836 )?;
3837
3838 silu_backward(
3842 &scratch.gate_out,
3843 &scratch.swiglu_out,
3844 &mut scratch.up_out, stream,
3846 )?;
3847
3848 silu_forward(&scratch.gate_out, &mut scratch.swiglu_out, n_inter, stream)?;
3851 elementwise_mul_forward(
3853 &scratch.grad_swiglu,
3854 &scratch.swiglu_out,
3855 &mut scratch.gate_out, n_inter,
3857 stream,
3858 )?;
3859
3860 scratch.op_end(_t, OP_SWIGLU_BWD);
3861
3862 let _t = scratch.op_begin(); static USE_FUSED_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3867 let fused_bwd = *USE_FUSED_BWD
3868 .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
3869
3870 if nf4_tc_bwd {
3871 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3874 &scratch.gate_out, &self.w_up_nf4,
3876 &self.w_up_scales,
3877 &mut scratch.grad_hidden,
3878 s,
3879 i_size, h, stream,
3882 )?;
3883 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3885 &scratch.up_out, &self.w_gate_nf4,
3887 &self.w_gate_scales,
3888 &mut scratch.ffn_out,
3889 s,
3890 i_size, h, stream,
3893 )?;
3894 cuda_add_inplace(
3896 &mut scratch.grad_hidden,
3897 &scratch.ffn_out,
3898 seq_len * hidden_size,
3899 stream,
3900 )?;
3901 } else if fused_bwd {
3902 gemm_backward_a_fp16_dispatch(
3904 &scratch.gate_out,
3905 self.w_up_fp16.as_ref(),
3906 &self.w_up_fp32,
3907 &mut scratch.grad_hidden,
3908 s,
3909 h,
3910 i_size,
3911 stream,
3912 &self.ctx,
3913 )?;
3914 gemm_backward_a_fp16_dispatch_accumulate(
3916 &scratch.up_out,
3917 self.w_gate_fp16.as_ref(),
3918 &self.w_gate_fp32,
3919 &mut scratch.grad_hidden,
3920 s,
3921 h,
3922 i_size,
3923 stream,
3924 &self.ctx,
3925 )?;
3926 } else {
3927 gemm_backward_a_fp16_dispatch(
3929 &scratch.up_out,
3930 self.w_gate_fp16.as_ref(),
3931 &self.w_gate_fp32,
3932 &mut scratch.ffn_out,
3933 s,
3934 h,
3935 i_size,
3936 stream,
3937 &self.ctx,
3938 )?;
3939 gemm_backward_a_fp16_dispatch(
3940 &scratch.gate_out,
3941 self.w_up_fp16.as_ref(),
3942 &self.w_up_fp32,
3943 &mut scratch.grad_hidden,
3944 s,
3945 h,
3946 i_size,
3947 stream,
3948 &self.ctx,
3949 )?;
3950
3951 cuda_add_inplace(
3953 &mut scratch.grad_hidden,
3954 &scratch.ffn_out,
3955 seq_len * hidden_size,
3956 stream,
3957 )?;
3958 }
3959 scratch.op_end(_t, OP_GATE_UP_BWD);
3960
3961 Ok(())
3962 }
3963
3964 fn backward_nf4_attention(
3970 &self,
3971 grad_residual1: &GpuBuffer<f32>,
3972 seq_len: usize,
3973 stream: &CudaStream,
3974 scratch: &mut CudaBlockScratch,
3975 grad_lora: &mut CudaLoraGradWorkspace,
3976 ) -> Result<()> {
3977 use crate::autograd::cuda_forward::gemm_forward;
3978
3979 let hidden_size = self.config.hidden_size;
3980 let q_dim = self.config.q_dim();
3981 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3982 let num_heads = self.config.num_attention_heads;
3983 let head_dim = self.config.head_dim();
3984
3985 let s = saturating_u32(seq_len);
3986 let h = saturating_u32(hidden_size);
3987 let qd = saturating_u32(q_dim);
3988 let kvh = saturating_u32(kv_hidden_size);
3989
3990 static USE_NF4_TC_BWD_O: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3992 let nf4_tc_bwd_o = *USE_NF4_TC_BWD_O
3993 .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3994
3995 let _t = scratch.op_begin(); if nf4_tc_bwd_o {
3997 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3998 grad_residual1,
3999 &self.w_o_nf4,
4000 &self.w_o_scales,
4001 &mut scratch.attn_out,
4002 s,
4003 h, qd, stream,
4006 )?;
4007 } else {
4008 gemm_backward_a_fp16_dispatch(
4009 grad_residual1,
4010 self.w_o_fp16.as_ref(),
4011 &self.w_o_fp32,
4012 &mut scratch.attn_out,
4013 s,
4014 qd,
4015 h,
4016 stream,
4017 &self.ctx,
4018 )?;
4019 }
4020
4021 self.backward_nf4_attention_mechanism(seq_len, num_heads, head_dim, stream, scratch)?;
4025
4026 let rope_theta = self.config.rope_theta;
4033 let num_kv_heads = self.config.num_kv_heads;
4034 let nkv = saturating_u32(num_kv_heads);
4035 let nh = saturating_u32(num_heads);
4036 let hd = saturating_u32(head_dim);
4037 {
4038 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
4039 batched_rope_neox_backward(
4040 q_ref,
4041 &mut scratch.q,
4042 &scratch.rope_positions,
4043 nh,
4044 hd,
4045 s,
4046 rope_theta,
4047 stream,
4048 )?;
4049 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
4050 batched_rope_neox_backward(
4051 k_ref,
4052 &mut scratch.k,
4053 &scratch.rope_positions,
4054 nkv,
4055 hd,
4056 s,
4057 rope_theta,
4058 stream,
4059 )?;
4060 }
4061
4062 scratch.op_end(_t, OP_ATTN_BWD);
4063
4064 static USE_NF4_TC_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4066 let nf4_tc_bwd = *USE_NF4_TC_BWD_ATTN
4067 .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4068
4069 let _t = scratch.op_begin(); if nf4_tc_bwd {
4071 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4073 &scratch.q,
4074 &self.w_q_nf4,
4075 &self.w_q_scales,
4076 &mut scratch.o_proj_out,
4077 s,
4078 qd, h, stream,
4081 )?;
4082 } else {
4083 gemm_backward_a_fp16_dispatch(
4084 &scratch.q,
4085 self.w_q_fp16.as_ref(),
4086 &self.w_q_fp32,
4087 &mut scratch.o_proj_out,
4088 s,
4089 h,
4090 qd,
4091 stream,
4092 &self.ctx,
4093 )?;
4094 }
4095
4096 if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
4098 let r = saturating_u32(self.lora_rank);
4099
4100 gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
4102
4103 gemm_backward_b(
4106 &scratch.lora_inter,
4107 &scratch.q,
4108 &mut grad_lora.grad_lora_b_q,
4109 s,
4110 r,
4111 qd,
4112 stream,
4113 )?;
4114
4115 gemm_backward_a(
4117 &scratch.q,
4118 b_q,
4119 &mut scratch.lora_inter, s,
4121 qd,
4122 r,
4123 stream,
4124 )?;
4125
4126 gemm_backward_b(
4128 &scratch.norm1_out,
4129 &scratch.lora_inter,
4130 &mut grad_lora.grad_lora_a_q,
4131 s,
4132 h,
4133 r,
4134 stream,
4135 )?;
4136
4137 gemm_backward_a(
4139 &scratch.lora_inter,
4140 a_q,
4141 &mut scratch.lora_temp, s,
4143 r,
4144 h,
4145 stream,
4146 )?;
4147 cuda_add_inplace(
4148 &mut scratch.o_proj_out,
4149 &scratch.lora_temp,
4150 seq_len * hidden_size,
4151 stream,
4152 )?;
4153 }
4154
4155 static USE_FUSED_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4157 let fused_bwd = *USE_FUSED_BWD_ATTN
4158 .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
4159
4160 if nf4_tc_bwd {
4161 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4163 &scratch.k,
4164 &self.w_k_nf4,
4165 &self.w_k_scales,
4166 &mut scratch.ffn_out,
4167 s,
4168 kvh, h, stream,
4171 )?;
4172 cuda_add_inplace(
4173 &mut scratch.o_proj_out,
4174 &scratch.ffn_out,
4175 seq_len * hidden_size,
4176 stream,
4177 )?;
4178 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4179 &scratch.v,
4180 &self.w_v_nf4,
4181 &self.w_v_scales,
4182 &mut scratch.ffn_out,
4183 s,
4184 kvh, h, stream,
4187 )?;
4188 cuda_add_inplace(
4189 &mut scratch.o_proj_out,
4190 &scratch.ffn_out,
4191 seq_len * hidden_size,
4192 stream,
4193 )?;
4194 } else if fused_bwd {
4195 gemm_backward_a_fp16_dispatch_accumulate(
4197 &scratch.k,
4198 self.w_k_fp16.as_ref(),
4199 &self.w_k_fp32,
4200 &mut scratch.o_proj_out,
4201 s,
4202 h,
4203 kvh,
4204 stream,
4205 &self.ctx,
4206 )?;
4207 gemm_backward_a_fp16_dispatch_accumulate(
4208 &scratch.v,
4209 self.w_v_fp16.as_ref(),
4210 &self.w_v_fp32,
4211 &mut scratch.o_proj_out,
4212 s,
4213 h,
4214 kvh,
4215 stream,
4216 &self.ctx,
4217 )?;
4218 } else {
4219 gemm_backward_a_fp16_dispatch(
4221 &scratch.k,
4222 self.w_k_fp16.as_ref(),
4223 &self.w_k_fp32,
4224 &mut scratch.ffn_out,
4225 s,
4226 h,
4227 kvh,
4228 stream,
4229 &self.ctx,
4230 )?;
4231 cuda_add_inplace(
4232 &mut scratch.o_proj_out,
4233 &scratch.ffn_out,
4234 seq_len * hidden_size,
4235 stream,
4236 )?;
4237
4238 gemm_backward_a_fp16_dispatch(
4239 &scratch.v,
4240 self.w_v_fp16.as_ref(),
4241 &self.w_v_fp32,
4242 &mut scratch.ffn_out,
4243 s,
4244 h,
4245 kvh,
4246 stream,
4247 &self.ctx,
4248 )?;
4249 cuda_add_inplace(
4250 &mut scratch.o_proj_out,
4251 &scratch.ffn_out,
4252 seq_len * hidden_size,
4253 stream,
4254 )?;
4255 }
4256
4257 if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
4259 let r = saturating_u32(self.lora_rank);
4260
4261 gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
4263
4264 gemm_backward_b(
4266 &scratch.lora_inter,
4267 &scratch.v,
4268 &mut grad_lora.grad_lora_b_v,
4269 s,
4270 r,
4271 kvh,
4272 stream,
4273 )?;
4274
4275 gemm_backward_a(&scratch.v, b_v, &mut scratch.lora_inter, s, kvh, r, stream)?;
4277
4278 gemm_backward_b(
4280 &scratch.norm1_out,
4281 &scratch.lora_inter,
4282 &mut grad_lora.grad_lora_a_v,
4283 s,
4284 h,
4285 r,
4286 stream,
4287 )?;
4288
4289 gemm_backward_a(&scratch.lora_inter, a_v, &mut scratch.lora_temp, s, r, h, stream)?;
4291 cuda_add_inplace(
4292 &mut scratch.o_proj_out,
4293 &scratch.lora_temp,
4294 seq_len * hidden_size,
4295 stream,
4296 )?;
4297 }
4298
4299 scratch.op_end(_t, OP_QKV_BWD);
4300
4301 unsafe {
4303 scratch.grad_hidden.copy_from_buffer_async(&scratch.o_proj_out, stream).map_err(
4304 |e| {
4305 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4306 "grad_norm1 copy failed: {e}"
4307 ))
4308 },
4309 )?;
4310 }
4311
4312 Ok(())
4313 }
4314
4315 fn backward_nf4_attention_mechanism(
4337 &self,
4338 seq_len: usize,
4339 num_heads: usize,
4340 head_dim: usize,
4341 stream: &CudaStream,
4342 scratch: &mut CudaBlockScratch,
4343 ) -> Result<()> {
4344 let num_kv_heads = self.config.num_kv_heads;
4345 let heads_per_kv = num_heads / num_kv_heads;
4346 let s = saturating_u32(seq_len);
4347 let nh = saturating_u32(num_heads);
4348 let nkv = saturating_u32(num_kv_heads);
4349 let hd = saturating_u32(head_dim);
4350 let scale = 1.0 / (head_dim as f32).sqrt();
4351
4352 interleaved_to_batched_forward(
4355 &scratch.attn_out,
4356 &mut scratch.attn_q_batched, s,
4358 nh,
4359 hd,
4360 stream,
4361 )?;
4362
4363 interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
4367
4368 if heads_per_kv > 1 {
4369 expand_kv_heads(
4370 &scratch.attn_kv_temp,
4371 &mut scratch.attn_kv_temp2,
4372 num_kv_heads,
4373 heads_per_kv,
4374 seq_len * head_dim,
4375 stream,
4376 )?;
4377 } else {
4378 unsafe {
4379 scratch
4380 .attn_kv_temp2
4381 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
4382 .map_err(|e| {
4383 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4384 "V copy for attn backward: {e:?}"
4385 ))
4386 })?;
4387 }
4388 }
4389 batched_transpose_forward(
4393 &scratch.attn_kv_temp2,
4394 &mut scratch.attn_kv_temp, nh,
4396 s,
4397 hd,
4398 stream,
4399 )?;
4400
4401 batched_4d_gemm_forward(
4404 &scratch.attn_q_batched,
4405 &scratch.attn_kv_temp,
4406 &mut scratch.grad_attn_scores,
4407 1,
4408 nh,
4409 s,
4410 s,
4411 hd,
4412 stream,
4413 )?;
4414
4415 batched_transpose_forward(
4421 &scratch.attn_q_batched,
4422 &mut scratch.attn_kv_temp, nh,
4424 s,
4425 hd,
4426 stream,
4427 )?;
4428
4429 batched_4d_gemm_forward(
4431 &scratch.attn_kv_temp,
4432 &scratch.attn_scores, &mut scratch.attn_kv_temp2, 1,
4435 nh,
4436 hd,
4437 s,
4438 s,
4439 stream,
4440 )?;
4441
4442 batched_transpose_forward(
4444 &scratch.attn_kv_temp2,
4445 &mut scratch.attn_kv_temp, nh,
4447 hd,
4448 s,
4449 stream,
4450 )?;
4451 let total_rows = nh * s;
4456 {
4457 let grad_scores_view = unsafe {
4458 GpuBuffer::<f32>::from_raw_parts(
4459 scratch.grad_attn_scores.as_ptr(),
4460 scratch.grad_attn_scores.len(),
4461 )
4462 };
4463 batched_softmax_backward(
4464 &scratch.attn_scores,
4465 &grad_scores_view,
4466 &mut scratch.grad_attn_scores,
4467 total_rows,
4468 s,
4469 stream,
4470 )?;
4471 leak(grad_scores_view);
4472 }
4473
4474 let total_scores = saturating_u32(num_heads * seq_len * seq_len);
4476 {
4477 let scores_view = unsafe {
4478 GpuBuffer::<f32>::from_raw_parts(
4479 scratch.grad_attn_scores.as_ptr(),
4480 scratch.grad_attn_scores.len(),
4481 )
4482 };
4483 scale_forward(
4484 &scores_view,
4485 &mut scratch.grad_attn_scores,
4486 scale,
4487 total_scores,
4488 stream,
4489 )?;
4490 leak(scores_view);
4491 }
4492
4493 interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp2, s, nkv, hd, stream)?;
4495
4496 if heads_per_kv > 1 {
4497 unsafe {
4498 scratch
4499 .attn_q_batched
4500 .copy_from_buffer_async(&scratch.attn_kv_temp2, stream)
4501 .map_err(|e| {
4502 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4503 "K copy for GQA expand: {e}"
4504 ))
4505 })?;
4506 }
4507 expand_kv_heads(
4508 &scratch.attn_q_batched,
4509 &mut scratch.attn_kv_temp2,
4510 num_kv_heads,
4511 heads_per_kv,
4512 seq_len * head_dim,
4513 stream,
4514 )?;
4515 }
4516 batched_4d_gemm_forward(
4520 &scratch.grad_attn_scores,
4521 &scratch.attn_kv_temp2,
4522 &mut scratch.attn_q_batched,
4523 1,
4524 nh,
4525 s,
4526 hd,
4527 s,
4528 stream,
4529 )?;
4530
4531 interleaved_to_batched_forward(
4535 &scratch.q,
4536 &mut scratch.o_proj_out, s,
4538 nh,
4539 hd,
4540 stream,
4541 )?;
4542
4543 batched_transpose_forward(
4545 &scratch.o_proj_out,
4546 &mut scratch.attn_kv_temp2, nh,
4548 s,
4549 hd,
4550 stream,
4551 )?;
4552
4553 batched_4d_gemm_forward(
4555 &scratch.attn_kv_temp2,
4556 &scratch.grad_attn_scores,
4557 &mut scratch.ffn_out, 1,
4559 nh,
4560 hd,
4561 s,
4562 s,
4563 stream,
4564 )?;
4565
4566 batched_transpose_forward(
4568 &scratch.ffn_out,
4569 &mut scratch.attn_kv_temp2, nh,
4571 hd,
4572 s,
4573 stream,
4574 )?;
4575
4576 if heads_per_kv > 1 {
4578 self.reduce_gqa_gradients_nf4(
4579 num_kv_heads,
4580 heads_per_kv,
4581 seq_len,
4582 head_dim,
4583 stream,
4584 scratch,
4585 )?;
4586 }
4587
4588 batched_to_interleaved_forward(&scratch.attn_q_batched, &mut scratch.q, s, nh, hd, stream)?;
4591
4592 batched_to_interleaved_forward(&scratch.attn_kv_temp2, &mut scratch.k, s, nkv, hd, stream)?;
4594
4595 batched_to_interleaved_forward(&scratch.attn_kv_temp, &mut scratch.v, s, nkv, hd, stream)?;
4598
4599 Ok(())
4600 }
4601
4602 fn reduce_gqa_gradients_nf4(
4606 &self,
4607 num_kv_heads: usize,
4608 heads_per_kv: usize,
4609 seq_len: usize,
4610 head_dim: usize,
4611 stream: &CudaStream,
4612 scratch: &mut CudaBlockScratch,
4613 ) -> Result<()> {
4614 let chunk = seq_len * head_dim;
4615 for g in 0..num_kv_heads {
4616 let dst_off = g * chunk;
4617 let src_off = g * heads_per_kv * chunk;
4619 {
4621 let src = unsafe {
4622 GpuBuffer::<f32>::from_raw_parts(
4623 scratch.attn_kv_temp2.as_ptr() + (src_off * 4) as u64,
4624 chunk,
4625 )
4626 };
4627 let mut dst = unsafe {
4628 GpuBuffer::<f32>::from_raw_parts(
4629 scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4630 chunk,
4631 )
4632 };
4633 if src_off != dst_off {
4634 unsafe {
4635 dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4636 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4637 "GQA K reduce copy: {e}"
4638 ))
4639 })?;
4640 }
4641 }
4642 leak(src);
4643 leak(dst);
4644 }
4645 for h in 1..heads_per_kv {
4647 let add_off = (g * heads_per_kv + h) * chunk;
4648 let src = unsafe {
4649 GpuBuffer::<f32>::from_raw_parts(
4650 scratch.attn_kv_temp2.as_ptr() + (add_off * 4) as u64,
4651 chunk,
4652 )
4653 };
4654 let mut dst = unsafe {
4655 GpuBuffer::<f32>::from_raw_parts(
4656 scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4657 chunk,
4658 )
4659 };
4660 cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4661 leak(src);
4662 leak(dst);
4663 }
4664 {
4666 let src = unsafe {
4667 GpuBuffer::<f32>::from_raw_parts(
4668 scratch.attn_kv_temp.as_ptr() + (src_off * 4) as u64,
4669 chunk,
4670 )
4671 };
4672 let mut dst = unsafe {
4673 GpuBuffer::<f32>::from_raw_parts(
4674 scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4675 chunk,
4676 )
4677 };
4678 if src_off != dst_off {
4679 unsafe {
4680 dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4681 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4682 "GQA V reduce copy: {e}"
4683 ))
4684 })?;
4685 }
4686 }
4687 leak(src);
4688 leak(dst);
4689 }
4690 for h in 1..heads_per_kv {
4691 let add_off = (g * heads_per_kv + h) * chunk;
4692 let src = unsafe {
4693 GpuBuffer::<f32>::from_raw_parts(
4694 scratch.attn_kv_temp.as_ptr() + (add_off * 4) as u64,
4695 chunk,
4696 )
4697 };
4698 let mut dst = unsafe {
4699 GpuBuffer::<f32>::from_raw_parts(
4700 scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4701 chunk,
4702 )
4703 };
4704 cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4705 leak(src);
4706 leak(dst);
4707 }
4708 }
4709 Ok(())
4710 }
4711
4712 pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
4714 GpuLoraOptimizerState::new(&self.ctx, &self.config, self.lora_rank)
4715 }
4716
4717 #[allow(clippy::too_many_arguments)]
4719 pub(crate) fn lora_optimizer_step(
4720 &mut self,
4721 state: &mut GpuLoraOptimizerState,
4722 step: u32,
4723 lr: f32,
4724 beta1: f32,
4725 beta2: f32,
4726 eps: f32,
4727 weight_decay: f32,
4728 stream: &CudaStream,
4729 grad_lora: &CudaLoraGradWorkspace,
4730 ) -> Result<()> {
4731 let h = self.config.hidden_size;
4732 let q_dim = self.config.q_dim();
4733 let kv = self.config.num_kv_heads * self.config.head_dim();
4734 let r = self.lora_rank;
4735
4736 if let Some(ref mut a_q) = self.lora_a_q {
4738 adamw_step_cuda(
4739 a_q,
4740 &grad_lora.grad_lora_a_q,
4741 &mut state.m_lora_a_q,
4742 &mut state.v_lora_a_q,
4743 lr,
4744 beta1,
4745 beta2,
4746 eps,
4747 weight_decay,
4748 step,
4749 saturating_u32(h * r),
4750 stream,
4751 )?;
4752 }
4753 if let Some(ref mut b_q) = self.lora_b_q {
4754 adamw_step_cuda(
4755 b_q,
4756 &grad_lora.grad_lora_b_q,
4757 &mut state.m_lora_b_q,
4758 &mut state.v_lora_b_q,
4759 lr,
4760 beta1,
4761 beta2,
4762 eps,
4763 weight_decay,
4764 step,
4765 saturating_u32(r * q_dim),
4766 stream,
4767 )?;
4768 }
4769 if let Some(ref mut a_v) = self.lora_a_v {
4770 adamw_step_cuda(
4771 a_v,
4772 &grad_lora.grad_lora_a_v,
4773 &mut state.m_lora_a_v,
4774 &mut state.v_lora_a_v,
4775 lr,
4776 beta1,
4777 beta2,
4778 eps,
4779 weight_decay,
4780 step,
4781 saturating_u32(h * r),
4782 stream,
4783 )?;
4784 }
4785 if let Some(ref mut b_v) = self.lora_b_v {
4786 adamw_step_cuda(
4787 b_v,
4788 &grad_lora.grad_lora_b_v,
4789 &mut state.m_lora_b_v,
4790 &mut state.v_lora_b_v,
4791 lr,
4792 beta1,
4793 beta2,
4794 eps,
4795 weight_decay,
4796 step,
4797 saturating_u32(r * kv),
4798 stream,
4799 )?;
4800 }
4801
4802 adamw_step_cuda(
4804 &mut self.input_norm_weight,
4805 &grad_lora.grad_input_norm,
4806 &mut state.m_input_norm,
4807 &mut state.v_input_norm,
4808 lr,
4809 beta1,
4810 beta2,
4811 eps,
4812 weight_decay,
4813 step,
4814 saturating_u32(h),
4815 stream,
4816 )?;
4817 adamw_step_cuda(
4818 &mut self.post_attn_norm_weight,
4819 &grad_lora.grad_post_attn_norm,
4820 &mut state.m_post_attn_norm,
4821 &mut state.v_post_attn_norm,
4822 lr,
4823 beta1,
4824 beta2,
4825 eps,
4826 weight_decay,
4827 step,
4828 saturating_u32(h),
4829 stream,
4830 )?;
4831
4832 Ok(())
4833 }
4834
4835 pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
4841 let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
4842 let mut host = vec![0.0f32; buf.len()];
4843 buf.copy_to_host(&mut host).map_err(|e| {
4844 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4845 "LoRA weight download failed: {e}"
4846 ))
4847 })?;
4848 Ok(host)
4849 };
4850 let a_q = self.lora_a_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4851 let b_q = self.lora_b_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4852 let a_v = self.lora_a_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4853 let b_v = self.lora_b_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4854 Ok((a_q, b_q, a_v, b_v))
4855 }
4856
4857 pub fn upload_lora_weights(
4863 &mut self,
4864 a_q: &[f32],
4865 b_q: &[f32],
4866 a_v: &[f32],
4867 b_v: &[f32],
4868 ) -> Result<()> {
4869 let upload = |buf: &mut GpuBuffer<f32>, data: &[f32], name: &str| -> Result<()> {
4870 if data.len() != buf.len() {
4871 return Err(crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(
4872 format!(
4873 "LoRA {name} size mismatch: checkpoint has {} but GPU buffer expects {}",
4874 data.len(),
4875 buf.len()
4876 ),
4877 ));
4878 }
4879 buf.copy_from_host(data).map_err(|e| {
4880 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4881 "LoRA {name} upload failed: {e}"
4882 ))
4883 })
4884 };
4885 if let Some(ref mut buf) = self.lora_a_q {
4886 upload(buf, a_q, "a_q")?;
4887 }
4888 if let Some(ref mut buf) = self.lora_b_q {
4889 upload(buf, b_q, "b_q")?;
4890 }
4891 if let Some(ref mut buf) = self.lora_a_v {
4892 upload(buf, a_v, "a_v")?;
4893 }
4894 if let Some(ref mut buf) = self.lora_b_v {
4895 upload(buf, b_v, "b_v")?;
4896 }
4897 Ok(())
4898 }
4899}
4900
4901#[cfg(test)]
4902mod tests {
4903 #[test]
4904 fn test_cuda_block_compiles() {
4905 #[cfg(feature = "cuda")]
4907 {
4908 use super::*;
4909 let _ = std::mem::size_of::<CudaTransformerBlock>();
4910 let _ = std::mem::size_of::<CudaNf4TransformerBlock>();
4911 }
4912 }
4913}