1#![allow(dead_code)]
16#![allow(unsafe_code)]
21
22#[cfg(feature = "cuda")]
25const OP_RMSNORM_ATTN: usize = 0;
26#[cfg(feature = "cuda")]
27const OP_QKV_GEMM: usize = 1;
28#[cfg(feature = "cuda")]
29const OP_ATTENTION: usize = 2;
30#[cfg(feature = "cuda")]
31const OP_O_PROJ: usize = 3;
32#[cfg(feature = "cuda")]
33const OP_RMSNORM_FFN: usize = 4;
34#[cfg(feature = "cuda")]
35const OP_GATE_UP_GEMM: usize = 5;
36#[cfg(feature = "cuda")]
37const OP_SILU: usize = 6;
38#[cfg(feature = "cuda")]
39const OP_DOWN_GEMM: usize = 7;
40
41#[cfg(feature = "cuda")]
43const OP_LORA_FWD: usize = 8;
44#[cfg(feature = "cuda")]
45const OP_DOWN_BWD: usize = 9;
46#[cfg(feature = "cuda")]
47const OP_SWIGLU_BWD: usize = 10;
48#[cfg(feature = "cuda")]
49const OP_GATE_UP_BWD: usize = 11;
50#[cfg(feature = "cuda")]
51const OP_ATTN_BWD: usize = 12;
52#[cfg(feature = "cuda")]
53const OP_QKV_BWD: usize = 13;
54#[cfg(feature = "cuda")]
55const OP_NORM_BWD: usize = 14;
56#[cfg(feature = "cuda")]
57const OP_LORA_BWD: usize = 15;
58
59#[cfg(feature = "cuda")]
60use std::sync::Arc;
61
62#[cfg(feature = "cuda")]
63#[inline]
64fn saturating_u32(v: usize) -> u32 {
65 v.min(u32::MAX as usize) as u32
66}
67
68#[cfg(feature = "cuda")]
70#[inline]
71fn leak<T>(val: T) {
72 let _ = std::mem::ManuallyDrop::new(val);
73}
74
75#[cfg(feature = "cuda")]
76use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
77
78#[cfg(feature = "cuda")]
79use crate::autograd::cuda_backward::{
80 batched_softmax_backward, gemm_backward_a, gemm_backward_a_fp16_dispatch,
81 gemm_backward_a_fp16_dispatch_accumulate, gemm_backward_b, rms_norm_backward, silu_backward,
82};
83#[cfg(feature = "cuda")]
84use crate::autograd::cuda_forward::{
85 batched_4d_gemm_forward, batched_rope_neox_backward, batched_rope_neox_forward,
86 batched_softmax_forward, batched_to_interleaved_forward, batched_transpose_forward,
87 cast_f32_to_f16_gpu, elementwise_mul_forward, expand_kv_heads, fused_residual_rmsnorm_forward,
88 fused_swiglu_forward, gemm_f16_to_f32_forward, gemm_forward, interleaved_to_batched_forward,
89 per_head_rmsnorm_forward, residual_add_forward, rms_norm_forward, rms_norm_forward_with_eps,
90 scale_forward, silu_forward,
91};
92#[cfg(feature = "cuda")]
93use crate::autograd::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, squared_sum_cuda};
94#[cfg(feature = "cuda")]
95use crate::autograd::cuda_tensor::Result;
96
97#[cfg(feature = "cuda")]
98use super::config::TransformerConfig;
99
100#[cfg(feature = "cuda")]
104pub struct CudaTransformerBlock {
105 config: TransformerConfig,
107 layer_idx: usize,
109 input_norm_weight: GpuBuffer<f32>,
111 post_attn_norm_weight: GpuBuffer<f32>,
113 w_q: GpuBuffer<f32>,
115 w_k: GpuBuffer<f32>,
117 w_v: GpuBuffer<f32>,
119 w_o: GpuBuffer<f32>,
121 w_gate: GpuBuffer<f32>,
123 w_up: GpuBuffer<f32>,
125 w_down: GpuBuffer<f32>,
127 ctx: Arc<CudaContext>,
129 scratch: CudaBlockScratch,
131 norm_zero_buf: Vec<f32>,
133 q_norm_weight: Option<GpuBuffer<f32>>,
135 k_norm_weight: Option<GpuBuffer<f32>>,
136 b_q_replicated: Option<GpuBuffer<f32>>,
145 b_k_replicated: Option<GpuBuffer<f32>>,
146 b_v_replicated: Option<GpuBuffer<f32>>,
147}
148
149#[cfg(feature = "cuda")]
160pub(crate) struct CudaBlockScratch {
161 norm1_out: GpuBuffer<f32>,
163 q: GpuBuffer<f32>,
165 k: GpuBuffer<f32>,
167 v: GpuBuffer<f32>,
169 attn_scores: GpuBuffer<f32>,
171 attn_out: GpuBuffer<f32>,
173 o_proj_out: GpuBuffer<f32>,
175 residual1: GpuBuffer<f32>,
177 norm2_out: GpuBuffer<f32>,
179 gate_out: GpuBuffer<f32>,
181 up_out: GpuBuffer<f32>,
183 swiglu_out: GpuBuffer<f32>,
185 ffn_out: GpuBuffer<f32>,
187 norm1_out_f16: Option<GpuBuffer<u16>>,
190 attn_out_f16: Option<GpuBuffer<u16>>,
191 norm2_out_f16: Option<GpuBuffer<u16>>,
192 swiglu_out_f16: Option<GpuBuffer<u16>>,
193 grad_hidden: GpuBuffer<f32>,
196 grad_swiglu: GpuBuffer<f32>,
198 attn_q_batched: GpuBuffer<f32>,
201 attn_kv_temp: GpuBuffer<f32>,
203 attn_kv_temp2: GpuBuffer<f32>,
205 grad_attn_scores: GpuBuffer<f32>,
209 lora_inter: GpuBuffer<f32>,
212 lora_temp: GpuBuffer<f32>,
215 rope_positions: GpuBuffer<u32>,
217 causal_mask_contiguous: GpuBuffer<f32>,
219 pub(crate) causal_mask_cached_seq_len: usize,
221 pub(crate) op_us: [u64; 16],
225 pub(crate) op_profiling_enabled: bool,
227}
228
229#[cfg(feature = "cuda")]
230impl CudaBlockScratch {
231 #[inline]
233 pub(crate) fn op_begin(&self) -> Option<std::time::Instant> {
234 if self.op_profiling_enabled {
235 Some(std::time::Instant::now())
236 } else {
237 None
238 }
239 }
240
241 #[inline]
243 pub(crate) fn op_end(&mut self, start: Option<std::time::Instant>, op: usize) {
244 if let Some(t) = start {
245 if op < 16 {
246 self.op_us[op] += t.elapsed().as_micros() as u64;
247 }
248 }
249 }
250
251 pub(crate) fn max_seq_len(&self, hidden_size: usize) -> usize {
255 self.norm1_out.len() / hidden_size.max(1)
256 }
257
258 #[rustfmt::skip]
259 pub(crate) fn zero_forward_buffers(&mut self, stream: &CudaStream) {
260 let z = |b: &mut GpuBuffer<f32>| { b.zero_async(stream).ok(); };
261 z(&mut self.norm1_out); z(&mut self.q); z(&mut self.k); z(&mut self.v); z(&mut self.attn_scores); z(&mut self.attn_out);
262 z(&mut self.o_proj_out); z(&mut self.residual1); z(&mut self.norm2_out); z(&mut self.gate_out); z(&mut self.up_out);
263 z(&mut self.swiglu_out); z(&mut self.ffn_out); z(&mut self.attn_q_batched); z(&mut self.attn_kv_temp); z(&mut self.attn_kv_temp2);
264 z(&mut self.grad_hidden); z(&mut self.grad_swiglu); z(&mut self.grad_attn_scores); z(&mut self.lora_inter); z(&mut self.lora_temp);
265 self.causal_mask_cached_seq_len = 0;
266 }
267
268 pub(crate) fn new(
274 config: &TransformerConfig,
275 max_seq_len: usize,
276 ctx: &Arc<CudaContext>,
277 lora_rank: usize,
278 ) -> Result<Self> {
279 let hidden_size = config.hidden_size;
280 let q_dim = config.q_dim();
281 let kv_hidden_size = config.num_kv_heads * config.head_dim();
282 let intermediate_size = config.intermediate_size;
283 let num_heads = config.num_attention_heads;
284 let head_dim = config.head_dim();
285
286 let max_proj_dim = q_dim.max(kv_hidden_size);
288 let lora_inter_size = (max_seq_len * lora_rank).max(1);
290 let lora_temp_size = (max_seq_len * max_proj_dim).max(1);
291
292 let causal_mask_data: Vec<f32> = (0..max_seq_len * max_seq_len)
295 .map(|idx| {
296 let row = idx / max_seq_len;
297 let col = idx % max_seq_len;
298 if col <= row {
299 0.0f32
300 } else {
301 f32::NEG_INFINITY
302 }
303 })
304 .collect();
305 Ok(Self {
306 norm1_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
307 q: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
308 k: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
309 v: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
310 attn_scores: GpuBuffer::new(ctx, num_heads * max_seq_len * max_seq_len)?,
311 attn_out: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
312 o_proj_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
313 residual1: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
314 norm2_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
315 gate_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
316 up_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
317 swiglu_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
318 ffn_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
319 norm1_out_f16: None,
320 attn_out_f16: None,
321 norm2_out_f16: None,
322 swiglu_out_f16: None,
323 grad_hidden: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
324 grad_swiglu: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
325 attn_q_batched: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
326 attn_kv_temp: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
327 attn_kv_temp2: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
328 grad_attn_scores: GpuBuffer::new(
329 ctx,
330 num_heads * max_seq_len * max_seq_len.max(head_dim),
331 )?,
332 lora_inter: GpuBuffer::new(ctx, lora_inter_size)?,
333 lora_temp: GpuBuffer::new(ctx, lora_temp_size)?,
334 rope_positions: {
335 let positions: Vec<u32> = (0..max_seq_len as u32).collect();
336 let mut buf = GpuBuffer::new(ctx, max_seq_len)?;
337 buf.copy_from_host(&positions)?;
338 buf
339 },
340 causal_mask_contiguous: GpuBuffer::from_host(ctx, &causal_mask_data)?,
341 causal_mask_cached_seq_len: max_seq_len,
342 op_us: [0u64; 16],
343 op_profiling_enabled: false,
344 })
345 }
346
347 pub(crate) fn prepare_causal_mask(
350 &mut self,
351 seq_len: usize,
352 ctx: &Arc<CudaContext>,
353 ) -> crate::autograd::cuda_tensor::Result<()> {
354 if seq_len == self.causal_mask_cached_seq_len {
355 return Ok(());
356 }
357 let mask_data: Vec<f32> = (0..seq_len * seq_len)
358 .map(|idx| {
359 let row = idx / seq_len;
360 let col = idx % seq_len;
361 if col <= row {
362 0.0f32
363 } else {
364 f32::NEG_INFINITY
365 }
366 })
367 .collect();
368 self.causal_mask_contiguous = GpuBuffer::from_host(ctx, &mask_data)?;
369 self.causal_mask_cached_seq_len = seq_len;
370 Ok(())
371 }
372}
373
374#[cfg(feature = "cuda")]
388pub struct CudaGradWorkspace {
389 pub(crate) grad_input_norm: GpuBuffer<f32>,
391 pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
393 pub(crate) grad_gate: GpuBuffer<f32>,
395 pub(crate) grad_up: GpuBuffer<f32>,
397 pub(crate) grad_down: GpuBuffer<f32>,
399 pub(crate) grad_w_q: GpuBuffer<f32>,
401 pub(crate) grad_w_k: GpuBuffer<f32>,
403 pub(crate) grad_w_v: GpuBuffer<f32>,
405 pub(crate) grad_w_o: GpuBuffer<f32>,
407}
408
409#[cfg(feature = "cuda")]
410impl CudaGradWorkspace {
411 pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> Result<Self> {
417 let h = config.hidden_size;
418 let q = config.q_dim();
419 let kv = config.num_kv_heads * config.head_dim();
420 let i = config.intermediate_size;
421
422 Ok(Self {
423 grad_input_norm: GpuBuffer::new(ctx, h)?,
424 grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
425 grad_gate: GpuBuffer::new(ctx, h * i)?,
426 grad_up: GpuBuffer::new(ctx, h * i)?,
427 grad_down: GpuBuffer::new(ctx, i * h)?,
428 grad_w_q: GpuBuffer::new(ctx, q * h)?,
429 grad_w_k: GpuBuffer::new(ctx, h * kv)?,
430 grad_w_v: GpuBuffer::new(ctx, h * kv)?,
431 grad_w_o: GpuBuffer::new(ctx, h * q)?,
432 })
433 }
434
435 pub fn zero_norm_grads(&mut self, zero_buf: &[f32]) -> Result<()> {
441 let n = self.grad_input_norm.len();
442 self.grad_input_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
443 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
444 "Failed to zero grad_input_norm: {e:?}"
445 ))
446 })?;
447 self.grad_post_attn_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
448 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
449 "Failed to zero grad_post_attn_norm: {e:?}"
450 ))
451 })?;
452 Ok(())
453 }
454}
455
456#[cfg(feature = "cuda")]
468pub struct GpuBlockOptimizerState {
469 m_w_q: GpuBuffer<f32>,
471 v_w_q: GpuBuffer<f32>,
472 m_w_k: GpuBuffer<f32>,
473 v_w_k: GpuBuffer<f32>,
474 m_w_v: GpuBuffer<f32>,
475 v_w_v: GpuBuffer<f32>,
476 m_w_o: GpuBuffer<f32>,
477 v_w_o: GpuBuffer<f32>,
478 m_w_gate: GpuBuffer<f32>,
480 v_w_gate: GpuBuffer<f32>,
481 m_w_up: GpuBuffer<f32>,
482 v_w_up: GpuBuffer<f32>,
483 m_w_down: GpuBuffer<f32>,
484 v_w_down: GpuBuffer<f32>,
485 m_input_norm: GpuBuffer<f32>,
487 v_input_norm: GpuBuffer<f32>,
488 m_post_attn_norm: GpuBuffer<f32>,
489 v_post_attn_norm: GpuBuffer<f32>,
490}
491
492#[cfg(feature = "cuda")]
494impl GpuBlockOptimizerState {
495 pub fn download_to_host(
498 &self,
499 ) -> crate::autograd::cuda_tensor::Result<Vec<(String, Vec<f32>)>> {
500 let dl = |name: &str,
501 buf: &GpuBuffer<f32>|
502 -> crate::autograd::cuda_tensor::Result<(String, Vec<f32>)> {
503 let mut host = vec![0.0f32; buf.len()];
504 buf.copy_to_host(&mut host).map_err(|e| {
505 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
506 "optimizer D2H {name}: {e}"
507 ))
508 })?;
509 Ok((name.to_string(), host))
510 };
511 Ok(vec![
512 dl("m.w_q", &self.m_w_q)?,
513 dl("v.w_q", &self.v_w_q)?,
514 dl("m.w_k", &self.m_w_k)?,
515 dl("v.w_k", &self.v_w_k)?,
516 dl("m.w_v", &self.m_w_v)?,
517 dl("v.w_v", &self.v_w_v)?,
518 dl("m.w_o", &self.m_w_o)?,
519 dl("v.w_o", &self.v_w_o)?,
520 dl("m.w_gate", &self.m_w_gate)?,
521 dl("v.w_gate", &self.v_w_gate)?,
522 dl("m.w_up", &self.m_w_up)?,
523 dl("v.w_up", &self.v_w_up)?,
524 dl("m.w_down", &self.m_w_down)?,
525 dl("v.w_down", &self.v_w_down)?,
526 dl("m.input_norm", &self.m_input_norm)?,
527 dl("v.input_norm", &self.v_input_norm)?,
528 dl("m.post_attn_norm", &self.m_post_attn_norm)?,
529 dl("v.post_attn_norm", &self.v_post_attn_norm)?,
530 ])
531 }
532
533 pub fn restore_from_host(
536 &mut self,
537 data: &std::collections::HashMap<String, Vec<f32>>,
538 ) -> crate::autograd::cuda_tensor::Result<()> {
539 let ul = |name: &str,
540 buf: &mut GpuBuffer<f32>,
541 data: &std::collections::HashMap<String, Vec<f32>>|
542 -> crate::autograd::cuda_tensor::Result<()> {
543 if let Some(host_data) = data.get(name) {
544 if host_data.len() == buf.len() {
545 buf.copy_from_host(host_data).map_err(|e| {
546 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
547 "optimizer H2D {name}: {e}"
548 ))
549 })?;
550 }
551 }
552 Ok(())
553 };
554 ul("m.w_q", &mut self.m_w_q, data)?;
555 ul("v.w_q", &mut self.v_w_q, data)?;
556 ul("m.w_k", &mut self.m_w_k, data)?;
557 ul("v.w_k", &mut self.v_w_k, data)?;
558 ul("m.w_v", &mut self.m_w_v, data)?;
559 ul("v.w_v", &mut self.v_w_v, data)?;
560 ul("m.w_o", &mut self.m_w_o, data)?;
561 ul("v.w_o", &mut self.v_w_o, data)?;
562 ul("m.w_gate", &mut self.m_w_gate, data)?;
563 ul("v.w_gate", &mut self.v_w_gate, data)?;
564 ul("m.w_up", &mut self.m_w_up, data)?;
565 ul("v.w_up", &mut self.v_w_up, data)?;
566 ul("m.w_down", &mut self.m_w_down, data)?;
567 ul("v.w_down", &mut self.v_w_down, data)?;
568 ul("m.input_norm", &mut self.m_input_norm, data)?;
569 ul("v.input_norm", &mut self.v_input_norm, data)?;
570 ul("m.post_attn_norm", &mut self.m_post_attn_norm, data)?;
571 ul("v.post_attn_norm", &mut self.v_post_attn_norm, data)?;
572 Ok(())
573 }
574}
575
576#[cfg(feature = "cuda")]
577impl CudaTransformerBlock {
578 #[allow(clippy::too_many_arguments)]
582 pub fn new(
583 config: &TransformerConfig,
584 layer_idx: usize,
585 ctx: Arc<CudaContext>,
586 input_norm_weight: &[f32],
587 post_attn_norm_weight: &[f32],
588 w_q: &[f32],
589 w_k: &[f32],
590 w_v: &[f32],
591 w_o: &[f32],
592 w_gate: &[f32],
593 w_up: &[f32],
594 w_down: &[f32],
595 max_seq_len: usize,
596 b_q: Option<&[f32]>,
602 b_k: Option<&[f32]>,
603 b_v: Option<&[f32]>,
604 ) -> Result<Self> {
605 let hidden_size = config.hidden_size;
606 let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
608 let intermediate_size = config.intermediate_size;
609 let num_heads = config.num_attention_heads;
610
611 let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
613 let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
614 let w_q = GpuBuffer::from_host(&ctx, w_q)?;
615 let w_k = GpuBuffer::from_host(&ctx, w_k)?;
616 let w_v = GpuBuffer::from_host(&ctx, w_v)?;
617 let w_o = GpuBuffer::from_host(&ctx, w_o)?;
618 let w_gate = GpuBuffer::from_host(&ctx, w_gate)?;
619 let w_up = GpuBuffer::from_host(&ctx, w_up)?;
620 let w_down = GpuBuffer::from_host(&ctx, w_down)?;
621
622 let single_mask: Vec<f32> = (0..max_seq_len * max_seq_len)
624 .map(|idx| {
625 let row = idx / max_seq_len;
626 let col = idx % max_seq_len;
627 if col <= row {
628 0.0f32
629 } else {
630 f32::NEG_INFINITY
631 }
632 })
633 .collect();
634 let scratch = CudaBlockScratch {
636 norm1_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
637 q: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
638 k: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
639 v: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
640 attn_scores: GpuBuffer::new(&ctx, num_heads * max_seq_len * max_seq_len)?,
641 attn_out: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
642 o_proj_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
643 residual1: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
644 norm2_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
645 gate_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
646 up_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
647 swiglu_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
648 ffn_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
649 norm1_out_f16: None,
650 attn_out_f16: None,
651 norm2_out_f16: None,
652 swiglu_out_f16: None,
653 grad_hidden: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
655 grad_swiglu: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
656 attn_q_batched: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
658 attn_kv_temp: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
659 attn_kv_temp2: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
660 grad_attn_scores: GpuBuffer::new(
663 &ctx,
664 num_heads * max_seq_len * max_seq_len.max(config.head_dim()),
665 )?,
666 lora_inter: GpuBuffer::new(&ctx, 1)?,
668 lora_temp: GpuBuffer::new(&ctx, 1)?,
669 rope_positions: {
670 let positions: Vec<u32> = (0..max_seq_len as u32).collect();
671 let mut buf = GpuBuffer::new(&ctx, max_seq_len)?;
672 buf.copy_from_host(&positions)?;
673 buf
674 },
675 causal_mask_contiguous: GpuBuffer::from_host(&ctx, &single_mask)?,
676 causal_mask_cached_seq_len: max_seq_len,
677 op_us: [0u64; 16],
678 op_profiling_enabled: false,
679 };
680
681 let replicate = |bias: Option<&[f32]>, dim: usize| -> Result<Option<GpuBuffer<f32>>> {
687 match bias {
688 Some(slice) => {
689 debug_assert_eq!(
690 slice.len(),
691 dim,
692 "bias slice len {} != expected dim {dim}",
693 slice.len()
694 );
695 let mut repl: Vec<f32> = Vec::with_capacity(max_seq_len * dim);
696 for _ in 0..max_seq_len {
697 repl.extend_from_slice(slice);
698 }
699 Ok(Some(GpuBuffer::from_host(&ctx, &repl)?))
700 }
701 None => Ok(None),
702 }
703 };
704 let b_q_replicated = replicate(b_q, q_dim)?;
705 let b_k_replicated = replicate(b_k, kv_hidden_size)?;
706 let b_v_replicated = replicate(b_v, kv_hidden_size)?;
707
708 Ok(Self {
709 config: config.clone(),
710 layer_idx,
711 input_norm_weight,
712 post_attn_norm_weight,
713 w_q,
714 w_k,
715 w_v,
716 w_o,
717 w_gate,
718 w_up,
719 w_down,
720 ctx,
721 scratch,
722 norm_zero_buf: vec![0.0f32; hidden_size],
723 q_norm_weight: None, k_norm_weight: None,
725 b_q_replicated,
726 b_k_replicated,
727 b_v_replicated,
728 })
729 }
730
731 #[allow(dead_code)]
733 pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()> {
734 self.q_norm_weight = Some(GpuBuffer::from_host(&self.ctx, q_norm)?);
735 self.k_norm_weight = Some(GpuBuffer::from_host(&self.ctx, k_norm)?);
736 Ok(())
737 }
738
739 pub fn forward(
747 &mut self,
748 input: &GpuBuffer<f32>,
749 output: &mut GpuBuffer<f32>,
750 seq_len: usize,
751 stream: &CudaStream,
752 ) -> Result<()> {
753 let hidden_size = self.config.hidden_size;
754 let q_dim = self.config.q_dim();
755 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
756 let intermediate_size = self.config.intermediate_size;
757
758 rms_norm_forward_with_eps(
763 input,
764 &self.input_norm_weight,
765 &mut self.scratch.norm1_out,
766 saturating_u32(seq_len),
767 saturating_u32(hidden_size),
768 self.config.rms_norm_eps,
769 stream,
770 )?;
771
772 gemm_forward(
775 &self.scratch.norm1_out,
776 &self.w_q,
777 &mut self.scratch.q,
778 saturating_u32(seq_len),
779 saturating_u32(hidden_size),
780 saturating_u32(q_dim),
781 stream,
782 )?;
783 if let Some(b_q_repl) = self.b_q_replicated.as_ref() {
788 cuda_add_inplace(&mut self.scratch.q, b_q_repl, seq_len * q_dim, stream)?;
789 }
790
791 gemm_forward(
792 &self.scratch.norm1_out,
793 &self.w_k,
794 &mut self.scratch.k,
795 saturating_u32(seq_len),
796 saturating_u32(hidden_size),
797 saturating_u32(kv_hidden_size),
798 stream,
799 )?;
800 if let Some(b_k_repl) = self.b_k_replicated.as_ref() {
801 cuda_add_inplace(&mut self.scratch.k, b_k_repl, seq_len * kv_hidden_size, stream)?;
802 }
803
804 gemm_forward(
805 &self.scratch.norm1_out,
806 &self.w_v,
807 &mut self.scratch.v,
808 saturating_u32(seq_len),
809 saturating_u32(hidden_size),
810 saturating_u32(kv_hidden_size),
811 stream,
812 )?;
813 if let Some(b_v_repl) = self.b_v_replicated.as_ref() {
814 cuda_add_inplace(&mut self.scratch.v, b_v_repl, seq_len * kv_hidden_size, stream)?;
815 }
816
817 self.compute_attention_cuda(seq_len, stream)?;
819
820 gemm_forward(
823 &self.scratch.attn_out,
824 &self.w_o,
825 &mut self.scratch.o_proj_out,
826 saturating_u32(seq_len),
827 saturating_u32(q_dim),
828 saturating_u32(hidden_size),
829 stream,
830 )?;
831
832 cuda_add(
834 input,
835 &self.scratch.o_proj_out,
836 &mut self.scratch.residual1,
837 seq_len * hidden_size,
838 stream,
839 )?;
840
841 rms_norm_forward_with_eps(
844 &self.scratch.residual1,
845 &self.post_attn_norm_weight,
846 &mut self.scratch.norm2_out,
847 saturating_u32(seq_len),
848 saturating_u32(hidden_size),
849 self.config.rms_norm_eps,
850 stream,
851 )?;
852
853 gemm_forward(
855 &self.scratch.norm2_out,
856 &self.w_gate,
857 &mut self.scratch.gate_out,
858 saturating_u32(seq_len),
859 saturating_u32(hidden_size),
860 saturating_u32(intermediate_size),
861 stream,
862 )?;
863
864 gemm_forward(
865 &self.scratch.norm2_out,
866 &self.w_up,
867 &mut self.scratch.up_out,
868 saturating_u32(seq_len),
869 saturating_u32(hidden_size),
870 saturating_u32(intermediate_size),
871 stream,
872 )?;
873
874 fused_swiglu_forward(
876 &self.scratch.gate_out,
877 &self.scratch.up_out,
878 &mut self.scratch.swiglu_out,
879 saturating_u32(seq_len * intermediate_size),
880 stream,
881 )?;
882
883 gemm_forward(
885 &self.scratch.swiglu_out,
886 &self.w_down,
887 &mut self.scratch.ffn_out,
888 saturating_u32(seq_len),
889 saturating_u32(intermediate_size),
890 saturating_u32(hidden_size),
891 stream,
892 )?;
893
894 cuda_add(
896 &self.scratch.residual1,
897 &self.scratch.ffn_out,
898 output,
899 seq_len * hidden_size,
900 stream,
901 )?;
902
903 Ok(())
904 }
905
906 fn compute_attention_cuda(&mut self, seq_len: usize, stream: &CudaStream) -> Result<()> {
924 let num_heads = self.config.num_attention_heads;
925 let num_kv_heads = self.config.num_kv_heads;
926 let head_dim = self.config.head_dim();
927 let heads_per_kv = num_heads / num_kv_heads;
928 let scale = 1.0 / (head_dim as f32).sqrt();
929
930 let seq = saturating_u32(seq_len);
931 let nh = saturating_u32(num_heads);
932 let nkv = saturating_u32(num_kv_heads);
933 let hd = saturating_u32(head_dim);
934
935 self.scratch.prepare_causal_mask(seq_len, &self.ctx)?;
937
938 if let Some(ref q_norm) = self.q_norm_weight {
943 for pos in 0..seq_len {
944 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
945 per_head_rmsnorm_forward(q_ref, q_norm, &mut self.scratch.q, nh, hd, pos, stream)?;
946 }
947 }
948 if let Some(ref k_norm) = self.k_norm_weight {
949 for pos in 0..seq_len {
950 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
951 per_head_rmsnorm_forward(k_ref, k_norm, &mut self.scratch.k, nkv, hd, pos, stream)?;
952 }
953 }
954
955 let rope_theta = self.config.rope_theta;
958 {
959 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
960 batched_rope_neox_forward(
961 q_ref,
962 &mut self.scratch.q,
963 &self.scratch.rope_positions,
964 nh,
965 hd,
966 seq,
967 rope_theta,
968 stream,
969 )?;
970 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
971 batched_rope_neox_forward(
972 k_ref,
973 &mut self.scratch.k,
974 &self.scratch.rope_positions,
975 nkv,
976 hd,
977 seq,
978 rope_theta,
979 stream,
980 )?;
981 }
982
983 interleaved_to_batched_forward(
985 &self.scratch.q,
986 &mut self.scratch.attn_q_batched,
987 seq,
988 nh,
989 hd,
990 stream,
991 )?;
992
993 interleaved_to_batched_forward(
995 &self.scratch.k,
996 &mut self.scratch.attn_kv_temp,
997 seq,
998 nkv,
999 hd,
1000 stream,
1001 )?;
1002
1003 if heads_per_kv == 1 {
1005 batched_transpose_forward(
1007 &self.scratch.attn_kv_temp,
1008 &mut self.scratch.attn_kv_temp2,
1009 nh,
1010 seq,
1011 hd,
1012 stream,
1013 )?;
1014 } else {
1015 expand_kv_heads(
1017 &self.scratch.attn_kv_temp,
1018 &mut self.scratch.attn_kv_temp2,
1019 num_kv_heads,
1020 heads_per_kv,
1021 seq_len * head_dim,
1022 stream,
1023 )?;
1024 batched_transpose_forward(
1026 &self.scratch.attn_kv_temp2,
1027 &mut self.scratch.attn_kv_temp,
1028 nh,
1029 seq,
1030 hd,
1031 stream,
1032 )?;
1033 unsafe {
1037 self.scratch
1038 .attn_kv_temp2
1039 .copy_from_buffer_async(&self.scratch.attn_kv_temp, stream)
1040 .map_err(|e| {
1041 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1042 "K^T buffer copy failed: {e}"
1043 ))
1044 })?;
1045 }
1046 }
1047
1048 batched_4d_gemm_forward(
1053 &self.scratch.attn_q_batched,
1054 &self.scratch.attn_kv_temp2,
1055 &mut self.scratch.attn_scores,
1056 1,
1057 nh,
1058 seq,
1059 seq,
1060 hd,
1061 stream,
1062 )?;
1063
1064 let total_scores = nh * seq * seq;
1066 {
1067 let scores_view = unsafe {
1071 GpuBuffer::<f32>::from_raw_parts(
1072 self.scratch.attn_scores.as_ptr(),
1073 self.scratch.attn_scores.len(),
1074 )
1075 };
1076 scale_forward(
1077 &scores_view,
1078 &mut self.scratch.attn_scores,
1079 scale,
1080 total_scores,
1081 stream,
1082 )?;
1083 leak(scores_view);
1084 }
1085
1086 {
1092 let seq_sq = (seq * seq) as usize;
1093 let mask_ptr = self.scratch.causal_mask_contiguous.as_ptr();
1094 let scores_base = self.scratch.attn_scores.as_ptr();
1095 for head in 0..nh as usize {
1096 let byte_offset = (head * seq_sq * 4) as u64; let head_ptr = scores_base + byte_offset;
1098 let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
1102 let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1103 let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
1104 residual_add_forward(&mask_view, &scores_view, &mut out_view, seq * seq, stream)?;
1105 leak(mask_view);
1106 leak(scores_view);
1107 leak(out_view);
1108 }
1109 }
1110
1111 let total_rows = nh * seq;
1113 {
1114 let scores_view = unsafe {
1118 GpuBuffer::<f32>::from_raw_parts(
1119 self.scratch.attn_scores.as_ptr(),
1120 self.scratch.attn_scores.len(),
1121 )
1122 };
1123 batched_softmax_forward(
1124 &scores_view,
1125 &mut self.scratch.attn_scores,
1126 total_rows,
1127 seq,
1128 stream,
1129 )?;
1130 leak(scores_view);
1131 }
1132
1133 interleaved_to_batched_forward(
1135 &self.scratch.v,
1136 &mut self.scratch.attn_kv_temp,
1137 seq,
1138 nkv,
1139 hd,
1140 stream,
1141 )?;
1142
1143 if heads_per_kv == 1 {
1144 } else {
1146 expand_kv_heads(
1148 &self.scratch.attn_kv_temp,
1149 &mut self.scratch.attn_kv_temp2,
1150 num_kv_heads,
1151 heads_per_kv,
1152 seq_len * head_dim,
1153 stream,
1154 )?;
1155 unsafe {
1158 self.scratch
1159 .attn_kv_temp
1160 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1161 .map_err(|e| {
1162 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1163 "V expanded buffer copy failed: {e}"
1164 ))
1165 })?;
1166 }
1167 }
1168
1169 batched_4d_gemm_forward(
1174 &self.scratch.attn_scores,
1175 &self.scratch.attn_kv_temp,
1176 &mut self.scratch.attn_q_batched,
1177 1,
1178 nh,
1179 seq,
1180 hd,
1181 seq,
1182 stream,
1183 )?;
1184
1185 batched_to_interleaved_forward(
1187 &self.scratch.attn_q_batched,
1188 &mut self.scratch.attn_out,
1189 seq,
1190 nh,
1191 hd,
1192 stream,
1193 )?;
1194
1195 Ok(())
1196 }
1197
1198 pub fn layer_idx(&self) -> usize {
1200 self.layer_idx
1201 }
1202
1203 pub fn config(&self) -> &TransformerConfig {
1205 &self.config
1206 }
1207
1208 #[provable_contracts_macros::contract("backward-pass-v1", equation = "backward")]
1226 pub fn backward(
1227 &mut self,
1228 input: &GpuBuffer<f32>,
1229 grad_output: &GpuBuffer<f32>,
1230 grad_input: &mut GpuBuffer<f32>,
1231 seq_len: usize,
1232 stream: &CudaStream,
1233 grad_ws: &mut CudaGradWorkspace,
1234 ) -> Result<()> {
1235 let hidden_size = self.config.hidden_size;
1236 let intermediate_size = self.config.intermediate_size;
1237 let eps = 1e-5_f32;
1238
1239 grad_ws.zero_norm_grads(&self.norm_zero_buf)?;
1243
1244 self.backward_ffn(grad_output, seq_len, hidden_size, intermediate_size, stream, grad_ws)?;
1247
1248 self.backward_post_attn_norm(grad_input, seq_len, hidden_size, eps, stream, grad_ws)?;
1250
1251 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
1256
1257 self.backward_attention(grad_input, seq_len, stream, grad_ws)?;
1260
1261 self.backward_residual_and_input_norm(
1263 input,
1264 grad_output,
1265 grad_input,
1266 seq_len,
1267 hidden_size,
1268 eps,
1269 stream,
1270 grad_ws,
1271 )?;
1272
1273 Ok(())
1274 }
1275
1276 fn backward_ffn(
1291 &mut self,
1292 grad_output: &GpuBuffer<f32>,
1293 seq_len: usize,
1294 hidden_size: usize,
1295 intermediate_size: usize,
1296 stream: &CudaStream,
1297 grad_ws: &mut CudaGradWorkspace,
1298 ) -> Result<()> {
1299 let n_inter = saturating_u32(seq_len * intermediate_size);
1300 let n_hidden = saturating_u32(seq_len * hidden_size);
1301
1302 gemm_backward_a(
1304 grad_output,
1305 &self.w_down,
1306 &mut self.scratch.grad_swiglu,
1307 saturating_u32(seq_len),
1308 saturating_u32(intermediate_size),
1309 saturating_u32(hidden_size),
1310 stream,
1311 )?;
1312
1313 gemm_backward_b(
1316 &self.scratch.swiglu_out,
1317 grad_output,
1318 &mut grad_ws.grad_down,
1319 saturating_u32(seq_len),
1320 saturating_u32(intermediate_size),
1321 saturating_u32(hidden_size),
1322 stream,
1323 )?;
1324
1325 elementwise_mul_forward(
1329 &self.scratch.grad_swiglu,
1330 &self.scratch.up_out,
1331 &mut self.scratch.swiglu_out,
1332 n_inter,
1333 stream,
1334 )?;
1335
1336 silu_backward(
1339 &self.scratch.gate_out,
1340 &self.scratch.swiglu_out,
1341 &mut self.scratch.up_out,
1342 stream,
1343 )?;
1344 silu_forward(&self.scratch.gate_out, &mut self.scratch.swiglu_out, n_inter, stream)?;
1348
1349 elementwise_mul_forward(
1351 &self.scratch.grad_swiglu,
1352 &self.scratch.swiglu_out,
1353 &mut self.scratch.gate_out,
1354 n_inter,
1355 stream,
1356 )?;
1357 gemm_backward_b(
1363 &self.scratch.norm2_out,
1364 &self.scratch.up_out,
1365 &mut grad_ws.grad_gate,
1366 saturating_u32(seq_len),
1367 saturating_u32(hidden_size),
1368 saturating_u32(intermediate_size),
1369 stream,
1370 )?;
1371
1372 gemm_backward_b(
1374 &self.scratch.norm2_out,
1375 &self.scratch.gate_out,
1376 &mut grad_ws.grad_up,
1377 saturating_u32(seq_len),
1378 saturating_u32(hidden_size),
1379 saturating_u32(intermediate_size),
1380 stream,
1381 )?;
1382
1383 gemm_backward_a(
1387 &self.scratch.up_out,
1388 &self.w_gate,
1389 &mut self.scratch.ffn_out,
1390 saturating_u32(seq_len),
1391 saturating_u32(hidden_size),
1392 saturating_u32(intermediate_size),
1393 stream,
1394 )?;
1395
1396 gemm_backward_a(
1398 &self.scratch.gate_out,
1399 &self.w_up,
1400 &mut self.scratch.grad_hidden,
1401 saturating_u32(seq_len),
1402 saturating_u32(hidden_size),
1403 saturating_u32(intermediate_size),
1404 stream,
1405 )?;
1406
1407 residual_add_forward(
1409 &self.scratch.ffn_out,
1410 &self.scratch.grad_hidden,
1411 &mut self.scratch.norm2_out,
1412 n_hidden,
1413 stream,
1414 )?;
1415
1416 Ok(())
1417 }
1418
1419 fn backward_post_attn_norm(
1421 &mut self,
1422 grad_input: &mut GpuBuffer<f32>,
1423 seq_len: usize,
1424 hidden_size: usize,
1425 eps: f32,
1426 stream: &CudaStream,
1427 grad_ws: &mut CudaGradWorkspace,
1428 ) -> Result<()> {
1429 unsafe {
1432 self.scratch
1433 .grad_hidden
1434 .copy_from_buffer_async(&self.scratch.norm2_out, stream)
1435 .map_err(|e| {
1436 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1437 "Backward norm D2D copy failed: {e}"
1438 ))
1439 })?;
1440 }
1441
1442 rms_norm_backward(
1443 &self.scratch.residual1,
1444 &self.post_attn_norm_weight,
1445 &self.scratch.grad_hidden,
1446 grad_input,
1447 &mut grad_ws.grad_post_attn_norm,
1448 saturating_u32(seq_len),
1449 saturating_u32(hidden_size),
1450 eps,
1451 stream,
1452 )
1453 }
1454
1455 fn backward_attention(
1468 &mut self,
1469 grad_input: &mut GpuBuffer<f32>,
1470 seq_len: usize,
1471 stream: &CudaStream,
1472 grad_ws: &mut CudaGradWorkspace,
1473 ) -> Result<()> {
1474 let hidden_size = self.config.hidden_size;
1475 let q_dim = self.config.q_dim();
1476 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
1477 let num_heads = self.config.num_attention_heads;
1478 let num_kv_heads = self.config.num_kv_heads;
1479 let head_dim = self.config.head_dim();
1480 let heads_per_kv = num_heads / num_kv_heads;
1481 let scale = 1.0 / (head_dim as f32).sqrt();
1482
1483 let seq = saturating_u32(seq_len);
1484 let nh = saturating_u32(num_heads);
1485 let nkv = saturating_u32(num_kv_heads);
1486 let hd = saturating_u32(head_dim);
1487
1488 gemm_backward_a(
1493 grad_input,
1494 &self.w_o,
1495 &mut self.scratch.grad_hidden,
1496 seq,
1497 saturating_u32(q_dim),
1498 saturating_u32(hidden_size),
1499 stream,
1500 )?;
1501
1502 gemm_backward_b(
1504 &self.scratch.attn_out,
1505 grad_input,
1506 &mut grad_ws.grad_w_o,
1507 seq,
1508 saturating_u32(q_dim),
1509 saturating_u32(hidden_size),
1510 stream,
1511 )?;
1512
1513 interleaved_to_batched_forward(
1517 &self.scratch.grad_hidden,
1518 &mut self.scratch.attn_q_batched,
1519 seq,
1520 nh,
1521 hd,
1522 stream,
1523 )?;
1524
1525 interleaved_to_batched_forward(
1529 &self.scratch.v,
1530 &mut self.scratch.attn_kv_temp,
1531 seq,
1532 nkv,
1533 hd,
1534 stream,
1535 )?;
1536
1537 if heads_per_kv > 1 {
1539 expand_kv_heads(
1540 &self.scratch.attn_kv_temp,
1541 &mut self.scratch.attn_kv_temp2,
1542 num_kv_heads,
1543 heads_per_kv,
1544 seq_len * head_dim,
1545 stream,
1546 )?;
1547 unsafe {
1549 self.scratch
1550 .attn_kv_temp
1551 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1552 .map_err(|e| {
1553 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1554 "Attn backward V expand D2D copy failed: {e}"
1555 ))
1556 })?;
1557 }
1558 }
1559 batched_transpose_forward(
1563 &self.scratch.attn_kv_temp,
1564 &mut self.scratch.attn_kv_temp2,
1565 nh,
1566 seq,
1567 hd,
1568 stream,
1569 )?;
1570 batched_4d_gemm_forward(
1574 &self.scratch.attn_q_batched,
1575 &self.scratch.attn_kv_temp2,
1576 &mut self.scratch.grad_attn_scores,
1577 1,
1578 nh,
1579 seq,
1580 seq,
1581 hd,
1582 stream,
1583 )?;
1584
1585 batched_transpose_forward(
1595 &self.scratch.attn_q_batched, &mut self.scratch.attn_kv_temp, nh,
1598 seq,
1599 hd,
1600 stream,
1601 )?;
1602
1603 batched_4d_gemm_forward(
1605 &self.scratch.attn_kv_temp, &self.scratch.attn_scores, &mut self.scratch.attn_kv_temp2, 1,
1609 nh,
1610 hd, seq, seq, stream,
1614 )?;
1615
1616 batched_transpose_forward(
1618 &self.scratch.attn_kv_temp2, &mut self.scratch.attn_kv_temp, nh,
1621 hd,
1622 seq,
1623 stream,
1624 )?;
1625 let total_rows = nh * seq;
1632 {
1633 let grad_scores_view = unsafe {
1637 GpuBuffer::<f32>::from_raw_parts(
1638 self.scratch.grad_attn_scores.as_ptr(),
1639 self.scratch.grad_attn_scores.len(),
1640 )
1641 };
1642 batched_softmax_backward(
1643 &self.scratch.attn_scores,
1644 &grad_scores_view,
1645 &mut self.scratch.grad_attn_scores,
1646 total_rows,
1647 seq,
1648 stream,
1649 )?;
1650 leak(grad_scores_view);
1651 }
1652 let total_scores = nh * seq * seq;
1657 {
1658 let scores_view = unsafe {
1660 GpuBuffer::<f32>::from_raw_parts(
1661 self.scratch.grad_attn_scores.as_ptr(),
1662 self.scratch.grad_attn_scores.len(),
1663 )
1664 };
1665 scale_forward(
1666 &scores_view,
1667 &mut self.scratch.grad_attn_scores,
1668 scale,
1669 total_scores,
1670 stream,
1671 )?;
1672 leak(scores_view);
1673 }
1674
1675 interleaved_to_batched_forward(
1679 &self.scratch.k,
1680 &mut self.scratch.attn_kv_temp2,
1681 seq,
1682 nkv,
1683 hd,
1684 stream,
1685 )?;
1686
1687 if heads_per_kv > 1 {
1688 unsafe {
1691 self.scratch
1692 .attn_q_batched
1693 .copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
1694 .map_err(|e| {
1695 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1696 "Attn backward K copy for GQA expand failed: {e}"
1697 ))
1698 })?;
1699 }
1700 expand_kv_heads(
1701 &self.scratch.attn_q_batched,
1702 &mut self.scratch.attn_kv_temp2,
1703 num_kv_heads,
1704 heads_per_kv,
1705 seq_len * head_dim,
1706 stream,
1707 )?;
1708 }
1709 batched_4d_gemm_forward(
1713 &self.scratch.grad_attn_scores,
1714 &self.scratch.attn_kv_temp2,
1715 &mut self.scratch.attn_q_batched,
1716 1,
1717 nh,
1718 seq,
1719 hd,
1720 seq,
1721 stream,
1722 )?;
1723
1724 interleaved_to_batched_forward(
1728 &self.scratch.q,
1729 &mut self.scratch.o_proj_out, seq,
1731 nh,
1732 hd,
1733 stream,
1734 )?;
1735
1736 batched_transpose_forward(
1738 &self.scratch.o_proj_out,
1739 &mut self.scratch.attn_kv_temp2, nh,
1741 seq,
1742 hd,
1743 stream,
1744 )?;
1745
1746 batched_4d_gemm_forward(
1748 &self.scratch.attn_kv_temp2,
1749 &self.scratch.grad_attn_scores,
1750 &mut self.scratch.ffn_out, 1,
1752 nh,
1753 hd,
1754 seq,
1755 seq,
1756 stream,
1757 )?;
1758
1759 batched_transpose_forward(
1761 &self.scratch.ffn_out,
1762 &mut self.scratch.attn_kv_temp2, nh,
1764 hd,
1765 seq,
1766 stream,
1767 )?;
1768
1769 if heads_per_kv > 1 {
1772 self.reduce_gqa_gradients(num_kv_heads, heads_per_kv, seq_len, head_dim, stream)?;
1773 }
1774
1775 batched_to_interleaved_forward(
1778 &self.scratch.attn_q_batched,
1779 &mut self.scratch.o_proj_out,
1780 seq,
1781 nh,
1782 hd,
1783 stream,
1784 )?;
1785
1786 batched_to_interleaved_forward(
1788 &self.scratch.attn_kv_temp2,
1789 &mut self.scratch.norm2_out,
1790 seq,
1791 nkv,
1792 hd,
1793 stream,
1794 )?;
1795
1796 batched_to_interleaved_forward(
1798 &self.scratch.attn_kv_temp,
1799 &mut self.scratch.ffn_out,
1800 seq,
1801 nkv,
1802 hd,
1803 stream,
1804 )?;
1805
1806 let rope_theta = self.config.rope_theta;
1811 {
1812 let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.o_proj_out)) };
1814 batched_rope_neox_backward(
1815 q_ref,
1816 &mut self.scratch.o_proj_out,
1817 &self.scratch.rope_positions,
1818 nh,
1819 hd,
1820 seq,
1821 rope_theta,
1822 stream,
1823 )?;
1824 let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.norm2_out)) };
1826 batched_rope_neox_backward(
1827 k_ref,
1828 &mut self.scratch.norm2_out,
1829 &self.scratch.rope_positions,
1830 nkv,
1831 hd,
1832 seq,
1833 rope_theta,
1834 stream,
1835 )?;
1836 }
1837
1838 gemm_backward_a(
1843 &self.scratch.o_proj_out, &self.w_q,
1845 &mut self.scratch.grad_hidden,
1846 seq,
1847 saturating_u32(hidden_size),
1848 saturating_u32(q_dim),
1849 stream,
1850 )?;
1851
1852 gemm_backward_a(
1857 &self.scratch.norm2_out, &self.w_k,
1859 &mut self.scratch.grad_attn_scores, seq,
1861 saturating_u32(hidden_size),
1862 saturating_u32(kv_hidden_size),
1863 stream,
1864 )?;
1865 cuda_add_inplace(
1866 &mut self.scratch.grad_hidden,
1867 &self.scratch.grad_attn_scores,
1868 seq_len * hidden_size,
1869 stream,
1870 )?;
1871
1872 gemm_backward_a(
1876 &self.scratch.ffn_out, &self.w_v,
1878 &mut self.scratch.grad_attn_scores, seq,
1880 saturating_u32(hidden_size),
1881 saturating_u32(kv_hidden_size),
1882 stream,
1883 )?;
1884 cuda_add_inplace(
1885 &mut self.scratch.grad_hidden,
1886 &self.scratch.grad_attn_scores,
1887 seq_len * hidden_size,
1888 stream,
1889 )?;
1890
1891 gemm_backward_b(
1893 &self.scratch.norm1_out,
1894 &self.scratch.o_proj_out, &mut grad_ws.grad_w_q,
1896 seq,
1897 saturating_u32(hidden_size),
1898 saturating_u32(q_dim),
1899 stream,
1900 )?;
1901
1902 gemm_backward_b(
1904 &self.scratch.norm1_out,
1905 &self.scratch.norm2_out, &mut grad_ws.grad_w_k,
1907 seq,
1908 saturating_u32(hidden_size),
1909 saturating_u32(kv_hidden_size),
1910 stream,
1911 )?;
1912
1913 gemm_backward_b(
1915 &self.scratch.norm1_out,
1916 &self.scratch.ffn_out, &mut grad_ws.grad_w_v,
1918 seq,
1919 saturating_u32(hidden_size),
1920 saturating_u32(kv_hidden_size),
1921 stream,
1922 )?;
1923
1924 unsafe {
1927 grad_input.copy_from_buffer_async(&self.scratch.grad_hidden, stream).map_err(|e| {
1928 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1929 "Attn backward grad_hidden → grad_input D2D copy failed: {e}"
1930 ))
1931 })?;
1932 }
1933
1934 Ok(())
1935 }
1936
1937 fn reduce_gqa_gradients(
1945 &mut self,
1946 num_kv_heads: usize,
1947 heads_per_kv: usize,
1948 seq_len: usize,
1949 head_dim: usize,
1950 stream: &CudaStream,
1951 ) -> Result<()> {
1952 let elems_per_head = seq_len * head_dim;
1953
1954 self.reduce_single_gqa_gradient(true, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1956
1957 self.reduce_single_gqa_gradient(false, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
1959
1960 let kv_elems = num_kv_heads * elems_per_head;
1962 unsafe {
1964 self.scratch
1965 .attn_kv_temp2
1966 .copy_from_buffer_at_async(&self.scratch.grad_attn_scores, 0, 0, kv_elems, stream)
1967 .map_err(|e| {
1968 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1969 "GQA grad_K reduced final copy failed: {e}"
1970 ))
1971 })?;
1972 self.scratch
1973 .attn_kv_temp
1974 .copy_from_buffer_at_async(&self.scratch.ffn_out, 0, 0, kv_elems, stream)
1975 .map_err(|e| {
1976 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
1977 "GQA grad_V reduced final copy failed: {e}"
1978 ))
1979 })?;
1980 }
1981 Ok(())
1982 }
1983
1984 fn reduce_single_gqa_gradient(
1990 &mut self,
1991 is_k: bool,
1992 num_kv_heads: usize,
1993 heads_per_kv: usize,
1994 elems_per_head: usize,
1995 stream: &CudaStream,
1996 ) -> Result<()> {
1997 let label = if is_k { "K" } else { "V" };
1998
1999 for kv_h in 0..num_kv_heads {
2000 let dst_offset = kv_h * elems_per_head;
2001 let first_h = kv_h * heads_per_kv;
2002 let src_offset = first_h * elems_per_head;
2003
2004 unsafe {
2007 let (dst, src) = if is_k {
2008 (&mut self.scratch.grad_attn_scores, &self.scratch.attn_kv_temp2)
2009 } else {
2010 (&mut self.scratch.ffn_out, &self.scratch.attn_kv_temp)
2011 };
2012 dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
2013 .map_err(|e| {
2014 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2015 "GQA grad_{label} reduce base copy failed: {e}"
2016 ))
2017 })?;
2018 }
2019
2020 for rep in 1..heads_per_kv {
2022 let h = kv_h * heads_per_kv + rep;
2023 let h_offset = h * elems_per_head;
2024
2025 unsafe {
2028 let src =
2029 if is_k { &self.scratch.attn_kv_temp2 } else { &self.scratch.attn_kv_temp };
2030 self.scratch
2031 .o_proj_out
2032 .copy_from_buffer_at_async(src, 0, h_offset, elems_per_head, stream)
2033 .map_err(|e| {
2034 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2035 "GQA grad_{label} reduce head copy failed: {e}"
2036 ))
2037 })?;
2038 }
2039
2040 unsafe {
2043 let dst_buf =
2044 if is_k { &self.scratch.grad_attn_scores } else { &self.scratch.ffn_out };
2045 let dst_view = GpuBuffer::<f32>::from_raw_parts(
2046 dst_buf.as_ptr() + (dst_offset as u64 * 4),
2047 elems_per_head,
2048 );
2049 let src_view = GpuBuffer::<f32>::from_raw_parts(
2050 self.scratch.o_proj_out.as_ptr(),
2051 elems_per_head,
2052 );
2053 let mut sum_view = GpuBuffer::<f32>::from_raw_parts(
2054 self.scratch.grad_hidden.as_ptr(),
2055 elems_per_head,
2056 );
2057 residual_add_forward(
2058 &dst_view,
2059 &src_view,
2060 &mut sum_view,
2061 saturating_u32(elems_per_head),
2062 stream,
2063 )?;
2064 let dst_buf = if is_k {
2066 &mut self.scratch.grad_attn_scores
2067 } else {
2068 &mut self.scratch.ffn_out
2069 };
2070 dst_buf
2071 .copy_from_buffer_at_async(
2072 &self.scratch.grad_hidden,
2073 dst_offset,
2074 0,
2075 elems_per_head,
2076 stream,
2077 )
2078 .map_err(|e| {
2079 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2080 "GQA grad_{label} reduce sum copy failed: {e}"
2081 ))
2082 })?;
2083 leak(dst_view);
2084 leak(src_view);
2085 leak(sum_view);
2086 }
2087 }
2088 }
2089 Ok(())
2090 }
2091
2092 fn backward_residual_and_input_norm(
2094 &mut self,
2095 input: &GpuBuffer<f32>,
2096 grad_output: &GpuBuffer<f32>,
2097 grad_input: &mut GpuBuffer<f32>,
2098 seq_len: usize,
2099 hidden_size: usize,
2100 eps: f32,
2101 stream: &CudaStream,
2102 grad_ws: &mut CudaGradWorkspace,
2103 ) -> Result<()> {
2104 unsafe {
2115 self.scratch.grad_hidden.copy_from_buffer_async(grad_input, stream).map_err(|e| {
2116 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2117 "Backward residual grad_hidden D2D copy failed: {e}"
2118 ))
2119 })?;
2120 }
2121
2122 rms_norm_backward(
2124 input,
2125 &self.input_norm_weight,
2126 &self.scratch.grad_hidden,
2127 grad_input,
2128 &mut grad_ws.grad_input_norm,
2129 saturating_u32(seq_len),
2130 saturating_u32(hidden_size),
2131 eps,
2132 stream,
2133 )?;
2134
2135 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)
2137 }
2138
2139 pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2151 let hidden = self.config.hidden_size;
2152 let q_dim = self.config.q_dim();
2153 let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
2154 let intermediate = self.config.intermediate_size;
2155
2156 let z = |n: usize| -> Result<GpuBuffer<f32>> {
2160 Ok(GpuBuffer::from_host(&self.ctx, &vec![0.0f32; n])?)
2161 };
2162 Ok(GpuBlockOptimizerState {
2163 m_w_q: z(q_dim * hidden)?,
2164 v_w_q: z(q_dim * hidden)?,
2165 m_w_k: z(hidden * kv_hidden)?,
2166 v_w_k: z(hidden * kv_hidden)?,
2167 m_w_v: z(hidden * kv_hidden)?,
2168 v_w_v: z(hidden * kv_hidden)?,
2169 m_w_o: z(hidden * q_dim)?,
2170 v_w_o: z(hidden * q_dim)?,
2171 m_w_gate: z(hidden * intermediate)?,
2172 v_w_gate: z(hidden * intermediate)?,
2173 m_w_up: z(hidden * intermediate)?,
2174 v_w_up: z(hidden * intermediate)?,
2175 m_w_down: z(intermediate * hidden)?,
2176 v_w_down: z(intermediate * hidden)?,
2177 m_input_norm: z(hidden)?,
2178 v_input_norm: z(hidden)?,
2179 m_post_attn_norm: z(hidden)?,
2180 v_post_attn_norm: z(hidden)?,
2181 })
2182 }
2183
2184 pub fn optimizer_step(
2197 &mut self,
2198 state: &mut GpuBlockOptimizerState,
2199 step: u32,
2200 lr: f32,
2201 beta1: f32,
2202 beta2: f32,
2203 eps: f32,
2204 weight_decay: f32,
2205 stream: &CudaStream,
2206 grad_ws: &CudaGradWorkspace,
2207 ) -> Result<()> {
2208 debug_assert!(step > 0, "C-OPTSTEP-001: step must be > 0 for bias adjust");
2209
2210 let n_wq = self.w_q.len() as u32;
2213 let n_wk = self.w_k.len() as u32;
2214 let n_wv = self.w_v.len() as u32;
2215 let n_wo = self.w_o.len() as u32;
2216 let n_gate = self.w_gate.len() as u32;
2217 let n_up = self.w_up.len() as u32;
2218 let n_down = self.w_down.len() as u32;
2219 let n_inorm = self.input_norm_weight.len() as u32;
2220 let n_panorm = self.post_attn_norm_weight.len() as u32;
2221
2222 adamw_step_cuda(
2224 &mut self.w_q,
2225 &grad_ws.grad_w_q,
2226 &mut state.m_w_q,
2227 &mut state.v_w_q,
2228 lr,
2229 beta1,
2230 beta2,
2231 eps,
2232 weight_decay,
2233 step,
2234 n_wq,
2235 stream,
2236 )?;
2237 adamw_step_cuda(
2238 &mut self.w_k,
2239 &grad_ws.grad_w_k,
2240 &mut state.m_w_k,
2241 &mut state.v_w_k,
2242 lr,
2243 beta1,
2244 beta2,
2245 eps,
2246 weight_decay,
2247 step,
2248 n_wk,
2249 stream,
2250 )?;
2251 adamw_step_cuda(
2252 &mut self.w_v,
2253 &grad_ws.grad_w_v,
2254 &mut state.m_w_v,
2255 &mut state.v_w_v,
2256 lr,
2257 beta1,
2258 beta2,
2259 eps,
2260 weight_decay,
2261 step,
2262 n_wv,
2263 stream,
2264 )?;
2265 adamw_step_cuda(
2266 &mut self.w_o,
2267 &grad_ws.grad_w_o,
2268 &mut state.m_w_o,
2269 &mut state.v_w_o,
2270 lr,
2271 beta1,
2272 beta2,
2273 eps,
2274 weight_decay,
2275 step,
2276 n_wo,
2277 stream,
2278 )?;
2279
2280 adamw_step_cuda(
2282 &mut self.w_gate,
2283 &grad_ws.grad_gate,
2284 &mut state.m_w_gate,
2285 &mut state.v_w_gate,
2286 lr,
2287 beta1,
2288 beta2,
2289 eps,
2290 weight_decay,
2291 step,
2292 n_gate,
2293 stream,
2294 )?;
2295 adamw_step_cuda(
2296 &mut self.w_up,
2297 &grad_ws.grad_up,
2298 &mut state.m_w_up,
2299 &mut state.v_w_up,
2300 lr,
2301 beta1,
2302 beta2,
2303 eps,
2304 weight_decay,
2305 step,
2306 n_up,
2307 stream,
2308 )?;
2309 adamw_step_cuda(
2310 &mut self.w_down,
2311 &grad_ws.grad_down,
2312 &mut state.m_w_down,
2313 &mut state.v_w_down,
2314 lr,
2315 beta1,
2316 beta2,
2317 eps,
2318 weight_decay,
2319 step,
2320 n_down,
2321 stream,
2322 )?;
2323
2324 adamw_step_cuda(
2326 &mut self.input_norm_weight,
2327 &grad_ws.grad_input_norm,
2328 &mut state.m_input_norm,
2329 &mut state.v_input_norm,
2330 lr,
2331 beta1,
2332 beta2,
2333 eps,
2334 weight_decay,
2335 step,
2336 n_inorm,
2337 stream,
2338 )?;
2339 adamw_step_cuda(
2340 &mut self.post_attn_norm_weight,
2341 &grad_ws.grad_post_attn_norm,
2342 &mut state.m_post_attn_norm,
2343 &mut state.v_post_attn_norm,
2344 lr,
2345 beta1,
2346 beta2,
2347 eps,
2348 weight_decay,
2349 step,
2350 n_panorm,
2351 stream,
2352 )?;
2353
2354 Ok(())
2355 }
2356
2357 pub fn download_weights(&self) -> Result<BlockWeights> {
2367 let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
2368 let mut host = vec![0.0f32; buf.len()];
2369 buf.copy_to_host(&mut host).map_err(|e| {
2370 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2371 "Weight download failed: {e}"
2372 ))
2373 })?;
2374 Ok(host)
2375 };
2376
2377 Ok(BlockWeights {
2378 w_q: download(&self.w_q)?,
2379 w_k: download(&self.w_k)?,
2380 w_v: download(&self.w_v)?,
2381 w_o: download(&self.w_o)?,
2382 w_gate: download(&self.w_gate)?,
2383 w_up: download(&self.w_up)?,
2384 w_down: download(&self.w_down)?,
2385 input_norm_weight: download(&self.input_norm_weight)?,
2386 post_attn_norm_weight: download(&self.post_attn_norm_weight)?,
2387 })
2388 }
2389}
2390
2391#[cfg(feature = "cuda")]
2397pub struct BlockWeights {
2398 pub w_q: Vec<f32>,
2399 pub w_k: Vec<f32>,
2400 pub w_v: Vec<f32>,
2401 pub w_o: Vec<f32>,
2402 pub w_gate: Vec<f32>,
2403 pub w_up: Vec<f32>,
2404 pub w_down: Vec<f32>,
2405 pub input_norm_weight: Vec<f32>,
2406 pub post_attn_norm_weight: Vec<f32>,
2407}
2408
2409#[cfg(feature = "cuda")]
2413fn cuda_add(
2414 a: &GpuBuffer<f32>,
2415 b: &GpuBuffer<f32>,
2416 output: &mut GpuBuffer<f32>,
2417 n: usize,
2418 stream: &CudaStream,
2419) -> Result<()> {
2420 residual_add_forward(a, b, output, saturating_u32(n), stream)
2421}
2422
2423#[cfg(feature = "cuda")]
2431pub(crate) fn cuda_add_inplace(
2432 target: &mut GpuBuffer<f32>,
2433 source: &GpuBuffer<f32>,
2434 n: usize,
2435 stream: &CudaStream,
2436) -> Result<()> {
2437 let target_ref: &GpuBuffer<f32> = unsafe { &*std::ptr::from_ref::<GpuBuffer<f32>>(target) };
2441 residual_add_forward(target_ref, source, target, saturating_u32(n), stream)
2442}
2443
2444#[cfg(feature = "cuda")]
2448fn cuda_mul(
2449 a: &GpuBuffer<f32>,
2450 b: &GpuBuffer<f32>,
2451 output: &mut GpuBuffer<f32>,
2452 n: usize,
2453 stream: &CudaStream,
2454) -> Result<()> {
2455 crate::autograd::cuda_forward::elementwise_mul_forward(a, b, output, saturating_u32(n), stream)
2456}
2457
2458#[cfg(not(feature = "cuda"))]
2460pub struct CudaTransformerBlock;
2461
2462#[cfg(not(feature = "cuda"))]
2463impl CudaTransformerBlock {
2464 pub fn layer_idx(&self) -> usize {
2465 0
2466 }
2467}
2468
2469#[cfg(feature = "cuda")]
2478pub enum CudaBlock {
2479 Fp32(CudaTransformerBlock),
2481 Nf4(CudaNf4TransformerBlock),
2483}
2484
2485#[cfg(feature = "cuda")]
2486impl CudaBlock {
2487 pub(crate) fn forward(
2492 &mut self,
2493 input: &GpuBuffer<f32>,
2494 output: &mut GpuBuffer<f32>,
2495 seq_len: usize,
2496 stream: &CudaStream,
2497 shared_scratch: Option<&mut CudaBlockScratch>,
2498 ) -> Result<()> {
2499 match self {
2500 CudaBlock::Fp32(b) => b.forward(input, output, seq_len, stream),
2501 CudaBlock::Nf4(b) => {
2502 let scratch =
2503 shared_scratch.expect("C-SCRATCH-001: NF4 blocks require shared scratch");
2504 b.forward(input, output, seq_len, stream, scratch)
2505 }
2506 }
2507 }
2508
2509 pub fn layer_idx(&self) -> usize {
2511 match self {
2512 CudaBlock::Fp32(b) => b.layer_idx(),
2513 CudaBlock::Nf4(b) => b.layer_idx,
2514 }
2515 }
2516
2517 pub fn backward(
2522 &mut self,
2523 input: &GpuBuffer<f32>,
2524 grad_output: &GpuBuffer<f32>,
2525 grad_input: &mut GpuBuffer<f32>,
2526 seq_len: usize,
2527 stream: &CudaStream,
2528 grad_ws: &mut CudaGradWorkspace,
2529 ) -> Result<()> {
2530 match self {
2531 CudaBlock::Fp32(b) => {
2532 b.backward(input, grad_output, grad_input, seq_len, stream, grad_ws)
2533 }
2534 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2535 "backward not supported on NF4 blocks (frozen weights)".into(),
2536 )),
2537 }
2538 }
2539
2540 pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
2542 match self {
2543 CudaBlock::Fp32(b) => b.init_optimizer_state(),
2544 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2545 "init_optimizer_state not supported on NF4 blocks".into(),
2546 )),
2547 }
2548 }
2549
2550 pub fn download_weights(&self) -> Result<BlockWeights> {
2552 match self {
2553 CudaBlock::Fp32(b) => b.download_weights(),
2554 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2555 "download_weights not supported on NF4 blocks".into(),
2556 )),
2557 }
2558 }
2559
2560 pub fn optimizer_step(
2562 &mut self,
2563 state: &mut GpuBlockOptimizerState,
2564 step: u32,
2565 lr: f32,
2566 beta1: f32,
2567 beta2: f32,
2568 eps: f32,
2569 weight_decay: f32,
2570 stream: &CudaStream,
2571 grad_ws: &CudaGradWorkspace,
2572 ) -> Result<()> {
2573 match self {
2574 CudaBlock::Fp32(b) => {
2575 b.optimizer_step(state, step, lr, beta1, beta2, eps, weight_decay, stream, grad_ws)
2576 }
2577 CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2578 "optimizer_step not supported on NF4 blocks (frozen weights)".into(),
2579 )),
2580 }
2581 }
2582
2583 #[allow(clippy::too_many_arguments)]
2587 pub(crate) fn backward_nf4(
2588 &self,
2589 layer_input: &GpuBuffer<f32>,
2590 grad_output: &GpuBuffer<f32>,
2591 grad_input: &mut GpuBuffer<f32>,
2592 output_scratch: &mut GpuBuffer<f32>,
2593 seq_len: usize,
2594 stream: &CudaStream,
2595 shared_scratch: &mut CudaBlockScratch,
2596 grad_lora: &mut CudaLoraGradWorkspace,
2597 ) -> Result<()> {
2598 match self {
2599 CudaBlock::Nf4(b) => b.backward(
2600 layer_input,
2601 grad_output,
2602 grad_input,
2603 output_scratch,
2604 seq_len,
2605 stream,
2606 shared_scratch,
2607 grad_lora,
2608 ),
2609 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2610 "backward_nf4 only supported on NF4 blocks".into(),
2611 )),
2612 }
2613 }
2614
2615 pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
2617 match self {
2618 CudaBlock::Nf4(b) => b.init_lora_optimizer_state(),
2619 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2620 "init_lora_optimizer_state only supported on NF4 blocks".into(),
2621 )),
2622 }
2623 }
2624
2625 #[allow(clippy::too_many_arguments)]
2627 pub(crate) fn lora_optimizer_step(
2628 &mut self,
2629 state: &mut GpuLoraOptimizerState,
2630 step: u32,
2631 lr: f32,
2632 beta1: f32,
2633 beta2: f32,
2634 eps: f32,
2635 weight_decay: f32,
2636 stream: &CudaStream,
2637 grad_lora: &CudaLoraGradWorkspace,
2638 ) -> Result<()> {
2639 match self {
2640 CudaBlock::Nf4(b) => b.lora_optimizer_step(
2641 state,
2642 step,
2643 lr,
2644 beta1,
2645 beta2,
2646 eps,
2647 weight_decay,
2648 stream,
2649 grad_lora,
2650 ),
2651 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2652 "lora_optimizer_step only supported on NF4 blocks".into(),
2653 )),
2654 }
2655 }
2656
2657 pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
2659 match self {
2660 CudaBlock::Nf4(b) => b.download_lora_weights(),
2661 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2662 "download_lora_weights only supported on NF4 blocks".into(),
2663 )),
2664 }
2665 }
2666
2667 pub fn upload_lora_weights(
2669 &mut self,
2670 a_q: &[f32],
2671 b_q: &[f32],
2672 a_v: &[f32],
2673 b_v: &[f32],
2674 ) -> Result<()> {
2675 match self {
2676 CudaBlock::Nf4(b) => b.upload_lora_weights(a_q, b_q, a_v, b_v),
2677 CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
2678 "upload_lora_weights only supported on NF4 blocks".into(),
2679 )),
2680 }
2681 }
2682}
2683
2684#[cfg(not(feature = "cuda"))]
2686pub enum CudaBlock {
2687 Fp32(CudaTransformerBlock),
2688}
2689
2690#[cfg(feature = "cuda")]
2711pub struct CudaNf4TransformerBlock {
2712 config: TransformerConfig,
2713 layer_idx: usize,
2714 input_norm_weight: GpuBuffer<f32>,
2716 post_attn_norm_weight: GpuBuffer<f32>,
2717 w_q_nf4: GpuBuffer<u8>,
2719 w_q_scales: GpuBuffer<f32>,
2720 w_k_nf4: GpuBuffer<u8>,
2721 w_k_scales: GpuBuffer<f32>,
2722 w_v_nf4: GpuBuffer<u8>,
2723 w_v_scales: GpuBuffer<f32>,
2724 w_o_nf4: GpuBuffer<u8>,
2725 w_o_scales: GpuBuffer<f32>,
2726 w_gate_nf4: GpuBuffer<u8>,
2727 w_gate_scales: GpuBuffer<f32>,
2728 w_up_nf4: GpuBuffer<u8>,
2729 w_up_scales: GpuBuffer<f32>,
2730 w_down_nf4: GpuBuffer<u8>,
2731 w_down_scales: GpuBuffer<f32>,
2732 w_q_fp32: GpuBuffer<f32>,
2734 w_k_fp32: GpuBuffer<f32>,
2735 w_v_fp32: GpuBuffer<f32>,
2736 w_o_fp32: GpuBuffer<f32>,
2737 w_gate_fp32: GpuBuffer<f32>,
2738 w_up_fp32: GpuBuffer<f32>,
2739 w_down_fp32: GpuBuffer<f32>,
2740 lora_a_q: Option<GpuBuffer<f32>>, lora_b_q: Option<GpuBuffer<f32>>, lora_a_v: Option<GpuBuffer<f32>>, lora_b_v: Option<GpuBuffer<f32>>, lora_scale: f32,
2747 lora_rank: usize,
2748 q_norm_weight: Option<GpuBuffer<f32>>,
2750 k_norm_weight: Option<GpuBuffer<f32>>,
2751 w_q_fp16: Option<GpuBuffer<u16>>,
2754 w_k_fp16: Option<GpuBuffer<u16>>,
2755 w_v_fp16: Option<GpuBuffer<u16>>,
2756 w_o_fp16: Option<GpuBuffer<u16>>,
2757 w_gate_fp16: Option<GpuBuffer<u16>>,
2758 w_up_fp16: Option<GpuBuffer<u16>>,
2759 w_down_fp16: Option<GpuBuffer<u16>>,
2760 ctx: Arc<CudaContext>,
2761 }
2763
2764#[cfg(feature = "cuda")]
2765impl CudaNf4TransformerBlock {
2766 #[allow(clippy::too_many_arguments)]
2771 pub fn new(
2772 config: &TransformerConfig,
2773 layer_idx: usize,
2774 ctx: Arc<CudaContext>,
2775 input_norm_weight: &[f32],
2776 post_attn_norm_weight: &[f32],
2777 w_q: &[f32],
2778 w_k: &[f32],
2779 w_v: &[f32],
2780 w_o: &[f32],
2781 w_gate: &[f32],
2782 w_up: &[f32],
2783 w_down: &[f32],
2784 _max_seq_len: usize, q_lora: Option<(&[f32], &[f32])>,
2787 v_lora: Option<(&[f32], &[f32])>,
2788 lora_scale: f32,
2789 lora_rank: usize,
2790 q_norm: Option<&[f32]>,
2792 k_norm: Option<&[f32]>,
2793 ) -> Result<Self> {
2794 use trueno_gpu::kernels::{quantize_nf4, NF4_BLOCK_SIZE};
2795
2796 let hidden_size = config.hidden_size;
2797 let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
2799 let intermediate_size = config.intermediate_size;
2800
2801 assert_eq!(
2806 w_q.len(),
2807 q_dim * hidden_size,
2808 "C-NF4SHAPE-001: w_q expected {}, got {} (q_dim={q_dim}, hidden={hidden_size})",
2809 q_dim * hidden_size,
2810 w_q.len()
2811 );
2812 assert_eq!(
2813 w_k.len(),
2814 kv_hidden_size * hidden_size,
2815 "C-NF4SHAPE-001: w_k expected {}, got {}",
2816 kv_hidden_size * hidden_size,
2817 w_k.len()
2818 );
2819 assert_eq!(
2820 w_v.len(),
2821 kv_hidden_size * hidden_size,
2822 "C-NF4SHAPE-001: w_v expected {}, got {}",
2823 kv_hidden_size * hidden_size,
2824 w_v.len()
2825 );
2826 assert_eq!(
2827 w_o.len(),
2828 hidden_size * q_dim,
2829 "C-NF4SHAPE-001: w_o expected {}, got {}",
2830 hidden_size * q_dim,
2831 w_o.len()
2832 );
2833 assert_eq!(
2834 w_gate.len(),
2835 intermediate_size * hidden_size,
2836 "C-NF4SHAPE-001: w_gate expected {}, got {}",
2837 intermediate_size * hidden_size,
2838 w_gate.len()
2839 );
2840 assert_eq!(
2841 w_up.len(),
2842 intermediate_size * hidden_size,
2843 "C-NF4SHAPE-001: w_up expected {}, got {}",
2844 intermediate_size * hidden_size,
2845 w_up.len()
2846 );
2847 assert_eq!(
2848 w_down.len(),
2849 hidden_size * intermediate_size,
2850 "C-NF4SHAPE-001: w_down expected {}, got {}",
2851 hidden_size * intermediate_size,
2852 w_down.len()
2853 );
2854
2855 let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
2857 let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
2858
2859 let quantize_and_upload = |weights: &[f32],
2863 total: usize|
2864 -> Result<(
2865 GpuBuffer<u8>,
2866 GpuBuffer<f32>,
2867 trueno_gpu::kernels::Nf4Quantized,
2868 )> {
2869 assert_eq!(weights.len(), total, "weight length mismatch");
2870 assert!(
2871 total.is_multiple_of(NF4_BLOCK_SIZE),
2872 "weight count {total} not divisible by NF4 block size {NF4_BLOCK_SIZE}"
2873 );
2874
2875 let q = quantize_nf4(weights, total / NF4_BLOCK_SIZE, NF4_BLOCK_SIZE);
2876 let nf4_buf = GpuBuffer::from_host(&ctx, &q.data)?;
2877 let scales_buf = GpuBuffer::from_host(&ctx, &q.scales)?;
2878 Ok((nf4_buf, scales_buf, q))
2879 };
2880
2881 let (w_q_nf4, w_q_scales, w_q_nf4_q) = quantize_and_upload(w_q, q_dim * hidden_size)?;
2883 let (w_k_nf4, w_k_scales, w_k_nf4_q) =
2884 quantize_and_upload(w_k, kv_hidden_size * hidden_size)?;
2885 let (w_v_nf4, w_v_scales, w_v_nf4_q) =
2886 quantize_and_upload(w_v, kv_hidden_size * hidden_size)?;
2887 let (w_o_nf4, w_o_scales, w_o_nf4_q) = quantize_and_upload(w_o, hidden_size * q_dim)?;
2888 let (w_gate_nf4, w_gate_scales, w_gate_nf4_q) =
2889 quantize_and_upload(w_gate, intermediate_size * hidden_size)?;
2890 let (w_up_nf4, w_up_scales, w_up_nf4_q) =
2891 quantize_and_upload(w_up, intermediate_size * hidden_size)?;
2892 let (w_down_nf4, w_down_scales, w_down_nf4_q) =
2893 quantize_and_upload(w_down, hidden_size * intermediate_size)?;
2894
2895 use trueno_gpu::kernels::dequantize_nf4;
2907 let dequant_transpose_upload = |q: &trueno_gpu::kernels::Nf4Quantized,
2908 n: usize,
2909 k: usize|
2910 -> std::result::Result<
2911 GpuBuffer<f32>,
2912 crate::autograd::cuda_tensor::CudaTensorError,
2913 > {
2914 let deq = dequantize_nf4(q); let nonzero = deq.iter().filter(|&&x| x != 0.0).count();
2916 eprintln!(
2917 "[TRACE] dequant n={n} k={k} len={} nonzero={nonzero} first5={:?}",
2918 deq.len(),
2919 &deq[..5.min(deq.len())]
2920 );
2921 assert_eq!(deq.len(), n * k, "dequant size mismatch: {} vs {}x{}", deq.len(), n, k);
2922 let mut transposed = vec![0.0f32; n * k];
2924 for row in 0..n {
2925 for col in 0..k {
2926 transposed[col * n + row] = deq[row * k + col];
2927 }
2928 }
2929 let buf = GpuBuffer::from_host(&ctx, &transposed).map_err(|e| {
2930 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
2931 "dequant transpose upload: {e:?}"
2932 ))
2933 })?;
2934 let mut verify_full = vec![0.0f32; buf.len()];
2936 let verify_ok = buf.copy_to_host(&mut verify_full).is_ok();
2937 let verify5: Vec<f32> = verify_full.iter().copied().take(5).collect();
2938 let nz = verify_full.iter().filter(|&&x| x != 0.0).count();
2939 eprintln!("[TRACE] uploaded ptr={:?} len={} copy_ok={verify_ok} nonzero={nz} verify[:5]={verify5:?}", buf.as_ptr(), buf.len());
2940 Ok(buf)
2941 };
2942 let w_q_fp32 = dequant_transpose_upload(&w_q_nf4_q, q_dim, hidden_size)?;
2945 let w_k_fp32 = dequant_transpose_upload(&w_k_nf4_q, kv_hidden_size, hidden_size)?;
2946 let w_v_fp32 = dequant_transpose_upload(&w_v_nf4_q, kv_hidden_size, hidden_size)?;
2947 let w_o_fp32 = dequant_transpose_upload(&w_o_nf4_q, hidden_size, q_dim)?;
2948 let w_gate_fp32 = dequant_transpose_upload(&w_gate_nf4_q, intermediate_size, hidden_size)?;
2949 let w_up_fp32 = dequant_transpose_upload(&w_up_nf4_q, intermediate_size, hidden_size)?;
2950 let w_down_fp32 = dequant_transpose_upload(&w_down_nf4_q, hidden_size, intermediate_size)?;
2951
2952 let (lora_a_q, lora_b_q) = match q_lora {
2959 Some((a_data, b_data)) => {
2960 let a = GpuBuffer::from_host(&ctx, a_data)?;
2961 let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2962 let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2963 (Some(a), Some(b))
2964 }
2965 None => (None, None),
2966 };
2967 let (lora_a_v, lora_b_v) = match v_lora {
2968 Some((a_data, b_data)) => {
2969 let a = GpuBuffer::from_host(&ctx, a_data)?;
2970 let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
2971 let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
2972 (Some(a), Some(b))
2973 }
2974 None => (None, None),
2975 };
2976
2977 let q_norm_weight = match q_norm {
2979 Some(w) => {
2980 assert_eq!(
2981 w.len(),
2982 config.head_dim(),
2983 "ENT-270: q_norm weight expected [head_dim={}], got [{}]",
2984 config.head_dim(),
2985 w.len()
2986 );
2987 Some(GpuBuffer::from_host(&ctx, w)?)
2988 }
2989 None => None,
2990 };
2991 let k_norm_weight = match k_norm {
2992 Some(w) => {
2993 assert_eq!(
2994 w.len(),
2995 config.head_dim(),
2996 "ENT-270: k_norm weight expected [head_dim={}], got [{}]",
2997 config.head_dim(),
2998 w.len()
2999 );
3000 Some(GpuBuffer::from_host(&ctx, w)?)
3001 }
3002 None => None,
3003 };
3004
3005 Ok(Self {
3006 config: config.clone(),
3007 layer_idx,
3008 input_norm_weight,
3009 post_attn_norm_weight,
3010 w_q_nf4,
3011 w_q_scales,
3012 w_k_nf4,
3013 w_k_scales,
3014 w_v_nf4,
3015 w_v_scales,
3016 w_o_nf4,
3017 w_o_scales,
3018 w_gate_nf4,
3019 w_gate_scales,
3020 w_up_nf4,
3021 w_up_scales,
3022 w_down_nf4,
3023 w_down_scales,
3024 w_q_fp32,
3025 w_k_fp32,
3026 w_v_fp32,
3027 w_o_fp32,
3028 w_gate_fp32,
3029 w_up_fp32,
3030 w_down_fp32,
3031 lora_a_q,
3032 lora_b_q,
3033 lora_a_v,
3034 lora_b_v,
3035 lora_scale,
3036 lora_rank,
3037 q_norm_weight,
3038 k_norm_weight,
3039 w_q_fp16: None,
3041 w_k_fp16: None,
3042 w_v_fp16: None,
3043 w_o_fp16: None,
3044 w_gate_fp16: None,
3045 w_up_fp16: None,
3046 w_down_fp16: None,
3047 ctx,
3048 })
3049 }
3050
3051 pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()> {
3053 let cast_weight = |w_fp32: &GpuBuffer<f32>, ctx: &CudaContext| -> Result<GpuBuffer<u16>> {
3054 let n = w_fp32.len();
3055 let mut w_fp16 = GpuBuffer::<u16>::new(ctx, n)?;
3056 cast_f32_to_f16_gpu(w_fp32, &mut w_fp16, n as u32, stream)?;
3057 Ok(w_fp16)
3058 };
3059
3060 self.w_q_fp16 = Some(cast_weight(&self.w_q_fp32, &self.ctx)?);
3061 self.w_k_fp16 = Some(cast_weight(&self.w_k_fp32, &self.ctx)?);
3062 self.w_v_fp16 = Some(cast_weight(&self.w_v_fp32, &self.ctx)?);
3063 self.w_o_fp16 = Some(cast_weight(&self.w_o_fp32, &self.ctx)?);
3064 self.w_gate_fp16 = Some(cast_weight(&self.w_gate_fp32, &self.ctx)?);
3065 self.w_up_fp16 = Some(cast_weight(&self.w_up_fp32, &self.ctx)?);
3066 self.w_down_fp16 = Some(cast_weight(&self.w_down_fp32, &self.ctx)?);
3067
3068 stream.synchronize().map_err(|e| {
3069 crate::autograd::cuda_tensor::CudaTensorError::KernelError(format!(
3070 "FP16 weight cast sync failed: {e:?}"
3071 ))
3072 })?;
3073 let dummy = |ctx: &CudaContext| GpuBuffer::<f32>::new(ctx, 1).unwrap();
3076 self.w_q_fp32 = dummy(&self.ctx);
3077 self.w_k_fp32 = dummy(&self.ctx);
3078 self.w_v_fp32 = dummy(&self.ctx);
3079 self.w_o_fp32 = dummy(&self.ctx);
3080 self.w_gate_fp32 = dummy(&self.ctx);
3081 self.w_up_fp32 = dummy(&self.ctx);
3082 self.w_down_fp32 = dummy(&self.ctx);
3083 eprintln!("[FP16] Weights cast + fp32 dropped (~2.6 GB freed)");
3084
3085 Ok(())
3086 }
3087
3088 #[rustfmt::skip]
3090 pub(crate) fn forward(
3091 &self,
3092 input: &GpuBuffer<f32>,
3093 output: &mut GpuBuffer<f32>,
3094 seq_len: usize,
3095 stream: &CudaStream,
3096 scratch: &mut CudaBlockScratch,
3097 ) -> Result<()> {
3098 use crate::autograd::cuda_forward::{gemm_forward, gemm_nf4_forward, gemm_nf4_tc_forward};
3099
3100 let hidden_size = self.config.hidden_size;
3101 let q_dim = self.config.q_dim();
3102 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3103 let intermediate_size = self.config.intermediate_size;
3104
3105 scratch.prepare_causal_mask(seq_len, &self.ctx)?;
3107
3108 let _t = scratch.op_begin();
3112 rms_norm_forward_with_eps(
3113 input,
3114 &self.input_norm_weight,
3115 &mut scratch.norm1_out,
3116 saturating_u32(seq_len),
3117 saturating_u32(hidden_size),
3118 self.config.rms_norm_eps,
3119 stream,
3120 )?;
3121 scratch.op_end(_t, OP_RMSNORM_ATTN);
3122
3123 static USE_NF4_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3129 let nf4_gemm = *USE_NF4_GEMM.get_or_init(|| std::env::var("NF4_FUSED_GEMM").as_deref() == Ok("1"));
3130 static USE_NF4_TC_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3131 let nf4_tc_gemm = *USE_NF4_TC_GEMM.get_or_init(|| std::env::var("NF4_TC_GEMM").as_deref() == Ok("1"));
3132 static USE_FP16_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3133 let fp16_gemm = *USE_FP16_GEMM.get_or_init(|| std::env::var("FP16_GEMM").as_deref() == Ok("1"));
3134
3135 let act_n = (seq_len * hidden_size) as u32;
3137 if fp16_gemm && self.w_q_fp16.is_some() {
3138 if scratch.norm1_out_f16.is_none() {
3140 scratch.norm1_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3141 }
3142 let f16_buf = scratch.norm1_out_f16.as_mut().unwrap();
3143 cast_f32_to_f16_gpu(&scratch.norm1_out, f16_buf, act_n, stream)?;
3144 }
3145
3146 let _t = scratch.op_begin(); if fp16_gemm && self.w_q_fp16.is_some() {
3148 let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3149 gemm_f16_to_f32_forward(f16_act, self.w_q_fp16.as_ref().unwrap(), &mut scratch.q,
3150 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3151 } else if nf4_tc_gemm {
3152 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3153 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3154 } else if nf4_gemm {
3155 gemm_nf4_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
3156 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3157 } else {
3158 gemm_forward(&scratch.norm1_out, &self.w_q_fp32, &mut scratch.q,
3159 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
3160 }
3161
3162 if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
3164 let s = saturating_u32(seq_len);
3165 let h = saturating_u32(hidden_size);
3166 let r = saturating_u32(self.lora_rank);
3167 let qd = saturating_u32(q_dim);
3168 gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
3170 gemm_forward(&scratch.lora_inter, b_q, &mut scratch.lora_temp, s, r, qd, stream)?;
3172 cuda_add_inplace(&mut scratch.q, &scratch.lora_temp, seq_len * q_dim, stream)?;
3174 }
3175
3176 if fp16_gemm && self.w_k_fp16.is_some() {
3177 let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
3178 gemm_f16_to_f32_forward(f16_act, self.w_k_fp16.as_ref().unwrap(), &mut scratch.k,
3179 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3180 gemm_f16_to_f32_forward(f16_act, self.w_v_fp16.as_ref().unwrap(), &mut scratch.v,
3181 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3182 } else if nf4_tc_gemm {
3183 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_k_nf4, &self.w_k_scales, &mut scratch.k,
3185 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3186 gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_v_nf4, &self.w_v_scales, &mut scratch.v,
3187 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3188 } else if nf4_gemm {
3189 crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3191 &scratch.norm1_out,
3192 &self.w_k_nf4, &self.w_k_scales,
3193 &self.w_v_nf4, &self.w_v_scales,
3194 &mut scratch.k, &mut scratch.v,
3195 saturating_u32(seq_len), saturating_u32(hidden_size),
3196 saturating_u32(kv_hidden_size), stream,
3197 )?;
3198 } else {
3199 gemm_forward(&scratch.norm1_out, &self.w_k_fp32, &mut scratch.k,
3200 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3201 gemm_forward(&scratch.norm1_out, &self.w_v_fp32, &mut scratch.v,
3202 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
3203 }
3204
3205 scratch.op_end(_t, OP_QKV_GEMM); if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
3209 let s = saturating_u32(seq_len);
3210 let h = saturating_u32(hidden_size);
3211 let r = saturating_u32(self.lora_rank);
3212 let vd = saturating_u32(kv_hidden_size);
3213 gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
3215 gemm_forward(&scratch.lora_inter, b_v, &mut scratch.lora_temp, s, r, vd, stream)?;
3217 cuda_add_inplace(&mut scratch.v, &scratch.lora_temp, seq_len * kv_hidden_size, stream)?;
3219 }
3220
3221 let _t = scratch.op_begin();
3223 self.compute_attention_cuda(seq_len, stream, scratch)?;
3224 scratch.op_end(_t, OP_ATTENTION);
3225
3226 let _t = scratch.op_begin();
3228 if fp16_gemm && self.w_o_fp16.is_some() {
3229 if scratch.attn_out_f16.is_none() {
3230 scratch.attn_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * q_dim)?);
3231 }
3232 let f16_buf = scratch.attn_out_f16.as_mut().unwrap();
3233 cast_f32_to_f16_gpu(&scratch.attn_out, f16_buf, (seq_len * q_dim) as u32, stream)?;
3234 gemm_f16_to_f32_forward(f16_buf, self.w_o_fp16.as_ref().unwrap(), &mut scratch.o_proj_out,
3235 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3236 } else if nf4_tc_gemm {
3237 gemm_nf4_tc_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3238 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3239 } else if nf4_gemm {
3240 gemm_nf4_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
3241 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3242 } else {
3243 gemm_forward(&scratch.attn_out, &self.w_o_fp32, &mut scratch.o_proj_out,
3244 saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
3245 }
3246
3247 scratch.op_end(_t, OP_O_PROJ);
3248
3249 let _t = scratch.op_begin();
3251 fused_residual_rmsnorm_forward(
3255 input,
3256 &scratch.o_proj_out,
3257 &mut scratch.residual1,
3258 &mut scratch.norm2_out,
3259 &self.post_attn_norm_weight,
3260 saturating_u32(seq_len),
3261 saturating_u32(hidden_size),
3262 stream,
3263 )?;
3264
3265 scratch.op_end(_t, OP_RMSNORM_FFN); let _t = scratch.op_begin(); if fp16_gemm && self.w_gate_fp16.is_some() {
3270 if scratch.norm2_out_f16.is_none() {
3271 scratch.norm2_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
3272 }
3273 let f16_buf = scratch.norm2_out_f16.as_mut().unwrap();
3274 cast_f32_to_f16_gpu(&scratch.norm2_out, f16_buf, (seq_len * hidden_size) as u32, stream)?;
3275 gemm_f16_to_f32_forward(f16_buf, self.w_gate_fp16.as_ref().unwrap(), &mut scratch.gate_out,
3276 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3277 gemm_f16_to_f32_forward(f16_buf, self.w_up_fp16.as_ref().unwrap(), &mut scratch.up_out,
3278 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3279 } else if nf4_tc_gemm {
3280 gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_gate_nf4, &self.w_gate_scales, &mut scratch.gate_out,
3282 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3283 gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_up_nf4, &self.w_up_scales, &mut scratch.up_out,
3284 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3285 } else if nf4_gemm {
3286 crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
3288 &scratch.norm2_out,
3289 &self.w_gate_nf4, &self.w_gate_scales,
3290 &self.w_up_nf4, &self.w_up_scales,
3291 &mut scratch.gate_out, &mut scratch.up_out,
3292 saturating_u32(seq_len), saturating_u32(hidden_size),
3293 saturating_u32(intermediate_size), stream,
3294 )?;
3295 } else {
3296 gemm_forward(&scratch.norm2_out, &self.w_gate_fp32, &mut scratch.gate_out,
3297 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3298 gemm_forward(&scratch.norm2_out, &self.w_up_fp32, &mut scratch.up_out,
3299 saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
3300 }
3301
3302 scratch.op_end(_t, OP_GATE_UP_GEMM);
3303
3304 let _t = scratch.op_begin();
3306 fused_swiglu_forward(&scratch.gate_out, &scratch.up_out, &mut scratch.swiglu_out,
3307 saturating_u32(seq_len * intermediate_size), stream)?;
3308 scratch.op_end(_t, OP_SILU);
3309
3310 let _t = scratch.op_begin();
3312 if fp16_gemm && self.w_down_fp16.is_some() {
3313 if scratch.swiglu_out_f16.is_none() {
3314 scratch.swiglu_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * intermediate_size)?);
3315 }
3316 let f16_buf = scratch.swiglu_out_f16.as_mut().unwrap();
3317 cast_f32_to_f16_gpu(&scratch.swiglu_out, f16_buf, (seq_len * intermediate_size) as u32, stream)?;
3318 gemm_f16_to_f32_forward(f16_buf, self.w_down_fp16.as_ref().unwrap(), &mut scratch.ffn_out,
3319 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3320 } else if nf4_tc_gemm {
3321 gemm_nf4_tc_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3322 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3323 } else if nf4_gemm {
3324 gemm_nf4_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
3325 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3326 } else {
3327 gemm_forward(&scratch.swiglu_out, &self.w_down_fp32, &mut scratch.ffn_out,
3328 saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
3329 }
3330
3331 scratch.op_end(_t, OP_DOWN_GEMM);
3332
3333 cuda_add(&scratch.residual1, &scratch.ffn_out, output, seq_len * hidden_size, stream)?;
3335
3336 Ok(())
3337 }
3338
3339 pub fn layer_idx(&self) -> usize {
3341 self.layer_idx
3342 }
3343}
3344
3345#[cfg(feature = "cuda")]
3350impl CudaNf4TransformerBlock {
3351 fn compute_attention_cuda(
3352 &self,
3353 seq_len: usize,
3354 stream: &CudaStream,
3355 scratch: &mut CudaBlockScratch,
3356 ) -> Result<()> {
3357 let num_heads = self.config.num_attention_heads;
3358 let num_kv_heads = self.config.num_kv_heads;
3359 let head_dim = self.config.head_dim();
3360 let heads_per_kv = num_heads / num_kv_heads;
3361
3362 let s = saturating_u32(seq_len);
3363 let nh = saturating_u32(num_heads);
3364 let nkv = saturating_u32(num_kv_heads);
3365 let hd = saturating_u32(head_dim);
3366
3367 if let Some(ref q_norm) = self.q_norm_weight {
3372 for pos in 0..seq_len {
3373 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3374 per_head_rmsnorm_forward(q_ref, q_norm, &mut scratch.q, nh, hd, pos, stream)?;
3375 }
3376 }
3377 if let Some(ref k_norm) = self.k_norm_weight {
3378 for pos in 0..seq_len {
3379 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3380 per_head_rmsnorm_forward(k_ref, k_norm, &mut scratch.k, nkv, hd, pos, stream)?;
3381 }
3382 }
3383
3384 let rope_theta = self.config.rope_theta;
3387 {
3388 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
3389 batched_rope_neox_forward(
3390 q_ref,
3391 &mut scratch.q,
3392 &scratch.rope_positions,
3393 nh,
3394 hd,
3395 s,
3396 rope_theta,
3397 stream,
3398 )?;
3399 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
3400 batched_rope_neox_forward(
3401 k_ref,
3402 &mut scratch.k,
3403 &scratch.rope_positions,
3404 nkv,
3405 hd,
3406 s,
3407 rope_theta,
3408 stream,
3409 )?;
3410 }
3411
3412 interleaved_to_batched_forward(&scratch.q, &mut scratch.attn_q_batched, s, nh, hd, stream)?;
3414
3415 interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3417
3418 if heads_per_kv > 1 {
3419 expand_kv_heads(
3420 &scratch.attn_kv_temp,
3421 &mut scratch.attn_kv_temp2,
3422 num_kv_heads,
3423 heads_per_kv,
3424 seq_len * head_dim,
3425 stream,
3426 )?;
3427 } else {
3428 unsafe {
3430 scratch
3431 .attn_kv_temp2
3432 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3433 .map_err(|e| {
3434 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3435 "K copy failed: {e:?}"
3436 ))
3437 })?;
3438 }
3439 }
3440
3441 batched_transpose_forward(
3443 &scratch.attn_kv_temp2,
3444 &mut scratch.attn_kv_temp,
3445 nh,
3446 s,
3447 hd,
3448 stream,
3449 )?;
3450
3451 batched_4d_gemm_forward(
3453 &scratch.attn_q_batched,
3454 &scratch.attn_kv_temp,
3455 &mut scratch.attn_scores,
3456 1,
3457 nh,
3458 s,
3459 s,
3460 hd,
3461 stream,
3462 )?;
3463
3464 let scale_factor = 1.0 / (head_dim as f32).sqrt();
3466 let total_scores = num_heads * seq_len * seq_len;
3467 let scores_view = unsafe {
3468 GpuBuffer::<f32>::from_raw_parts(
3469 scratch.attn_scores.as_ptr(),
3470 scratch.attn_scores.len(),
3471 )
3472 };
3473 scale_forward(
3474 &scores_view,
3475 &mut scratch.attn_scores,
3476 scale_factor,
3477 saturating_u32(total_scores),
3478 stream,
3479 )?;
3480 leak(scores_view);
3481
3482 {
3488 let seq_sq = seq_len * seq_len;
3489 let mask_ptr = scratch.causal_mask_contiguous.as_ptr();
3490 let scores_base = scratch.attn_scores.as_ptr();
3491 for head in 0..num_heads {
3492 let byte_offset = (head * seq_sq * 4) as u64;
3493 let head_ptr = scores_base + byte_offset;
3494 let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
3495 let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3496 let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
3497 residual_add_forward(
3498 &mask_view,
3499 &scores_view,
3500 &mut out_view,
3501 saturating_u32(seq_sq),
3502 stream,
3503 )?;
3504 leak(mask_view);
3505 leak(scores_view);
3506 leak(out_view);
3507 }
3508 }
3509
3510 let scores_view = unsafe {
3513 GpuBuffer::<f32>::from_raw_parts(
3514 scratch.attn_scores.as_ptr(),
3515 scratch.attn_scores.len(),
3516 )
3517 };
3518 batched_softmax_forward(
3519 &scores_view,
3520 &mut scratch.attn_scores,
3521 saturating_u32(num_heads * seq_len),
3522 s,
3523 stream,
3524 )?;
3525 leak(scores_view);
3526
3527 interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
3529
3530 if heads_per_kv > 1 {
3531 expand_kv_heads(
3532 &scratch.attn_kv_temp,
3533 &mut scratch.attn_kv_temp2,
3534 num_kv_heads,
3535 heads_per_kv,
3536 seq_len * head_dim,
3537 stream,
3538 )?;
3539 } else {
3540 unsafe {
3544 scratch
3545 .attn_kv_temp2
3546 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
3547 .map_err(|e| {
3548 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
3549 "V copy failed: {e:?}"
3550 ))
3551 })?;
3552 }
3553 }
3554
3555 batched_4d_gemm_forward(
3557 &scratch.attn_scores,
3558 &scratch.attn_kv_temp2,
3559 &mut scratch.attn_q_batched,
3560 1,
3561 nh,
3562 s,
3563 hd,
3564 s,
3565 stream,
3566 )?;
3567
3568 batched_to_interleaved_forward(
3570 &scratch.attn_q_batched,
3571 &mut scratch.attn_out,
3572 s,
3573 nh,
3574 hd,
3575 stream,
3576 )?;
3577
3578 Ok(())
3579 }
3580}
3581
3582#[cfg(feature = "cuda")]
3598pub(crate) struct CudaLoraGradWorkspace {
3599 pub(crate) grad_lora_a_q: GpuBuffer<f32>,
3601 pub(crate) grad_lora_b_q: GpuBuffer<f32>,
3603 pub(crate) grad_lora_a_v: GpuBuffer<f32>,
3605 pub(crate) grad_lora_b_v: GpuBuffer<f32>,
3607 pub(crate) grad_input_norm: GpuBuffer<f32>,
3609 pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
3611}
3612
3613#[cfg(feature = "cuda")]
3614impl CudaLoraGradWorkspace {
3615 pub(crate) fn new(
3617 ctx: &Arc<CudaContext>,
3618 config: &super::config::TransformerConfig,
3619 lora_rank: usize,
3620 ) -> Result<Self> {
3621 let h = config.hidden_size;
3622 let q_dim = config.q_dim();
3623 let kv = config.num_kv_heads * config.head_dim();
3624 let r = lora_rank;
3625
3626 Ok(Self {
3627 grad_lora_a_q: GpuBuffer::new(ctx, h * r)?,
3628 grad_lora_b_q: GpuBuffer::new(ctx, r * q_dim)?,
3629 grad_lora_a_v: GpuBuffer::new(ctx, h * r)?,
3630 grad_lora_b_v: GpuBuffer::new(ctx, r * kv)?,
3631 grad_input_norm: GpuBuffer::new(ctx, h)?,
3632 grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
3633 })
3634 }
3635
3636 pub(crate) fn clip_gradients(&mut self, max_norm: f32, stream: &CudaStream) {
3646 let sq_a_q = squared_sum_cuda(&self.grad_lora_a_q, self.grad_lora_a_q.len() as u32, stream)
3648 .unwrap_or(0.0);
3649 let sq_b_q = squared_sum_cuda(&self.grad_lora_b_q, self.grad_lora_b_q.len() as u32, stream)
3650 .unwrap_or(0.0);
3651 let sq_a_v = squared_sum_cuda(&self.grad_lora_a_v, self.grad_lora_a_v.len() as u32, stream)
3652 .unwrap_or(0.0);
3653 let sq_b_v = squared_sum_cuda(&self.grad_lora_b_v, self.grad_lora_b_v.len() as u32, stream)
3654 .unwrap_or(0.0);
3655 let sq_in =
3656 squared_sum_cuda(&self.grad_input_norm, self.grad_input_norm.len() as u32, stream)
3657 .unwrap_or(0.0);
3658 let sq_pa = squared_sum_cuda(
3659 &self.grad_post_attn_norm,
3660 self.grad_post_attn_norm.len() as u32,
3661 stream,
3662 )
3663 .unwrap_or(0.0);
3664 let total_norm = (sq_a_q + sq_b_q + sq_a_v + sq_b_v + sq_in + sq_pa).sqrt();
3665
3666 if total_norm <= max_norm {
3667 return;
3668 }
3669
3670 let clip_scale = max_norm / (total_norm + 1e-6);
3672 let n_aq = self.grad_lora_a_q.len() as u32;
3673 let n_bq = self.grad_lora_b_q.len() as u32;
3674 let n_av = self.grad_lora_a_v.len() as u32;
3675 let n_bv = self.grad_lora_b_v.len() as u32;
3676 let n_in = self.grad_input_norm.len() as u32;
3677 let n_pa = self.grad_post_attn_norm.len() as u32;
3678 let _ = gradient_clip_cuda(&mut self.grad_lora_a_q, clip_scale, n_aq, stream);
3679 let _ = gradient_clip_cuda(&mut self.grad_lora_b_q, clip_scale, n_bq, stream);
3680 let _ = gradient_clip_cuda(&mut self.grad_lora_a_v, clip_scale, n_av, stream);
3681 let _ = gradient_clip_cuda(&mut self.grad_lora_b_v, clip_scale, n_bv, stream);
3682 let _ = gradient_clip_cuda(&mut self.grad_input_norm, clip_scale, n_in, stream);
3683 let _ = gradient_clip_cuda(&mut self.grad_post_attn_norm, clip_scale, n_pa, stream);
3684 }
3685}
3686
3687#[cfg(feature = "cuda")]
3699pub(crate) struct GpuLoraOptimizerState {
3700 m_lora_a_q: GpuBuffer<f32>,
3701 v_lora_a_q: GpuBuffer<f32>,
3702 m_lora_b_q: GpuBuffer<f32>,
3703 v_lora_b_q: GpuBuffer<f32>,
3704 m_lora_a_v: GpuBuffer<f32>,
3705 v_lora_a_v: GpuBuffer<f32>,
3706 m_lora_b_v: GpuBuffer<f32>,
3707 v_lora_b_v: GpuBuffer<f32>,
3708 m_input_norm: GpuBuffer<f32>,
3709 v_input_norm: GpuBuffer<f32>,
3710 m_post_attn_norm: GpuBuffer<f32>,
3711 v_post_attn_norm: GpuBuffer<f32>,
3712}
3713
3714#[cfg(feature = "cuda")]
3715impl GpuLoraOptimizerState {
3716 fn new(
3717 ctx: &Arc<CudaContext>,
3718 config: &super::config::TransformerConfig,
3719 lora_rank: usize,
3720 ) -> Result<Self> {
3721 let h = config.hidden_size;
3722 let q_dim = config.q_dim();
3723 let kv = config.num_kv_heads * config.head_dim();
3724 let r = lora_rank;
3725
3726 let z = |n: usize| -> Result<GpuBuffer<f32>> {
3729 Ok(GpuBuffer::from_host(ctx, &vec![0.0f32; n])?)
3730 };
3731 Ok(Self {
3732 m_lora_a_q: z(h * r)?,
3733 v_lora_a_q: z(h * r)?,
3734 m_lora_b_q: z(r * q_dim)?,
3735 v_lora_b_q: z(r * q_dim)?,
3736 m_lora_a_v: z(h * r)?,
3737 v_lora_a_v: z(h * r)?,
3738 m_lora_b_v: z(r * kv)?,
3739 v_lora_b_v: z(r * kv)?,
3740 m_input_norm: z(h)?,
3741 v_input_norm: z(h)?,
3742 m_post_attn_norm: z(h)?,
3743 v_post_attn_norm: z(h)?,
3744 })
3745 }
3746}
3747
3748#[cfg(feature = "cuda")]
3753impl CudaNf4TransformerBlock {
3754 #[allow(clippy::too_many_arguments)]
3762 pub(crate) fn backward(
3763 &self,
3764 layer_input: &GpuBuffer<f32>,
3765 grad_output: &GpuBuffer<f32>,
3766 grad_input: &mut GpuBuffer<f32>,
3767 output_scratch: &mut GpuBuffer<f32>,
3768 seq_len: usize,
3769 stream: &CudaStream,
3770 scratch: &mut CudaBlockScratch,
3771 grad_lora: &mut CudaLoraGradWorkspace,
3772 ) -> Result<()> {
3773 let hidden_size = self.config.hidden_size;
3774 let _q_dim = self.config.q_dim();
3775 let _kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
3776 let intermediate_size = self.config.intermediate_size;
3777 let eps = 1e-5_f32;
3778
3779 self.forward(layer_input, output_scratch, seq_len, stream, scratch).map_err(|e| {
3782 eprintln!(
3783 "[backward] Layer {} activation-checkpoint forward FAILED: {e:?}",
3784 self.layer_idx
3785 );
3786 e
3787 })?;
3788
3789 self.backward_nf4_ffn(
3791 grad_output,
3792 seq_len,
3793 hidden_size,
3794 intermediate_size,
3795 stream,
3796 scratch,
3797 )?;
3798
3799 let _t = scratch.op_begin(); rms_norm_backward(
3802 &scratch.residual1,
3803 &self.post_attn_norm_weight,
3804 &scratch.grad_hidden, grad_input, &mut grad_lora.grad_post_attn_norm,
3807 saturating_u32(seq_len),
3808 saturating_u32(hidden_size),
3809 eps,
3810 stream,
3811 )?;
3812
3813 cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
3816
3817 self.backward_nf4_attention(
3819 grad_input, seq_len, stream, scratch, grad_lora,
3821 )?;
3822
3823 rms_norm_backward(
3827 layer_input,
3828 &self.input_norm_weight,
3829 &scratch.grad_hidden, grad_input, &mut grad_lora.grad_input_norm,
3832 saturating_u32(seq_len),
3833 saturating_u32(hidden_size),
3834 eps,
3835 stream,
3836 )?;
3837
3838 scratch.op_end(_t, OP_NORM_BWD);
3839
3840 Ok(())
3841 }
3842
3843 fn backward_nf4_ffn(
3849 &self,
3850 grad_output: &GpuBuffer<f32>,
3851 seq_len: usize,
3852 hidden_size: usize,
3853 intermediate_size: usize,
3854 stream: &CudaStream,
3855 scratch: &mut CudaBlockScratch,
3856 ) -> Result<()> {
3857 let s = saturating_u32(seq_len);
3858 let h = saturating_u32(hidden_size);
3859 let i_size = saturating_u32(intermediate_size);
3860 let n_inter = saturating_u32(seq_len * intermediate_size);
3861
3862 static USE_NF4_TC_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3864 let nf4_tc_bwd =
3865 *USE_NF4_TC_BWD.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
3866
3867 let _t = scratch.op_begin(); if nf4_tc_bwd {
3870 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3872 grad_output,
3873 &self.w_down_nf4,
3874 &self.w_down_scales,
3875 &mut scratch.grad_swiglu,
3876 s,
3877 h, i_size, stream,
3880 )?;
3881 } else {
3882 gemm_backward_a_fp16_dispatch(
3883 grad_output,
3884 self.w_down_fp16.as_ref(),
3885 &self.w_down_fp32,
3886 &mut scratch.grad_swiglu,
3887 s,
3888 i_size,
3889 h,
3890 stream,
3891 &self.ctx,
3892 )?;
3893 }
3894
3895 scratch.op_end(_t, OP_DOWN_BWD);
3896
3897 let _t = scratch.op_begin(); elementwise_mul_forward(
3904 &scratch.grad_swiglu,
3905 &scratch.up_out,
3906 &mut scratch.swiglu_out,
3907 n_inter,
3908 stream,
3909 )?;
3910
3911 silu_backward(
3915 &scratch.gate_out,
3916 &scratch.swiglu_out,
3917 &mut scratch.up_out, stream,
3919 )?;
3920
3921 silu_forward(&scratch.gate_out, &mut scratch.swiglu_out, n_inter, stream)?;
3924 elementwise_mul_forward(
3926 &scratch.grad_swiglu,
3927 &scratch.swiglu_out,
3928 &mut scratch.gate_out, n_inter,
3930 stream,
3931 )?;
3932
3933 scratch.op_end(_t, OP_SWIGLU_BWD);
3934
3935 let _t = scratch.op_begin(); static USE_FUSED_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
3940 let fused_bwd = *USE_FUSED_BWD
3941 .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
3942
3943 if nf4_tc_bwd {
3944 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3947 &scratch.gate_out, &self.w_up_nf4,
3949 &self.w_up_scales,
3950 &mut scratch.grad_hidden,
3951 s,
3952 i_size, h, stream,
3955 )?;
3956 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
3958 &scratch.up_out, &self.w_gate_nf4,
3960 &self.w_gate_scales,
3961 &mut scratch.ffn_out,
3962 s,
3963 i_size, h, stream,
3966 )?;
3967 cuda_add_inplace(
3969 &mut scratch.grad_hidden,
3970 &scratch.ffn_out,
3971 seq_len * hidden_size,
3972 stream,
3973 )?;
3974 } else if fused_bwd {
3975 gemm_backward_a_fp16_dispatch(
3977 &scratch.gate_out,
3978 self.w_up_fp16.as_ref(),
3979 &self.w_up_fp32,
3980 &mut scratch.grad_hidden,
3981 s,
3982 h,
3983 i_size,
3984 stream,
3985 &self.ctx,
3986 )?;
3987 gemm_backward_a_fp16_dispatch_accumulate(
3989 &scratch.up_out,
3990 self.w_gate_fp16.as_ref(),
3991 &self.w_gate_fp32,
3992 &mut scratch.grad_hidden,
3993 s,
3994 h,
3995 i_size,
3996 stream,
3997 &self.ctx,
3998 )?;
3999 } else {
4000 gemm_backward_a_fp16_dispatch(
4002 &scratch.up_out,
4003 self.w_gate_fp16.as_ref(),
4004 &self.w_gate_fp32,
4005 &mut scratch.ffn_out,
4006 s,
4007 h,
4008 i_size,
4009 stream,
4010 &self.ctx,
4011 )?;
4012 gemm_backward_a_fp16_dispatch(
4013 &scratch.gate_out,
4014 self.w_up_fp16.as_ref(),
4015 &self.w_up_fp32,
4016 &mut scratch.grad_hidden,
4017 s,
4018 h,
4019 i_size,
4020 stream,
4021 &self.ctx,
4022 )?;
4023
4024 cuda_add_inplace(
4026 &mut scratch.grad_hidden,
4027 &scratch.ffn_out,
4028 seq_len * hidden_size,
4029 stream,
4030 )?;
4031 }
4032 scratch.op_end(_t, OP_GATE_UP_BWD);
4033
4034 Ok(())
4035 }
4036
4037 fn backward_nf4_attention(
4043 &self,
4044 grad_residual1: &GpuBuffer<f32>,
4045 seq_len: usize,
4046 stream: &CudaStream,
4047 scratch: &mut CudaBlockScratch,
4048 grad_lora: &mut CudaLoraGradWorkspace,
4049 ) -> Result<()> {
4050 use crate::autograd::cuda_forward::gemm_forward;
4051
4052 let hidden_size = self.config.hidden_size;
4053 let q_dim = self.config.q_dim();
4054 let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
4055 let num_heads = self.config.num_attention_heads;
4056 let head_dim = self.config.head_dim();
4057
4058 let s = saturating_u32(seq_len);
4059 let h = saturating_u32(hidden_size);
4060 let qd = saturating_u32(q_dim);
4061 let kvh = saturating_u32(kv_hidden_size);
4062
4063 static USE_NF4_TC_BWD_O: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4065 let nf4_tc_bwd_o = *USE_NF4_TC_BWD_O
4066 .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4067
4068 let _t = scratch.op_begin(); if nf4_tc_bwd_o {
4070 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4071 grad_residual1,
4072 &self.w_o_nf4,
4073 &self.w_o_scales,
4074 &mut scratch.attn_out,
4075 s,
4076 h, qd, stream,
4079 )?;
4080 } else {
4081 gemm_backward_a_fp16_dispatch(
4082 grad_residual1,
4083 self.w_o_fp16.as_ref(),
4084 &self.w_o_fp32,
4085 &mut scratch.attn_out,
4086 s,
4087 qd,
4088 h,
4089 stream,
4090 &self.ctx,
4091 )?;
4092 }
4093
4094 self.backward_nf4_attention_mechanism(seq_len, num_heads, head_dim, stream, scratch)?;
4098
4099 let rope_theta = self.config.rope_theta;
4106 let num_kv_heads = self.config.num_kv_heads;
4107 let nkv = saturating_u32(num_kv_heads);
4108 let nh = saturating_u32(num_heads);
4109 let hd = saturating_u32(head_dim);
4110 {
4111 let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
4112 batched_rope_neox_backward(
4113 q_ref,
4114 &mut scratch.q,
4115 &scratch.rope_positions,
4116 nh,
4117 hd,
4118 s,
4119 rope_theta,
4120 stream,
4121 )?;
4122 let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
4123 batched_rope_neox_backward(
4124 k_ref,
4125 &mut scratch.k,
4126 &scratch.rope_positions,
4127 nkv,
4128 hd,
4129 s,
4130 rope_theta,
4131 stream,
4132 )?;
4133 }
4134
4135 scratch.op_end(_t, OP_ATTN_BWD);
4136
4137 static USE_NF4_TC_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4139 let nf4_tc_bwd = *USE_NF4_TC_BWD_ATTN
4140 .get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
4141
4142 let _t = scratch.op_begin(); if nf4_tc_bwd {
4144 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4146 &scratch.q,
4147 &self.w_q_nf4,
4148 &self.w_q_scales,
4149 &mut scratch.o_proj_out,
4150 s,
4151 qd, h, stream,
4154 )?;
4155 } else {
4156 gemm_backward_a_fp16_dispatch(
4157 &scratch.q,
4158 self.w_q_fp16.as_ref(),
4159 &self.w_q_fp32,
4160 &mut scratch.o_proj_out,
4161 s,
4162 h,
4163 qd,
4164 stream,
4165 &self.ctx,
4166 )?;
4167 }
4168
4169 if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
4171 let r = saturating_u32(self.lora_rank);
4172
4173 gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
4175
4176 gemm_backward_b(
4179 &scratch.lora_inter,
4180 &scratch.q,
4181 &mut grad_lora.grad_lora_b_q,
4182 s,
4183 r,
4184 qd,
4185 stream,
4186 )?;
4187
4188 gemm_backward_a(
4190 &scratch.q,
4191 b_q,
4192 &mut scratch.lora_inter, s,
4194 qd,
4195 r,
4196 stream,
4197 )?;
4198
4199 gemm_backward_b(
4201 &scratch.norm1_out,
4202 &scratch.lora_inter,
4203 &mut grad_lora.grad_lora_a_q,
4204 s,
4205 h,
4206 r,
4207 stream,
4208 )?;
4209
4210 gemm_backward_a(
4212 &scratch.lora_inter,
4213 a_q,
4214 &mut scratch.lora_temp, s,
4216 r,
4217 h,
4218 stream,
4219 )?;
4220 cuda_add_inplace(
4221 &mut scratch.o_proj_out,
4222 &scratch.lora_temp,
4223 seq_len * hidden_size,
4224 stream,
4225 )?;
4226 }
4227
4228 static USE_FUSED_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
4230 let fused_bwd = *USE_FUSED_BWD_ATTN
4231 .get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
4232
4233 if nf4_tc_bwd {
4234 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4236 &scratch.k,
4237 &self.w_k_nf4,
4238 &self.w_k_scales,
4239 &mut scratch.ffn_out,
4240 s,
4241 kvh, h, stream,
4244 )?;
4245 cuda_add_inplace(
4246 &mut scratch.o_proj_out,
4247 &scratch.ffn_out,
4248 seq_len * hidden_size,
4249 stream,
4250 )?;
4251 crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
4252 &scratch.v,
4253 &self.w_v_nf4,
4254 &self.w_v_scales,
4255 &mut scratch.ffn_out,
4256 s,
4257 kvh, h, stream,
4260 )?;
4261 cuda_add_inplace(
4262 &mut scratch.o_proj_out,
4263 &scratch.ffn_out,
4264 seq_len * hidden_size,
4265 stream,
4266 )?;
4267 } else if fused_bwd {
4268 gemm_backward_a_fp16_dispatch_accumulate(
4270 &scratch.k,
4271 self.w_k_fp16.as_ref(),
4272 &self.w_k_fp32,
4273 &mut scratch.o_proj_out,
4274 s,
4275 h,
4276 kvh,
4277 stream,
4278 &self.ctx,
4279 )?;
4280 gemm_backward_a_fp16_dispatch_accumulate(
4281 &scratch.v,
4282 self.w_v_fp16.as_ref(),
4283 &self.w_v_fp32,
4284 &mut scratch.o_proj_out,
4285 s,
4286 h,
4287 kvh,
4288 stream,
4289 &self.ctx,
4290 )?;
4291 } else {
4292 gemm_backward_a_fp16_dispatch(
4294 &scratch.k,
4295 self.w_k_fp16.as_ref(),
4296 &self.w_k_fp32,
4297 &mut scratch.ffn_out,
4298 s,
4299 h,
4300 kvh,
4301 stream,
4302 &self.ctx,
4303 )?;
4304 cuda_add_inplace(
4305 &mut scratch.o_proj_out,
4306 &scratch.ffn_out,
4307 seq_len * hidden_size,
4308 stream,
4309 )?;
4310
4311 gemm_backward_a_fp16_dispatch(
4312 &scratch.v,
4313 self.w_v_fp16.as_ref(),
4314 &self.w_v_fp32,
4315 &mut scratch.ffn_out,
4316 s,
4317 h,
4318 kvh,
4319 stream,
4320 &self.ctx,
4321 )?;
4322 cuda_add_inplace(
4323 &mut scratch.o_proj_out,
4324 &scratch.ffn_out,
4325 seq_len * hidden_size,
4326 stream,
4327 )?;
4328 }
4329
4330 if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
4332 let r = saturating_u32(self.lora_rank);
4333
4334 gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
4336
4337 gemm_backward_b(
4339 &scratch.lora_inter,
4340 &scratch.v,
4341 &mut grad_lora.grad_lora_b_v,
4342 s,
4343 r,
4344 kvh,
4345 stream,
4346 )?;
4347
4348 gemm_backward_a(&scratch.v, b_v, &mut scratch.lora_inter, s, kvh, r, stream)?;
4350
4351 gemm_backward_b(
4353 &scratch.norm1_out,
4354 &scratch.lora_inter,
4355 &mut grad_lora.grad_lora_a_v,
4356 s,
4357 h,
4358 r,
4359 stream,
4360 )?;
4361
4362 gemm_backward_a(&scratch.lora_inter, a_v, &mut scratch.lora_temp, s, r, h, stream)?;
4364 cuda_add_inplace(
4365 &mut scratch.o_proj_out,
4366 &scratch.lora_temp,
4367 seq_len * hidden_size,
4368 stream,
4369 )?;
4370 }
4371
4372 scratch.op_end(_t, OP_QKV_BWD);
4373
4374 unsafe {
4376 scratch.grad_hidden.copy_from_buffer_async(&scratch.o_proj_out, stream).map_err(
4377 |e| {
4378 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4379 "grad_norm1 copy failed: {e}"
4380 ))
4381 },
4382 )?;
4383 }
4384
4385 Ok(())
4386 }
4387
4388 fn backward_nf4_attention_mechanism(
4410 &self,
4411 seq_len: usize,
4412 num_heads: usize,
4413 head_dim: usize,
4414 stream: &CudaStream,
4415 scratch: &mut CudaBlockScratch,
4416 ) -> Result<()> {
4417 let num_kv_heads = self.config.num_kv_heads;
4418 let heads_per_kv = num_heads / num_kv_heads;
4419 let s = saturating_u32(seq_len);
4420 let nh = saturating_u32(num_heads);
4421 let nkv = saturating_u32(num_kv_heads);
4422 let hd = saturating_u32(head_dim);
4423 let scale = 1.0 / (head_dim as f32).sqrt();
4424
4425 interleaved_to_batched_forward(
4428 &scratch.attn_out,
4429 &mut scratch.attn_q_batched, s,
4431 nh,
4432 hd,
4433 stream,
4434 )?;
4435
4436 interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
4440
4441 if heads_per_kv > 1 {
4442 expand_kv_heads(
4443 &scratch.attn_kv_temp,
4444 &mut scratch.attn_kv_temp2,
4445 num_kv_heads,
4446 heads_per_kv,
4447 seq_len * head_dim,
4448 stream,
4449 )?;
4450 } else {
4451 unsafe {
4452 scratch
4453 .attn_kv_temp2
4454 .copy_from_buffer_async(&scratch.attn_kv_temp, stream)
4455 .map_err(|e| {
4456 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4457 "V copy for attn backward: {e:?}"
4458 ))
4459 })?;
4460 }
4461 }
4462 batched_transpose_forward(
4466 &scratch.attn_kv_temp2,
4467 &mut scratch.attn_kv_temp, nh,
4469 s,
4470 hd,
4471 stream,
4472 )?;
4473
4474 batched_4d_gemm_forward(
4477 &scratch.attn_q_batched,
4478 &scratch.attn_kv_temp,
4479 &mut scratch.grad_attn_scores,
4480 1,
4481 nh,
4482 s,
4483 s,
4484 hd,
4485 stream,
4486 )?;
4487
4488 batched_transpose_forward(
4494 &scratch.attn_q_batched,
4495 &mut scratch.attn_kv_temp, nh,
4497 s,
4498 hd,
4499 stream,
4500 )?;
4501
4502 batched_4d_gemm_forward(
4504 &scratch.attn_kv_temp,
4505 &scratch.attn_scores, &mut scratch.attn_kv_temp2, 1,
4508 nh,
4509 hd,
4510 s,
4511 s,
4512 stream,
4513 )?;
4514
4515 batched_transpose_forward(
4517 &scratch.attn_kv_temp2,
4518 &mut scratch.attn_kv_temp, nh,
4520 hd,
4521 s,
4522 stream,
4523 )?;
4524 let total_rows = nh * s;
4529 {
4530 let grad_scores_view = unsafe {
4531 GpuBuffer::<f32>::from_raw_parts(
4532 scratch.grad_attn_scores.as_ptr(),
4533 scratch.grad_attn_scores.len(),
4534 )
4535 };
4536 batched_softmax_backward(
4537 &scratch.attn_scores,
4538 &grad_scores_view,
4539 &mut scratch.grad_attn_scores,
4540 total_rows,
4541 s,
4542 stream,
4543 )?;
4544 leak(grad_scores_view);
4545 }
4546
4547 let total_scores = saturating_u32(num_heads * seq_len * seq_len);
4549 {
4550 let scores_view = unsafe {
4551 GpuBuffer::<f32>::from_raw_parts(
4552 scratch.grad_attn_scores.as_ptr(),
4553 scratch.grad_attn_scores.len(),
4554 )
4555 };
4556 scale_forward(
4557 &scores_view,
4558 &mut scratch.grad_attn_scores,
4559 scale,
4560 total_scores,
4561 stream,
4562 )?;
4563 leak(scores_view);
4564 }
4565
4566 interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp2, s, nkv, hd, stream)?;
4568
4569 if heads_per_kv > 1 {
4570 unsafe {
4571 scratch
4572 .attn_q_batched
4573 .copy_from_buffer_async(&scratch.attn_kv_temp2, stream)
4574 .map_err(|e| {
4575 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4576 "K copy for GQA expand: {e}"
4577 ))
4578 })?;
4579 }
4580 expand_kv_heads(
4581 &scratch.attn_q_batched,
4582 &mut scratch.attn_kv_temp2,
4583 num_kv_heads,
4584 heads_per_kv,
4585 seq_len * head_dim,
4586 stream,
4587 )?;
4588 }
4589 batched_4d_gemm_forward(
4593 &scratch.grad_attn_scores,
4594 &scratch.attn_kv_temp2,
4595 &mut scratch.attn_q_batched,
4596 1,
4597 nh,
4598 s,
4599 hd,
4600 s,
4601 stream,
4602 )?;
4603
4604 interleaved_to_batched_forward(
4608 &scratch.q,
4609 &mut scratch.o_proj_out, s,
4611 nh,
4612 hd,
4613 stream,
4614 )?;
4615
4616 batched_transpose_forward(
4618 &scratch.o_proj_out,
4619 &mut scratch.attn_kv_temp2, nh,
4621 s,
4622 hd,
4623 stream,
4624 )?;
4625
4626 batched_4d_gemm_forward(
4628 &scratch.attn_kv_temp2,
4629 &scratch.grad_attn_scores,
4630 &mut scratch.ffn_out, 1,
4632 nh,
4633 hd,
4634 s,
4635 s,
4636 stream,
4637 )?;
4638
4639 batched_transpose_forward(
4641 &scratch.ffn_out,
4642 &mut scratch.attn_kv_temp2, nh,
4644 hd,
4645 s,
4646 stream,
4647 )?;
4648
4649 if heads_per_kv > 1 {
4651 self.reduce_gqa_gradients_nf4(
4652 num_kv_heads,
4653 heads_per_kv,
4654 seq_len,
4655 head_dim,
4656 stream,
4657 scratch,
4658 )?;
4659 }
4660
4661 batched_to_interleaved_forward(&scratch.attn_q_batched, &mut scratch.q, s, nh, hd, stream)?;
4664
4665 batched_to_interleaved_forward(&scratch.attn_kv_temp2, &mut scratch.k, s, nkv, hd, stream)?;
4667
4668 batched_to_interleaved_forward(&scratch.attn_kv_temp, &mut scratch.v, s, nkv, hd, stream)?;
4671
4672 Ok(())
4673 }
4674
4675 fn reduce_gqa_gradients_nf4(
4679 &self,
4680 num_kv_heads: usize,
4681 heads_per_kv: usize,
4682 seq_len: usize,
4683 head_dim: usize,
4684 stream: &CudaStream,
4685 scratch: &mut CudaBlockScratch,
4686 ) -> Result<()> {
4687 let chunk = seq_len * head_dim;
4688 for g in 0..num_kv_heads {
4689 let dst_off = g * chunk;
4690 let src_off = g * heads_per_kv * chunk;
4692 {
4694 let src = unsafe {
4695 GpuBuffer::<f32>::from_raw_parts(
4696 scratch.attn_kv_temp2.as_ptr() + (src_off * 4) as u64,
4697 chunk,
4698 )
4699 };
4700 let mut dst = unsafe {
4701 GpuBuffer::<f32>::from_raw_parts(
4702 scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4703 chunk,
4704 )
4705 };
4706 if src_off != dst_off {
4707 unsafe {
4708 dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4709 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4710 "GQA K reduce copy: {e}"
4711 ))
4712 })?;
4713 }
4714 }
4715 leak(src);
4716 leak(dst);
4717 }
4718 for h in 1..heads_per_kv {
4720 let add_off = (g * heads_per_kv + h) * chunk;
4721 let src = unsafe {
4722 GpuBuffer::<f32>::from_raw_parts(
4723 scratch.attn_kv_temp2.as_ptr() + (add_off * 4) as u64,
4724 chunk,
4725 )
4726 };
4727 let mut dst = unsafe {
4728 GpuBuffer::<f32>::from_raw_parts(
4729 scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
4730 chunk,
4731 )
4732 };
4733 cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4734 leak(src);
4735 leak(dst);
4736 }
4737 {
4739 let src = unsafe {
4740 GpuBuffer::<f32>::from_raw_parts(
4741 scratch.attn_kv_temp.as_ptr() + (src_off * 4) as u64,
4742 chunk,
4743 )
4744 };
4745 let mut dst = unsafe {
4746 GpuBuffer::<f32>::from_raw_parts(
4747 scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4748 chunk,
4749 )
4750 };
4751 if src_off != dst_off {
4752 unsafe {
4753 dst.copy_from_buffer_async(&src, stream).map_err(|e| {
4754 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4755 "GQA V reduce copy: {e}"
4756 ))
4757 })?;
4758 }
4759 }
4760 leak(src);
4761 leak(dst);
4762 }
4763 for h in 1..heads_per_kv {
4764 let add_off = (g * heads_per_kv + h) * chunk;
4765 let src = unsafe {
4766 GpuBuffer::<f32>::from_raw_parts(
4767 scratch.attn_kv_temp.as_ptr() + (add_off * 4) as u64,
4768 chunk,
4769 )
4770 };
4771 let mut dst = unsafe {
4772 GpuBuffer::<f32>::from_raw_parts(
4773 scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
4774 chunk,
4775 )
4776 };
4777 cuda_add_inplace(&mut dst, &src, chunk, stream)?;
4778 leak(src);
4779 leak(dst);
4780 }
4781 }
4782 Ok(())
4783 }
4784
4785 pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
4787 GpuLoraOptimizerState::new(&self.ctx, &self.config, self.lora_rank)
4788 }
4789
4790 #[allow(clippy::too_many_arguments)]
4792 pub(crate) fn lora_optimizer_step(
4793 &mut self,
4794 state: &mut GpuLoraOptimizerState,
4795 step: u32,
4796 lr: f32,
4797 beta1: f32,
4798 beta2: f32,
4799 eps: f32,
4800 weight_decay: f32,
4801 stream: &CudaStream,
4802 grad_lora: &CudaLoraGradWorkspace,
4803 ) -> Result<()> {
4804 let h = self.config.hidden_size;
4805 let q_dim = self.config.q_dim();
4806 let kv = self.config.num_kv_heads * self.config.head_dim();
4807 let r = self.lora_rank;
4808
4809 if let Some(ref mut a_q) = self.lora_a_q {
4811 adamw_step_cuda(
4812 a_q,
4813 &grad_lora.grad_lora_a_q,
4814 &mut state.m_lora_a_q,
4815 &mut state.v_lora_a_q,
4816 lr,
4817 beta1,
4818 beta2,
4819 eps,
4820 weight_decay,
4821 step,
4822 saturating_u32(h * r),
4823 stream,
4824 )?;
4825 }
4826 if let Some(ref mut b_q) = self.lora_b_q {
4827 adamw_step_cuda(
4828 b_q,
4829 &grad_lora.grad_lora_b_q,
4830 &mut state.m_lora_b_q,
4831 &mut state.v_lora_b_q,
4832 lr,
4833 beta1,
4834 beta2,
4835 eps,
4836 weight_decay,
4837 step,
4838 saturating_u32(r * q_dim),
4839 stream,
4840 )?;
4841 }
4842 if let Some(ref mut a_v) = self.lora_a_v {
4843 adamw_step_cuda(
4844 a_v,
4845 &grad_lora.grad_lora_a_v,
4846 &mut state.m_lora_a_v,
4847 &mut state.v_lora_a_v,
4848 lr,
4849 beta1,
4850 beta2,
4851 eps,
4852 weight_decay,
4853 step,
4854 saturating_u32(h * r),
4855 stream,
4856 )?;
4857 }
4858 if let Some(ref mut b_v) = self.lora_b_v {
4859 adamw_step_cuda(
4860 b_v,
4861 &grad_lora.grad_lora_b_v,
4862 &mut state.m_lora_b_v,
4863 &mut state.v_lora_b_v,
4864 lr,
4865 beta1,
4866 beta2,
4867 eps,
4868 weight_decay,
4869 step,
4870 saturating_u32(r * kv),
4871 stream,
4872 )?;
4873 }
4874
4875 adamw_step_cuda(
4877 &mut self.input_norm_weight,
4878 &grad_lora.grad_input_norm,
4879 &mut state.m_input_norm,
4880 &mut state.v_input_norm,
4881 lr,
4882 beta1,
4883 beta2,
4884 eps,
4885 weight_decay,
4886 step,
4887 saturating_u32(h),
4888 stream,
4889 )?;
4890 adamw_step_cuda(
4891 &mut self.post_attn_norm_weight,
4892 &grad_lora.grad_post_attn_norm,
4893 &mut state.m_post_attn_norm,
4894 &mut state.v_post_attn_norm,
4895 lr,
4896 beta1,
4897 beta2,
4898 eps,
4899 weight_decay,
4900 step,
4901 saturating_u32(h),
4902 stream,
4903 )?;
4904
4905 Ok(())
4906 }
4907
4908 pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
4914 let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
4915 let mut host = vec![0.0f32; buf.len()];
4916 buf.copy_to_host(&mut host).map_err(|e| {
4917 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4918 "LoRA weight download failed: {e}"
4919 ))
4920 })?;
4921 Ok(host)
4922 };
4923 let a_q = self.lora_a_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4924 let b_q = self.lora_b_q.as_ref().map(&download).transpose()?.unwrap_or_default();
4925 let a_v = self.lora_a_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4926 let b_v = self.lora_b_v.as_ref().map(&download).transpose()?.unwrap_or_default();
4927 Ok((a_q, b_q, a_v, b_v))
4928 }
4929
4930 pub fn upload_lora_weights(
4936 &mut self,
4937 a_q: &[f32],
4938 b_q: &[f32],
4939 a_v: &[f32],
4940 b_v: &[f32],
4941 ) -> Result<()> {
4942 let upload = |buf: &mut GpuBuffer<f32>, data: &[f32], name: &str| -> Result<()> {
4943 if data.len() != buf.len() {
4944 return Err(crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(
4945 format!(
4946 "LoRA {name} size mismatch: checkpoint has {} but GPU buffer expects {}",
4947 data.len(),
4948 buf.len()
4949 ),
4950 ));
4951 }
4952 buf.copy_from_host(data).map_err(|e| {
4953 crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
4954 "LoRA {name} upload failed: {e}"
4955 ))
4956 })
4957 };
4958 if let Some(ref mut buf) = self.lora_a_q {
4959 upload(buf, a_q, "a_q")?;
4960 }
4961 if let Some(ref mut buf) = self.lora_b_q {
4962 upload(buf, b_q, "b_q")?;
4963 }
4964 if let Some(ref mut buf) = self.lora_a_v {
4965 upload(buf, a_v, "a_v")?;
4966 }
4967 if let Some(ref mut buf) = self.lora_b_v {
4968 upload(buf, b_v, "b_v")?;
4969 }
4970 Ok(())
4971 }
4972}
4973
4974#[cfg(test)]
4975mod tests {
4976 #[test]
4977 fn test_cuda_block_compiles() {
4978 #[cfg(feature = "cuda")]
4980 {
4981 use super::*;
4982 let _ = std::mem::size_of::<CudaTransformerBlock>();
4983 let _ = std::mem::size_of::<CudaNf4TransformerBlock>();
4984 }
4985 }
4986}