#![allow(dead_code)]
#![allow(unsafe_code)]
#[cfg(feature = "cuda")]
const OP_RMSNORM_ATTN: usize = 0;
#[cfg(feature = "cuda")]
const OP_QKV_GEMM: usize = 1;
#[cfg(feature = "cuda")]
const OP_ATTENTION: usize = 2;
#[cfg(feature = "cuda")]
const OP_O_PROJ: usize = 3;
#[cfg(feature = "cuda")]
const OP_RMSNORM_FFN: usize = 4;
#[cfg(feature = "cuda")]
const OP_GATE_UP_GEMM: usize = 5;
#[cfg(feature = "cuda")]
const OP_SILU: usize = 6;
#[cfg(feature = "cuda")]
const OP_DOWN_GEMM: usize = 7;
#[cfg(feature = "cuda")]
const OP_LORA_FWD: usize = 8;
#[cfg(feature = "cuda")]
const OP_DOWN_BWD: usize = 9;
#[cfg(feature = "cuda")]
const OP_SWIGLU_BWD: usize = 10;
#[cfg(feature = "cuda")]
const OP_GATE_UP_BWD: usize = 11;
#[cfg(feature = "cuda")]
const OP_ATTN_BWD: usize = 12;
#[cfg(feature = "cuda")]
const OP_QKV_BWD: usize = 13;
#[cfg(feature = "cuda")]
const OP_NORM_BWD: usize = 14;
#[cfg(feature = "cuda")]
const OP_LORA_BWD: usize = 15;
#[cfg(feature = "cuda")]
use std::sync::Arc;
#[cfg(feature = "cuda")]
#[inline]
fn saturating_u32(v: usize) -> u32 {
v.min(u32::MAX as usize) as u32
}
#[cfg(feature = "cuda")]
#[inline]
fn leak<T>(val: T) {
let _ = std::mem::ManuallyDrop::new(val);
}
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_backward::{
batched_softmax_backward, gemm_backward_a, gemm_backward_a_fp16_dispatch,
gemm_backward_a_fp16_dispatch_accumulate, gemm_backward_b, rms_norm_backward, silu_backward,
};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_forward::{
batched_4d_gemm_forward, batched_rope_neox_backward, batched_rope_neox_forward,
batched_softmax_forward, batched_to_interleaved_forward, batched_transpose_forward,
cast_f32_to_f16_gpu, elementwise_mul_forward, expand_kv_heads, fused_residual_rmsnorm_forward,
fused_swiglu_forward, gemm_f16_to_f32_forward, gemm_forward, interleaved_to_batched_forward,
per_head_rmsnorm_forward, residual_add_forward, rms_norm_forward, scale_forward, silu_forward,
};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_optim::{adamw_step_cuda, gradient_clip_cuda, squared_sum_cuda};
#[cfg(feature = "cuda")]
use crate::autograd::cuda_tensor::Result;
#[cfg(feature = "cuda")]
use super::config::TransformerConfig;
#[cfg(feature = "cuda")]
pub struct CudaTransformerBlock {
config: TransformerConfig,
layer_idx: usize,
input_norm_weight: GpuBuffer<f32>,
post_attn_norm_weight: GpuBuffer<f32>,
w_q: GpuBuffer<f32>,
w_k: GpuBuffer<f32>,
w_v: GpuBuffer<f32>,
w_o: GpuBuffer<f32>,
w_gate: GpuBuffer<f32>,
w_up: GpuBuffer<f32>,
w_down: GpuBuffer<f32>,
ctx: Arc<CudaContext>,
scratch: CudaBlockScratch,
norm_zero_buf: Vec<f32>,
q_norm_weight: Option<GpuBuffer<f32>>,
k_norm_weight: Option<GpuBuffer<f32>>,
}
#[cfg(feature = "cuda")]
pub(crate) struct CudaBlockScratch {
norm1_out: GpuBuffer<f32>,
q: GpuBuffer<f32>,
k: GpuBuffer<f32>,
v: GpuBuffer<f32>,
attn_scores: GpuBuffer<f32>,
attn_out: GpuBuffer<f32>,
o_proj_out: GpuBuffer<f32>,
residual1: GpuBuffer<f32>,
norm2_out: GpuBuffer<f32>,
gate_out: GpuBuffer<f32>,
up_out: GpuBuffer<f32>,
swiglu_out: GpuBuffer<f32>,
ffn_out: GpuBuffer<f32>,
norm1_out_f16: Option<GpuBuffer<u16>>,
attn_out_f16: Option<GpuBuffer<u16>>,
norm2_out_f16: Option<GpuBuffer<u16>>,
swiglu_out_f16: Option<GpuBuffer<u16>>,
grad_hidden: GpuBuffer<f32>,
grad_swiglu: GpuBuffer<f32>,
attn_q_batched: GpuBuffer<f32>,
attn_kv_temp: GpuBuffer<f32>,
attn_kv_temp2: GpuBuffer<f32>,
grad_attn_scores: GpuBuffer<f32>,
lora_inter: GpuBuffer<f32>,
lora_temp: GpuBuffer<f32>,
rope_positions: GpuBuffer<u32>,
causal_mask_contiguous: GpuBuffer<f32>,
pub(crate) causal_mask_cached_seq_len: usize,
pub(crate) op_us: [u64; 16],
pub(crate) op_profiling_enabled: bool,
}
#[cfg(feature = "cuda")]
impl CudaBlockScratch {
#[inline]
pub(crate) fn op_begin(&self) -> Option<std::time::Instant> {
if self.op_profiling_enabled {
Some(std::time::Instant::now())
} else {
None
}
}
#[inline]
pub(crate) fn op_end(&mut self, start: Option<std::time::Instant>, op: usize) {
if let Some(t) = start {
if op < 16 {
self.op_us[op] += t.elapsed().as_micros() as u64;
}
}
}
pub(crate) fn max_seq_len(&self, hidden_size: usize) -> usize {
self.norm1_out.len() / hidden_size.max(1)
}
#[rustfmt::skip]
pub(crate) fn zero_forward_buffers(&mut self, stream: &CudaStream) {
let z = |b: &mut GpuBuffer<f32>| { b.zero_async(stream).ok(); };
z(&mut self.norm1_out); z(&mut self.q); z(&mut self.k); z(&mut self.v); z(&mut self.attn_scores); z(&mut self.attn_out);
z(&mut self.o_proj_out); z(&mut self.residual1); z(&mut self.norm2_out); z(&mut self.gate_out); z(&mut self.up_out);
z(&mut self.swiglu_out); z(&mut self.ffn_out); z(&mut self.attn_q_batched); z(&mut self.attn_kv_temp); z(&mut self.attn_kv_temp2);
z(&mut self.grad_hidden); z(&mut self.grad_swiglu); z(&mut self.grad_attn_scores); z(&mut self.lora_inter); z(&mut self.lora_temp);
self.causal_mask_cached_seq_len = 0;
}
pub(crate) fn new(
config: &TransformerConfig,
max_seq_len: usize,
ctx: &Arc<CudaContext>,
lora_rank: usize,
) -> Result<Self> {
let hidden_size = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let num_heads = config.num_attention_heads;
let head_dim = config.head_dim();
let max_proj_dim = q_dim.max(kv_hidden_size);
let lora_inter_size = (max_seq_len * lora_rank).max(1);
let lora_temp_size = (max_seq_len * max_proj_dim).max(1);
let causal_mask_data: Vec<f32> = (0..max_seq_len * max_seq_len)
.map(|idx| {
let row = idx / max_seq_len;
let col = idx % max_seq_len;
if col <= row {
0.0f32
} else {
f32::NEG_INFINITY
}
})
.collect();
Ok(Self {
norm1_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
q: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
k: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
v: GpuBuffer::new(ctx, max_seq_len * kv_hidden_size)?,
attn_scores: GpuBuffer::new(ctx, num_heads * max_seq_len * max_seq_len)?,
attn_out: GpuBuffer::new(ctx, max_seq_len * q_dim)?,
o_proj_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
residual1: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
norm2_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
gate_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
up_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
swiglu_out: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
ffn_out: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
norm1_out_f16: None,
attn_out_f16: None,
norm2_out_f16: None,
swiglu_out_f16: None,
grad_hidden: GpuBuffer::new(ctx, max_seq_len * hidden_size)?,
grad_swiglu: GpuBuffer::new(ctx, max_seq_len * intermediate_size)?,
attn_q_batched: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
attn_kv_temp: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
attn_kv_temp2: GpuBuffer::new(ctx, num_heads * max_seq_len * head_dim)?,
grad_attn_scores: GpuBuffer::new(
ctx,
num_heads * max_seq_len * max_seq_len.max(head_dim),
)?,
lora_inter: GpuBuffer::new(ctx, lora_inter_size)?,
lora_temp: GpuBuffer::new(ctx, lora_temp_size)?,
rope_positions: {
let positions: Vec<u32> = (0..max_seq_len as u32).collect();
let mut buf = GpuBuffer::new(ctx, max_seq_len)?;
buf.copy_from_host(&positions)?;
buf
},
causal_mask_contiguous: GpuBuffer::from_host(ctx, &causal_mask_data)?,
causal_mask_cached_seq_len: max_seq_len,
op_us: [0u64; 16],
op_profiling_enabled: false,
})
}
pub(crate) fn prepare_causal_mask(
&mut self,
seq_len: usize,
ctx: &Arc<CudaContext>,
) -> crate::autograd::cuda_tensor::Result<()> {
if seq_len == self.causal_mask_cached_seq_len {
return Ok(());
}
let mask_data: Vec<f32> = (0..seq_len * seq_len)
.map(|idx| {
let row = idx / seq_len;
let col = idx % seq_len;
if col <= row {
0.0f32
} else {
f32::NEG_INFINITY
}
})
.collect();
self.causal_mask_contiguous = GpuBuffer::from_host(ctx, &mask_data)?;
self.causal_mask_cached_seq_len = seq_len;
Ok(())
}
}
#[cfg(feature = "cuda")]
pub struct CudaGradWorkspace {
pub(crate) grad_input_norm: GpuBuffer<f32>,
pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
pub(crate) grad_gate: GpuBuffer<f32>,
pub(crate) grad_up: GpuBuffer<f32>,
pub(crate) grad_down: GpuBuffer<f32>,
pub(crate) grad_w_q: GpuBuffer<f32>,
pub(crate) grad_w_k: GpuBuffer<f32>,
pub(crate) grad_w_v: GpuBuffer<f32>,
pub(crate) grad_w_o: GpuBuffer<f32>,
}
#[cfg(feature = "cuda")]
impl CudaGradWorkspace {
pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> Result<Self> {
let h = config.hidden_size;
let q = config.q_dim();
let kv = config.num_kv_heads * config.head_dim();
let i = config.intermediate_size;
Ok(Self {
grad_input_norm: GpuBuffer::new(ctx, h)?,
grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
grad_gate: GpuBuffer::new(ctx, h * i)?,
grad_up: GpuBuffer::new(ctx, h * i)?,
grad_down: GpuBuffer::new(ctx, i * h)?,
grad_w_q: GpuBuffer::new(ctx, q * h)?,
grad_w_k: GpuBuffer::new(ctx, h * kv)?,
grad_w_v: GpuBuffer::new(ctx, h * kv)?,
grad_w_o: GpuBuffer::new(ctx, h * q)?,
})
}
pub fn zero_norm_grads(&mut self, zero_buf: &[f32]) -> Result<()> {
let n = self.grad_input_norm.len();
self.grad_input_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Failed to zero grad_input_norm: {e:?}"
))
})?;
self.grad_post_attn_norm.copy_from_host(&zero_buf[..n]).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Failed to zero grad_post_attn_norm: {e:?}"
))
})?;
Ok(())
}
}
#[cfg(feature = "cuda")]
pub struct GpuBlockOptimizerState {
m_w_q: GpuBuffer<f32>,
v_w_q: GpuBuffer<f32>,
m_w_k: GpuBuffer<f32>,
v_w_k: GpuBuffer<f32>,
m_w_v: GpuBuffer<f32>,
v_w_v: GpuBuffer<f32>,
m_w_o: GpuBuffer<f32>,
v_w_o: GpuBuffer<f32>,
m_w_gate: GpuBuffer<f32>,
v_w_gate: GpuBuffer<f32>,
m_w_up: GpuBuffer<f32>,
v_w_up: GpuBuffer<f32>,
m_w_down: GpuBuffer<f32>,
v_w_down: GpuBuffer<f32>,
m_input_norm: GpuBuffer<f32>,
v_input_norm: GpuBuffer<f32>,
m_post_attn_norm: GpuBuffer<f32>,
v_post_attn_norm: GpuBuffer<f32>,
}
#[cfg(feature = "cuda")]
impl GpuBlockOptimizerState {
pub fn download_to_host(
&self,
) -> crate::autograd::cuda_tensor::Result<Vec<(String, Vec<f32>)>> {
let dl = |name: &str,
buf: &GpuBuffer<f32>|
-> crate::autograd::cuda_tensor::Result<(String, Vec<f32>)> {
let mut host = vec![0.0f32; buf.len()];
buf.copy_to_host(&mut host).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"optimizer D2H {name}: {e}"
))
})?;
Ok((name.to_string(), host))
};
Ok(vec![
dl("m.w_q", &self.m_w_q)?,
dl("v.w_q", &self.v_w_q)?,
dl("m.w_k", &self.m_w_k)?,
dl("v.w_k", &self.v_w_k)?,
dl("m.w_v", &self.m_w_v)?,
dl("v.w_v", &self.v_w_v)?,
dl("m.w_o", &self.m_w_o)?,
dl("v.w_o", &self.v_w_o)?,
dl("m.w_gate", &self.m_w_gate)?,
dl("v.w_gate", &self.v_w_gate)?,
dl("m.w_up", &self.m_w_up)?,
dl("v.w_up", &self.v_w_up)?,
dl("m.w_down", &self.m_w_down)?,
dl("v.w_down", &self.v_w_down)?,
dl("m.input_norm", &self.m_input_norm)?,
dl("v.input_norm", &self.v_input_norm)?,
dl("m.post_attn_norm", &self.m_post_attn_norm)?,
dl("v.post_attn_norm", &self.v_post_attn_norm)?,
])
}
pub fn restore_from_host(
&mut self,
data: &std::collections::HashMap<String, Vec<f32>>,
) -> crate::autograd::cuda_tensor::Result<()> {
let ul = |name: &str,
buf: &mut GpuBuffer<f32>,
data: &std::collections::HashMap<String, Vec<f32>>|
-> crate::autograd::cuda_tensor::Result<()> {
if let Some(host_data) = data.get(name) {
if host_data.len() == buf.len() {
buf.copy_from_host(host_data).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"optimizer H2D {name}: {e}"
))
})?;
}
}
Ok(())
};
ul("m.w_q", &mut self.m_w_q, data)?;
ul("v.w_q", &mut self.v_w_q, data)?;
ul("m.w_k", &mut self.m_w_k, data)?;
ul("v.w_k", &mut self.v_w_k, data)?;
ul("m.w_v", &mut self.m_w_v, data)?;
ul("v.w_v", &mut self.v_w_v, data)?;
ul("m.w_o", &mut self.m_w_o, data)?;
ul("v.w_o", &mut self.v_w_o, data)?;
ul("m.w_gate", &mut self.m_w_gate, data)?;
ul("v.w_gate", &mut self.v_w_gate, data)?;
ul("m.w_up", &mut self.m_w_up, data)?;
ul("v.w_up", &mut self.v_w_up, data)?;
ul("m.w_down", &mut self.m_w_down, data)?;
ul("v.w_down", &mut self.v_w_down, data)?;
ul("m.input_norm", &mut self.m_input_norm, data)?;
ul("v.input_norm", &mut self.v_input_norm, data)?;
ul("m.post_attn_norm", &mut self.m_post_attn_norm, data)?;
ul("v.post_attn_norm", &mut self.v_post_attn_norm, data)?;
Ok(())
}
}
#[cfg(feature = "cuda")]
impl CudaTransformerBlock {
pub fn new(
config: &TransformerConfig,
layer_idx: usize,
ctx: Arc<CudaContext>,
input_norm_weight: &[f32],
post_attn_norm_weight: &[f32],
w_q: &[f32],
w_k: &[f32],
w_v: &[f32],
w_o: &[f32],
w_gate: &[f32],
w_up: &[f32],
w_down: &[f32],
max_seq_len: usize,
) -> Result<Self> {
let hidden_size = config.hidden_size;
let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let num_heads = config.num_attention_heads;
let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
let w_q = GpuBuffer::from_host(&ctx, w_q)?;
let w_k = GpuBuffer::from_host(&ctx, w_k)?;
let w_v = GpuBuffer::from_host(&ctx, w_v)?;
let w_o = GpuBuffer::from_host(&ctx, w_o)?;
let w_gate = GpuBuffer::from_host(&ctx, w_gate)?;
let w_up = GpuBuffer::from_host(&ctx, w_up)?;
let w_down = GpuBuffer::from_host(&ctx, w_down)?;
let single_mask: Vec<f32> = (0..max_seq_len * max_seq_len)
.map(|idx| {
let row = idx / max_seq_len;
let col = idx % max_seq_len;
if col <= row {
0.0f32
} else {
f32::NEG_INFINITY
}
})
.collect();
let scratch = CudaBlockScratch {
norm1_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
q: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
k: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
v: GpuBuffer::new(&ctx, max_seq_len * kv_hidden_size)?,
attn_scores: GpuBuffer::new(&ctx, num_heads * max_seq_len * max_seq_len)?,
attn_out: GpuBuffer::new(&ctx, max_seq_len * q_dim)?,
o_proj_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
residual1: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
norm2_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
gate_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
up_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
swiglu_out: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
ffn_out: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
norm1_out_f16: None,
attn_out_f16: None,
norm2_out_f16: None,
swiglu_out_f16: None,
grad_hidden: GpuBuffer::new(&ctx, max_seq_len * hidden_size)?,
grad_swiglu: GpuBuffer::new(&ctx, max_seq_len * intermediate_size)?,
attn_q_batched: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
attn_kv_temp: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
attn_kv_temp2: GpuBuffer::new(&ctx, num_heads * max_seq_len * config.head_dim())?,
grad_attn_scores: GpuBuffer::new(
&ctx,
num_heads * max_seq_len * max_seq_len.max(config.head_dim()),
)?,
lora_inter: GpuBuffer::new(&ctx, 1)?,
lora_temp: GpuBuffer::new(&ctx, 1)?,
rope_positions: {
let positions: Vec<u32> = (0..max_seq_len as u32).collect();
let mut buf = GpuBuffer::new(&ctx, max_seq_len)?;
buf.copy_from_host(&positions)?;
buf
},
causal_mask_contiguous: GpuBuffer::from_host(&ctx, &single_mask)?,
causal_mask_cached_seq_len: max_seq_len,
op_us: [0u64; 16],
op_profiling_enabled: false,
};
Ok(Self {
config: config.clone(),
layer_idx,
input_norm_weight,
post_attn_norm_weight,
w_q,
w_k,
w_v,
w_o,
w_gate,
w_up,
w_down,
ctx,
scratch,
norm_zero_buf: vec![0.0f32; hidden_size],
q_norm_weight: None, k_norm_weight: None,
})
}
#[allow(dead_code)]
pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()> {
self.q_norm_weight = Some(GpuBuffer::from_host(&self.ctx, q_norm)?);
self.k_norm_weight = Some(GpuBuffer::from_host(&self.ctx, k_norm)?);
Ok(())
}
pub fn forward(
&mut self,
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
) -> Result<()> {
let hidden_size = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
let intermediate_size = self.config.intermediate_size;
rms_norm_forward(
input,
&self.input_norm_weight,
&mut self.scratch.norm1_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
stream,
)?;
gemm_forward(
&self.scratch.norm1_out,
&self.w_q,
&mut self.scratch.q,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(q_dim),
stream,
)?;
gemm_forward(
&self.scratch.norm1_out,
&self.w_k,
&mut self.scratch.k,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
gemm_forward(
&self.scratch.norm1_out,
&self.w_v,
&mut self.scratch.v,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
self.compute_attention_cuda(seq_len, stream)?;
gemm_forward(
&self.scratch.attn_out,
&self.w_o,
&mut self.scratch.o_proj_out,
saturating_u32(seq_len),
saturating_u32(q_dim),
saturating_u32(hidden_size),
stream,
)?;
cuda_add(
input,
&self.scratch.o_proj_out,
&mut self.scratch.residual1,
seq_len * hidden_size,
stream,
)?;
rms_norm_forward(
&self.scratch.residual1,
&self.post_attn_norm_weight,
&mut self.scratch.norm2_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
stream,
)?;
gemm_forward(
&self.scratch.norm2_out,
&self.w_gate,
&mut self.scratch.gate_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
gemm_forward(
&self.scratch.norm2_out,
&self.w_up,
&mut self.scratch.up_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
fused_swiglu_forward(
&self.scratch.gate_out,
&self.scratch.up_out,
&mut self.scratch.swiglu_out,
saturating_u32(seq_len * intermediate_size),
stream,
)?;
gemm_forward(
&self.scratch.swiglu_out,
&self.w_down,
&mut self.scratch.ffn_out,
saturating_u32(seq_len),
saturating_u32(intermediate_size),
saturating_u32(hidden_size),
stream,
)?;
cuda_add(
&self.scratch.residual1,
&self.scratch.ffn_out,
output,
seq_len * hidden_size,
stream,
)?;
Ok(())
}
fn compute_attention_cuda(&mut self, seq_len: usize, stream: &CudaStream) -> Result<()> {
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let heads_per_kv = num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let seq = saturating_u32(seq_len);
let nh = saturating_u32(num_heads);
let nkv = saturating_u32(num_kv_heads);
let hd = saturating_u32(head_dim);
self.scratch.prepare_causal_mask(seq_len, &self.ctx)?;
if let Some(ref q_norm) = self.q_norm_weight {
for pos in 0..seq_len {
let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
per_head_rmsnorm_forward(q_ref, q_norm, &mut self.scratch.q, nh, hd, pos, stream)?;
}
}
if let Some(ref k_norm) = self.k_norm_weight {
for pos in 0..seq_len {
let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
per_head_rmsnorm_forward(k_ref, k_norm, &mut self.scratch.k, nkv, hd, pos, stream)?;
}
}
let rope_theta = self.config.rope_theta;
{
let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.q)) };
batched_rope_neox_forward(
q_ref,
&mut self.scratch.q,
&self.scratch.rope_positions,
nh,
hd,
seq,
rope_theta,
stream,
)?;
let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.k)) };
batched_rope_neox_forward(
k_ref,
&mut self.scratch.k,
&self.scratch.rope_positions,
nkv,
hd,
seq,
rope_theta,
stream,
)?;
}
interleaved_to_batched_forward(
&self.scratch.q,
&mut self.scratch.attn_q_batched,
seq,
nh,
hd,
stream,
)?;
interleaved_to_batched_forward(
&self.scratch.k,
&mut self.scratch.attn_kv_temp,
seq,
nkv,
hd,
stream,
)?;
if heads_per_kv == 1 {
batched_transpose_forward(
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_kv_temp2,
nh,
seq,
hd,
stream,
)?;
} else {
expand_kv_heads(
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
batched_transpose_forward(
&self.scratch.attn_kv_temp2,
&mut self.scratch.attn_kv_temp,
nh,
seq,
hd,
stream,
)?;
unsafe {
self.scratch
.attn_kv_temp2
.copy_from_buffer_async(&self.scratch.attn_kv_temp, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"K^T buffer copy failed: {e}"
))
})?;
}
}
batched_4d_gemm_forward(
&self.scratch.attn_q_batched,
&self.scratch.attn_kv_temp2,
&mut self.scratch.attn_scores,
1,
nh,
seq,
seq,
hd,
stream,
)?;
let total_scores = nh * seq * seq;
{
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
self.scratch.attn_scores.as_ptr(),
self.scratch.attn_scores.len(),
)
};
scale_forward(
&scores_view,
&mut self.scratch.attn_scores,
scale,
total_scores,
stream,
)?;
leak(scores_view);
}
{
let seq_sq = (seq * seq) as usize;
let mask_ptr = self.scratch.causal_mask_contiguous.as_ptr();
let scores_base = self.scratch.attn_scores.as_ptr();
for head in 0..nh as usize {
let byte_offset = (head * seq_sq * 4) as u64; let head_ptr = scores_base + byte_offset;
let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
residual_add_forward(&mask_view, &scores_view, &mut out_view, seq * seq, stream)?;
leak(mask_view);
leak(scores_view);
leak(out_view);
}
}
let total_rows = nh * seq;
{
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
self.scratch.attn_scores.as_ptr(),
self.scratch.attn_scores.len(),
)
};
batched_softmax_forward(
&scores_view,
&mut self.scratch.attn_scores,
total_rows,
seq,
stream,
)?;
leak(scores_view);
}
interleaved_to_batched_forward(
&self.scratch.v,
&mut self.scratch.attn_kv_temp,
seq,
nkv,
hd,
stream,
)?;
if heads_per_kv == 1 {
} else {
expand_kv_heads(
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
unsafe {
self.scratch
.attn_kv_temp
.copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"V expanded buffer copy failed: {e}"
))
})?;
}
}
batched_4d_gemm_forward(
&self.scratch.attn_scores,
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_q_batched,
1,
nh,
seq,
hd,
seq,
stream,
)?;
batched_to_interleaved_forward(
&self.scratch.attn_q_batched,
&mut self.scratch.attn_out,
seq,
nh,
hd,
stream,
)?;
Ok(())
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
pub fn config(&self) -> &TransformerConfig {
&self.config
}
#[provable_contracts_macros::contract("backward-pass-v1", equation = "backward")]
pub fn backward(
&mut self,
input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
let hidden_size = self.config.hidden_size;
let intermediate_size = self.config.intermediate_size;
let eps = 1e-5_f32;
grad_ws.zero_norm_grads(&self.norm_zero_buf)?;
self.backward_ffn(grad_output, seq_len, hidden_size, intermediate_size, stream, grad_ws)?;
self.backward_post_attn_norm(grad_input, seq_len, hidden_size, eps, stream, grad_ws)?;
cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
self.backward_attention(grad_input, seq_len, stream, grad_ws)?;
self.backward_residual_and_input_norm(
input,
grad_output,
grad_input,
seq_len,
hidden_size,
eps,
stream,
grad_ws,
)?;
Ok(())
}
fn backward_ffn(
&mut self,
grad_output: &GpuBuffer<f32>,
seq_len: usize,
hidden_size: usize,
intermediate_size: usize,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
let n_inter = saturating_u32(seq_len * intermediate_size);
let n_hidden = saturating_u32(seq_len * hidden_size);
gemm_backward_a(
grad_output,
&self.w_down,
&mut self.scratch.grad_swiglu,
saturating_u32(seq_len),
saturating_u32(intermediate_size),
saturating_u32(hidden_size),
stream,
)?;
gemm_backward_b(
&self.scratch.swiglu_out,
grad_output,
&mut grad_ws.grad_down,
saturating_u32(seq_len),
saturating_u32(intermediate_size),
saturating_u32(hidden_size),
stream,
)?;
elementwise_mul_forward(
&self.scratch.grad_swiglu,
&self.scratch.up_out,
&mut self.scratch.swiglu_out,
n_inter,
stream,
)?;
silu_backward(
&self.scratch.gate_out,
&self.scratch.swiglu_out,
&mut self.scratch.up_out,
stream,
)?;
silu_forward(&self.scratch.gate_out, &mut self.scratch.swiglu_out, n_inter, stream)?;
elementwise_mul_forward(
&self.scratch.grad_swiglu,
&self.scratch.swiglu_out,
&mut self.scratch.gate_out,
n_inter,
stream,
)?;
gemm_backward_b(
&self.scratch.norm2_out,
&self.scratch.up_out,
&mut grad_ws.grad_gate,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
gemm_backward_b(
&self.scratch.norm2_out,
&self.scratch.gate_out,
&mut grad_ws.grad_up,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
gemm_backward_a(
&self.scratch.up_out,
&self.w_gate,
&mut self.scratch.ffn_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
gemm_backward_a(
&self.scratch.gate_out,
&self.w_up,
&mut self.scratch.grad_hidden,
saturating_u32(seq_len),
saturating_u32(hidden_size),
saturating_u32(intermediate_size),
stream,
)?;
residual_add_forward(
&self.scratch.ffn_out,
&self.scratch.grad_hidden,
&mut self.scratch.norm2_out,
n_hidden,
stream,
)?;
Ok(())
}
fn backward_post_attn_norm(
&mut self,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
hidden_size: usize,
eps: f32,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
unsafe {
self.scratch
.grad_hidden
.copy_from_buffer_async(&self.scratch.norm2_out, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Backward norm D2D copy failed: {e}"
))
})?;
}
rms_norm_backward(
&self.scratch.residual1,
&self.post_attn_norm_weight,
&self.scratch.grad_hidden,
grad_input,
&mut grad_ws.grad_post_attn_norm,
saturating_u32(seq_len),
saturating_u32(hidden_size),
eps,
stream,
)
}
fn backward_attention(
&mut self,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
let hidden_size = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let heads_per_kv = num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let seq = saturating_u32(seq_len);
let nh = saturating_u32(num_heads);
let nkv = saturating_u32(num_kv_heads);
let hd = saturating_u32(head_dim);
gemm_backward_a(
grad_input,
&self.w_o,
&mut self.scratch.grad_hidden,
seq,
saturating_u32(q_dim),
saturating_u32(hidden_size),
stream,
)?;
gemm_backward_b(
&self.scratch.attn_out,
grad_input,
&mut grad_ws.grad_w_o,
seq,
saturating_u32(q_dim),
saturating_u32(hidden_size),
stream,
)?;
interleaved_to_batched_forward(
&self.scratch.grad_hidden,
&mut self.scratch.attn_q_batched,
seq,
nh,
hd,
stream,
)?;
interleaved_to_batched_forward(
&self.scratch.v,
&mut self.scratch.attn_kv_temp,
seq,
nkv,
hd,
stream,
)?;
if heads_per_kv > 1 {
expand_kv_heads(
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
unsafe {
self.scratch
.attn_kv_temp
.copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Attn backward V expand D2D copy failed: {e}"
))
})?;
}
}
batched_transpose_forward(
&self.scratch.attn_kv_temp,
&mut self.scratch.attn_kv_temp2,
nh,
seq,
hd,
stream,
)?;
batched_4d_gemm_forward(
&self.scratch.attn_q_batched,
&self.scratch.attn_kv_temp2,
&mut self.scratch.grad_attn_scores,
1,
nh,
seq,
seq,
hd,
stream,
)?;
batched_transpose_forward(
&self.scratch.attn_q_batched, &mut self.scratch.attn_kv_temp, nh,
seq,
hd,
stream,
)?;
batched_4d_gemm_forward(
&self.scratch.attn_kv_temp, &self.scratch.attn_scores, &mut self.scratch.attn_kv_temp2, 1,
nh,
hd, seq, seq, stream,
)?;
batched_transpose_forward(
&self.scratch.attn_kv_temp2, &mut self.scratch.attn_kv_temp, nh,
hd,
seq,
stream,
)?;
let total_rows = nh * seq;
{
let grad_scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
self.scratch.grad_attn_scores.as_ptr(),
self.scratch.grad_attn_scores.len(),
)
};
batched_softmax_backward(
&self.scratch.attn_scores,
&grad_scores_view,
&mut self.scratch.grad_attn_scores,
total_rows,
seq,
stream,
)?;
leak(grad_scores_view);
}
let total_scores = nh * seq * seq;
{
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
self.scratch.grad_attn_scores.as_ptr(),
self.scratch.grad_attn_scores.len(),
)
};
scale_forward(
&scores_view,
&mut self.scratch.grad_attn_scores,
scale,
total_scores,
stream,
)?;
leak(scores_view);
}
interleaved_to_batched_forward(
&self.scratch.k,
&mut self.scratch.attn_kv_temp2,
seq,
nkv,
hd,
stream,
)?;
if heads_per_kv > 1 {
unsafe {
self.scratch
.attn_q_batched
.copy_from_buffer_async(&self.scratch.attn_kv_temp2, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Attn backward K copy for GQA expand failed: {e}"
))
})?;
}
expand_kv_heads(
&self.scratch.attn_q_batched,
&mut self.scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
}
batched_4d_gemm_forward(
&self.scratch.grad_attn_scores,
&self.scratch.attn_kv_temp2,
&mut self.scratch.attn_q_batched,
1,
nh,
seq,
hd,
seq,
stream,
)?;
interleaved_to_batched_forward(
&self.scratch.q,
&mut self.scratch.o_proj_out, seq,
nh,
hd,
stream,
)?;
batched_transpose_forward(
&self.scratch.o_proj_out,
&mut self.scratch.attn_kv_temp2, nh,
seq,
hd,
stream,
)?;
batched_4d_gemm_forward(
&self.scratch.attn_kv_temp2,
&self.scratch.grad_attn_scores,
&mut self.scratch.ffn_out, 1,
nh,
hd,
seq,
seq,
stream,
)?;
batched_transpose_forward(
&self.scratch.ffn_out,
&mut self.scratch.attn_kv_temp2, nh,
hd,
seq,
stream,
)?;
if heads_per_kv > 1 {
self.reduce_gqa_gradients(num_kv_heads, heads_per_kv, seq_len, head_dim, stream)?;
}
batched_to_interleaved_forward(
&self.scratch.attn_q_batched,
&mut self.scratch.o_proj_out,
seq,
nh,
hd,
stream,
)?;
batched_to_interleaved_forward(
&self.scratch.attn_kv_temp2,
&mut self.scratch.norm2_out,
seq,
nkv,
hd,
stream,
)?;
batched_to_interleaved_forward(
&self.scratch.attn_kv_temp,
&mut self.scratch.ffn_out,
seq,
nkv,
hd,
stream,
)?;
let rope_theta = self.config.rope_theta;
{
let q_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.o_proj_out)) };
batched_rope_neox_backward(
q_ref,
&mut self.scratch.o_proj_out,
&self.scratch.rope_positions,
nh,
hd,
seq,
rope_theta,
stream,
)?;
let k_ref = unsafe { &*(std::ptr::addr_of!(self.scratch.norm2_out)) };
batched_rope_neox_backward(
k_ref,
&mut self.scratch.norm2_out,
&self.scratch.rope_positions,
nkv,
hd,
seq,
rope_theta,
stream,
)?;
}
gemm_backward_a(
&self.scratch.o_proj_out, &self.w_q,
&mut self.scratch.grad_hidden,
seq,
saturating_u32(hidden_size),
saturating_u32(q_dim),
stream,
)?;
gemm_backward_a(
&self.scratch.norm2_out, &self.w_k,
&mut self.scratch.grad_attn_scores, seq,
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
cuda_add_inplace(
&mut self.scratch.grad_hidden,
&self.scratch.grad_attn_scores,
seq_len * hidden_size,
stream,
)?;
gemm_backward_a(
&self.scratch.ffn_out, &self.w_v,
&mut self.scratch.grad_attn_scores, seq,
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
cuda_add_inplace(
&mut self.scratch.grad_hidden,
&self.scratch.grad_attn_scores,
seq_len * hidden_size,
stream,
)?;
gemm_backward_b(
&self.scratch.norm1_out,
&self.scratch.o_proj_out, &mut grad_ws.grad_w_q,
seq,
saturating_u32(hidden_size),
saturating_u32(q_dim),
stream,
)?;
gemm_backward_b(
&self.scratch.norm1_out,
&self.scratch.norm2_out, &mut grad_ws.grad_w_k,
seq,
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
gemm_backward_b(
&self.scratch.norm1_out,
&self.scratch.ffn_out, &mut grad_ws.grad_w_v,
seq,
saturating_u32(hidden_size),
saturating_u32(kv_hidden_size),
stream,
)?;
unsafe {
grad_input.copy_from_buffer_async(&self.scratch.grad_hidden, stream).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Attn backward grad_hidden → grad_input D2D copy failed: {e}"
))
})?;
}
Ok(())
}
fn reduce_gqa_gradients(
&mut self,
num_kv_heads: usize,
heads_per_kv: usize,
seq_len: usize,
head_dim: usize,
stream: &CudaStream,
) -> Result<()> {
let elems_per_head = seq_len * head_dim;
self.reduce_single_gqa_gradient(true, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
self.reduce_single_gqa_gradient(false, num_kv_heads, heads_per_kv, elems_per_head, stream)?;
let kv_elems = num_kv_heads * elems_per_head;
unsafe {
self.scratch
.attn_kv_temp2
.copy_from_buffer_at_async(&self.scratch.grad_attn_scores, 0, 0, kv_elems, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA grad_K reduced final copy failed: {e}"
))
})?;
self.scratch
.attn_kv_temp
.copy_from_buffer_at_async(&self.scratch.ffn_out, 0, 0, kv_elems, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA grad_V reduced final copy failed: {e}"
))
})?;
}
Ok(())
}
fn reduce_single_gqa_gradient(
&mut self,
is_k: bool,
num_kv_heads: usize,
heads_per_kv: usize,
elems_per_head: usize,
stream: &CudaStream,
) -> Result<()> {
let label = if is_k { "K" } else { "V" };
for kv_h in 0..num_kv_heads {
let dst_offset = kv_h * elems_per_head;
let first_h = kv_h * heads_per_kv;
let src_offset = first_h * elems_per_head;
unsafe {
let (dst, src) = if is_k {
(&mut self.scratch.grad_attn_scores, &self.scratch.attn_kv_temp2)
} else {
(&mut self.scratch.ffn_out, &self.scratch.attn_kv_temp)
};
dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA grad_{label} reduce base copy failed: {e}"
))
})?;
}
for rep in 1..heads_per_kv {
let h = kv_h * heads_per_kv + rep;
let h_offset = h * elems_per_head;
unsafe {
let src =
if is_k { &self.scratch.attn_kv_temp2 } else { &self.scratch.attn_kv_temp };
self.scratch
.o_proj_out
.copy_from_buffer_at_async(src, 0, h_offset, elems_per_head, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA grad_{label} reduce head copy failed: {e}"
))
})?;
}
unsafe {
let dst_buf =
if is_k { &self.scratch.grad_attn_scores } else { &self.scratch.ffn_out };
let dst_view = GpuBuffer::<f32>::from_raw_parts(
dst_buf.as_ptr() + (dst_offset as u64 * 4),
elems_per_head,
);
let src_view = GpuBuffer::<f32>::from_raw_parts(
self.scratch.o_proj_out.as_ptr(),
elems_per_head,
);
let mut sum_view = GpuBuffer::<f32>::from_raw_parts(
self.scratch.grad_hidden.as_ptr(),
elems_per_head,
);
residual_add_forward(
&dst_view,
&src_view,
&mut sum_view,
saturating_u32(elems_per_head),
stream,
)?;
let dst_buf = if is_k {
&mut self.scratch.grad_attn_scores
} else {
&mut self.scratch.ffn_out
};
dst_buf
.copy_from_buffer_at_async(
&self.scratch.grad_hidden,
dst_offset,
0,
elems_per_head,
stream,
)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA grad_{label} reduce sum copy failed: {e}"
))
})?;
leak(dst_view);
leak(src_view);
leak(sum_view);
}
}
}
Ok(())
}
fn backward_residual_and_input_norm(
&mut self,
input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
hidden_size: usize,
eps: f32,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
unsafe {
self.scratch.grad_hidden.copy_from_buffer_async(grad_input, stream).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Backward residual grad_hidden D2D copy failed: {e}"
))
})?;
}
rms_norm_backward(
input,
&self.input_norm_weight,
&self.scratch.grad_hidden,
grad_input,
&mut grad_ws.grad_input_norm,
saturating_u32(seq_len),
saturating_u32(hidden_size),
eps,
stream,
)?;
cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)
}
pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
let hidden = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv_hidden = self.config.num_kv_heads * self.config.head_dim();
let intermediate = self.config.intermediate_size;
let z = |n: usize| -> Result<GpuBuffer<f32>> {
Ok(GpuBuffer::from_host(&self.ctx, &vec![0.0f32; n])?)
};
Ok(GpuBlockOptimizerState {
m_w_q: z(q_dim * hidden)?,
v_w_q: z(q_dim * hidden)?,
m_w_k: z(hidden * kv_hidden)?,
v_w_k: z(hidden * kv_hidden)?,
m_w_v: z(hidden * kv_hidden)?,
v_w_v: z(hidden * kv_hidden)?,
m_w_o: z(hidden * q_dim)?,
v_w_o: z(hidden * q_dim)?,
m_w_gate: z(hidden * intermediate)?,
v_w_gate: z(hidden * intermediate)?,
m_w_up: z(hidden * intermediate)?,
v_w_up: z(hidden * intermediate)?,
m_w_down: z(intermediate * hidden)?,
v_w_down: z(intermediate * hidden)?,
m_input_norm: z(hidden)?,
v_input_norm: z(hidden)?,
m_post_attn_norm: z(hidden)?,
v_post_attn_norm: z(hidden)?,
})
}
pub fn optimizer_step(
&mut self,
state: &mut GpuBlockOptimizerState,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
stream: &CudaStream,
grad_ws: &CudaGradWorkspace,
) -> Result<()> {
debug_assert!(step > 0, "C-OPTSTEP-001: step must be > 0 for bias adjust");
let n_wq = self.w_q.len() as u32;
let n_wk = self.w_k.len() as u32;
let n_wv = self.w_v.len() as u32;
let n_wo = self.w_o.len() as u32;
let n_gate = self.w_gate.len() as u32;
let n_up = self.w_up.len() as u32;
let n_down = self.w_down.len() as u32;
let n_inorm = self.input_norm_weight.len() as u32;
let n_panorm = self.post_attn_norm_weight.len() as u32;
adamw_step_cuda(
&mut self.w_q,
&grad_ws.grad_w_q,
&mut state.m_w_q,
&mut state.v_w_q,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_wq,
stream,
)?;
adamw_step_cuda(
&mut self.w_k,
&grad_ws.grad_w_k,
&mut state.m_w_k,
&mut state.v_w_k,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_wk,
stream,
)?;
adamw_step_cuda(
&mut self.w_v,
&grad_ws.grad_w_v,
&mut state.m_w_v,
&mut state.v_w_v,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_wv,
stream,
)?;
adamw_step_cuda(
&mut self.w_o,
&grad_ws.grad_w_o,
&mut state.m_w_o,
&mut state.v_w_o,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_wo,
stream,
)?;
adamw_step_cuda(
&mut self.w_gate,
&grad_ws.grad_gate,
&mut state.m_w_gate,
&mut state.v_w_gate,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_gate,
stream,
)?;
adamw_step_cuda(
&mut self.w_up,
&grad_ws.grad_up,
&mut state.m_w_up,
&mut state.v_w_up,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_up,
stream,
)?;
adamw_step_cuda(
&mut self.w_down,
&grad_ws.grad_down,
&mut state.m_w_down,
&mut state.v_w_down,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_down,
stream,
)?;
adamw_step_cuda(
&mut self.input_norm_weight,
&grad_ws.grad_input_norm,
&mut state.m_input_norm,
&mut state.v_input_norm,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_inorm,
stream,
)?;
adamw_step_cuda(
&mut self.post_attn_norm_weight,
&grad_ws.grad_post_attn_norm,
&mut state.m_post_attn_norm,
&mut state.v_post_attn_norm,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
n_panorm,
stream,
)?;
Ok(())
}
pub fn download_weights(&self) -> Result<BlockWeights> {
let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
let mut host = vec![0.0f32; buf.len()];
buf.copy_to_host(&mut host).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"Weight download failed: {e}"
))
})?;
Ok(host)
};
Ok(BlockWeights {
w_q: download(&self.w_q)?,
w_k: download(&self.w_k)?,
w_v: download(&self.w_v)?,
w_o: download(&self.w_o)?,
w_gate: download(&self.w_gate)?,
w_up: download(&self.w_up)?,
w_down: download(&self.w_down)?,
input_norm_weight: download(&self.input_norm_weight)?,
post_attn_norm_weight: download(&self.post_attn_norm_weight)?,
})
}
}
#[cfg(feature = "cuda")]
pub struct BlockWeights {
pub w_q: Vec<f32>,
pub w_k: Vec<f32>,
pub w_v: Vec<f32>,
pub w_o: Vec<f32>,
pub w_gate: Vec<f32>,
pub w_up: Vec<f32>,
pub w_down: Vec<f32>,
pub input_norm_weight: Vec<f32>,
pub post_attn_norm_weight: Vec<f32>,
}
#[cfg(feature = "cuda")]
fn cuda_add(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
n: usize,
stream: &CudaStream,
) -> Result<()> {
residual_add_forward(a, b, output, saturating_u32(n), stream)
}
#[cfg(feature = "cuda")]
pub(crate) fn cuda_add_inplace(
target: &mut GpuBuffer<f32>,
source: &GpuBuffer<f32>,
n: usize,
stream: &CudaStream,
) -> Result<()> {
let target_ref: &GpuBuffer<f32> = unsafe { &*std::ptr::from_ref::<GpuBuffer<f32>>(target) };
residual_add_forward(target_ref, source, target, saturating_u32(n), stream)
}
#[cfg(feature = "cuda")]
fn cuda_mul(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
n: usize,
stream: &CudaStream,
) -> Result<()> {
crate::autograd::cuda_forward::elementwise_mul_forward(a, b, output, saturating_u32(n), stream)
}
#[cfg(not(feature = "cuda"))]
pub struct CudaTransformerBlock;
#[cfg(not(feature = "cuda"))]
impl CudaTransformerBlock {
pub fn layer_idx(&self) -> usize {
0
}
}
#[cfg(feature = "cuda")]
pub enum CudaBlock {
Fp32(CudaTransformerBlock),
Nf4(CudaNf4TransformerBlock),
}
#[cfg(feature = "cuda")]
impl CudaBlock {
pub(crate) fn forward(
&mut self,
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
shared_scratch: Option<&mut CudaBlockScratch>,
) -> Result<()> {
match self {
CudaBlock::Fp32(b) => b.forward(input, output, seq_len, stream),
CudaBlock::Nf4(b) => {
let scratch =
shared_scratch.expect("C-SCRATCH-001: NF4 blocks require shared scratch");
b.forward(input, output, seq_len, stream, scratch)
}
}
}
pub fn layer_idx(&self) -> usize {
match self {
CudaBlock::Fp32(b) => b.layer_idx(),
CudaBlock::Nf4(b) => b.layer_idx,
}
}
pub fn backward(
&mut self,
input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()> {
match self {
CudaBlock::Fp32(b) => {
b.backward(input, grad_output, grad_input, seq_len, stream, grad_ws)
}
CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"backward not supported on NF4 blocks (frozen weights)".into(),
)),
}
}
pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState> {
match self {
CudaBlock::Fp32(b) => b.init_optimizer_state(),
CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"init_optimizer_state not supported on NF4 blocks".into(),
)),
}
}
pub fn download_weights(&self) -> Result<BlockWeights> {
match self {
CudaBlock::Fp32(b) => b.download_weights(),
CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"download_weights not supported on NF4 blocks".into(),
)),
}
}
pub fn optimizer_step(
&mut self,
state: &mut GpuBlockOptimizerState,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
stream: &CudaStream,
grad_ws: &CudaGradWorkspace,
) -> Result<()> {
match self {
CudaBlock::Fp32(b) => {
b.optimizer_step(state, step, lr, beta1, beta2, eps, weight_decay, stream, grad_ws)
}
CudaBlock::Nf4(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"optimizer_step not supported on NF4 blocks (frozen weights)".into(),
)),
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn backward_nf4(
&self,
layer_input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
output_scratch: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
shared_scratch: &mut CudaBlockScratch,
grad_lora: &mut CudaLoraGradWorkspace,
) -> Result<()> {
match self {
CudaBlock::Nf4(b) => b.backward(
layer_input,
grad_output,
grad_input,
output_scratch,
seq_len,
stream,
shared_scratch,
grad_lora,
),
CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"backward_nf4 only supported on NF4 blocks".into(),
)),
}
}
pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
match self {
CudaBlock::Nf4(b) => b.init_lora_optimizer_state(),
CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"init_lora_optimizer_state only supported on NF4 blocks".into(),
)),
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn lora_optimizer_step(
&mut self,
state: &mut GpuLoraOptimizerState,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
stream: &CudaStream,
grad_lora: &CudaLoraGradWorkspace,
) -> Result<()> {
match self {
CudaBlock::Nf4(b) => b.lora_optimizer_step(
state,
step,
lr,
beta1,
beta2,
eps,
weight_decay,
stream,
grad_lora,
),
CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"lora_optimizer_step only supported on NF4 blocks".into(),
)),
}
}
pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
match self {
CudaBlock::Nf4(b) => b.download_lora_weights(),
CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"download_lora_weights only supported on NF4 blocks".into(),
)),
}
}
pub fn upload_lora_weights(
&mut self,
a_q: &[f32],
b_q: &[f32],
a_v: &[f32],
b_v: &[f32],
) -> Result<()> {
match self {
CudaBlock::Nf4(b) => b.upload_lora_weights(a_q, b_q, a_v, b_v),
CudaBlock::Fp32(_) => Err(crate::autograd::cuda_tensor::CudaTensorError::KernelError(
"upload_lora_weights only supported on NF4 blocks".into(),
)),
}
}
}
#[cfg(not(feature = "cuda"))]
pub enum CudaBlock {
Fp32(CudaTransformerBlock),
}
#[cfg(feature = "cuda")]
pub struct CudaNf4TransformerBlock {
config: TransformerConfig,
layer_idx: usize,
input_norm_weight: GpuBuffer<f32>,
post_attn_norm_weight: GpuBuffer<f32>,
w_q_nf4: GpuBuffer<u8>,
w_q_scales: GpuBuffer<f32>,
w_k_nf4: GpuBuffer<u8>,
w_k_scales: GpuBuffer<f32>,
w_v_nf4: GpuBuffer<u8>,
w_v_scales: GpuBuffer<f32>,
w_o_nf4: GpuBuffer<u8>,
w_o_scales: GpuBuffer<f32>,
w_gate_nf4: GpuBuffer<u8>,
w_gate_scales: GpuBuffer<f32>,
w_up_nf4: GpuBuffer<u8>,
w_up_scales: GpuBuffer<f32>,
w_down_nf4: GpuBuffer<u8>,
w_down_scales: GpuBuffer<f32>,
w_q_fp32: GpuBuffer<f32>,
w_k_fp32: GpuBuffer<f32>,
w_v_fp32: GpuBuffer<f32>,
w_o_fp32: GpuBuffer<f32>,
w_gate_fp32: GpuBuffer<f32>,
w_up_fp32: GpuBuffer<f32>,
w_down_fp32: GpuBuffer<f32>,
lora_a_q: Option<GpuBuffer<f32>>, lora_b_q: Option<GpuBuffer<f32>>, lora_a_v: Option<GpuBuffer<f32>>, lora_b_v: Option<GpuBuffer<f32>>, lora_scale: f32,
lora_rank: usize,
q_norm_weight: Option<GpuBuffer<f32>>,
k_norm_weight: Option<GpuBuffer<f32>>,
w_q_fp16: Option<GpuBuffer<u16>>,
w_k_fp16: Option<GpuBuffer<u16>>,
w_v_fp16: Option<GpuBuffer<u16>>,
w_o_fp16: Option<GpuBuffer<u16>>,
w_gate_fp16: Option<GpuBuffer<u16>>,
w_up_fp16: Option<GpuBuffer<u16>>,
w_down_fp16: Option<GpuBuffer<u16>>,
ctx: Arc<CudaContext>,
}
#[cfg(feature = "cuda")]
impl CudaNf4TransformerBlock {
#[allow(clippy::too_many_arguments)]
pub fn new(
config: &TransformerConfig,
layer_idx: usize,
ctx: Arc<CudaContext>,
input_norm_weight: &[f32],
post_attn_norm_weight: &[f32],
w_q: &[f32],
w_k: &[f32],
w_v: &[f32],
w_o: &[f32],
w_gate: &[f32],
w_up: &[f32],
w_down: &[f32],
_max_seq_len: usize, q_lora: Option<(&[f32], &[f32])>,
v_lora: Option<(&[f32], &[f32])>,
lora_scale: f32,
lora_rank: usize,
q_norm: Option<&[f32]>,
k_norm: Option<&[f32]>,
) -> Result<Self> {
use trueno_gpu::kernels::{quantize_nf4, NF4_BLOCK_SIZE};
let hidden_size = config.hidden_size;
let q_dim = config.q_dim(); let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
assert_eq!(
w_q.len(),
q_dim * hidden_size,
"C-NF4SHAPE-001: w_q expected {}, got {} (q_dim={q_dim}, hidden={hidden_size})",
q_dim * hidden_size,
w_q.len()
);
assert_eq!(
w_k.len(),
kv_hidden_size * hidden_size,
"C-NF4SHAPE-001: w_k expected {}, got {}",
kv_hidden_size * hidden_size,
w_k.len()
);
assert_eq!(
w_v.len(),
kv_hidden_size * hidden_size,
"C-NF4SHAPE-001: w_v expected {}, got {}",
kv_hidden_size * hidden_size,
w_v.len()
);
assert_eq!(
w_o.len(),
hidden_size * q_dim,
"C-NF4SHAPE-001: w_o expected {}, got {}",
hidden_size * q_dim,
w_o.len()
);
assert_eq!(
w_gate.len(),
intermediate_size * hidden_size,
"C-NF4SHAPE-001: w_gate expected {}, got {}",
intermediate_size * hidden_size,
w_gate.len()
);
assert_eq!(
w_up.len(),
intermediate_size * hidden_size,
"C-NF4SHAPE-001: w_up expected {}, got {}",
intermediate_size * hidden_size,
w_up.len()
);
assert_eq!(
w_down.len(),
hidden_size * intermediate_size,
"C-NF4SHAPE-001: w_down expected {}, got {}",
hidden_size * intermediate_size,
w_down.len()
);
let input_norm_weight = GpuBuffer::from_host(&ctx, input_norm_weight)?;
let post_attn_norm_weight = GpuBuffer::from_host(&ctx, post_attn_norm_weight)?;
let quantize_and_upload = |weights: &[f32],
total: usize|
-> Result<(
GpuBuffer<u8>,
GpuBuffer<f32>,
trueno_gpu::kernels::Nf4Quantized,
)> {
assert_eq!(weights.len(), total, "weight length mismatch");
assert!(
total.is_multiple_of(NF4_BLOCK_SIZE),
"weight count {total} not divisible by NF4 block size {NF4_BLOCK_SIZE}"
);
let q = quantize_nf4(weights, total / NF4_BLOCK_SIZE, NF4_BLOCK_SIZE);
let nf4_buf = GpuBuffer::from_host(&ctx, &q.data)?;
let scales_buf = GpuBuffer::from_host(&ctx, &q.scales)?;
Ok((nf4_buf, scales_buf, q))
};
let (w_q_nf4, w_q_scales, w_q_nf4_q) = quantize_and_upload(w_q, q_dim * hidden_size)?;
let (w_k_nf4, w_k_scales, w_k_nf4_q) =
quantize_and_upload(w_k, kv_hidden_size * hidden_size)?;
let (w_v_nf4, w_v_scales, w_v_nf4_q) =
quantize_and_upload(w_v, kv_hidden_size * hidden_size)?;
let (w_o_nf4, w_o_scales, w_o_nf4_q) = quantize_and_upload(w_o, hidden_size * q_dim)?;
let (w_gate_nf4, w_gate_scales, w_gate_nf4_q) =
quantize_and_upload(w_gate, intermediate_size * hidden_size)?;
let (w_up_nf4, w_up_scales, w_up_nf4_q) =
quantize_and_upload(w_up, intermediate_size * hidden_size)?;
let (w_down_nf4, w_down_scales, w_down_nf4_q) =
quantize_and_upload(w_down, hidden_size * intermediate_size)?;
use trueno_gpu::kernels::dequantize_nf4;
let dequant_transpose_upload = |q: &trueno_gpu::kernels::Nf4Quantized,
n: usize,
k: usize|
-> std::result::Result<
GpuBuffer<f32>,
crate::autograd::cuda_tensor::CudaTensorError,
> {
let deq = dequantize_nf4(q); let nonzero = deq.iter().filter(|&&x| x != 0.0).count();
eprintln!(
"[TRACE] dequant n={n} k={k} len={} nonzero={nonzero} first5={:?}",
deq.len(),
&deq[..5.min(deq.len())]
);
assert_eq!(deq.len(), n * k, "dequant size mismatch: {} vs {}x{}", deq.len(), n, k);
let mut transposed = vec![0.0f32; n * k];
for row in 0..n {
for col in 0..k {
transposed[col * n + row] = deq[row * k + col];
}
}
let buf = GpuBuffer::from_host(&ctx, &transposed).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"dequant transpose upload: {e:?}"
))
})?;
let mut verify_full = vec![0.0f32; buf.len()];
let verify_ok = buf.copy_to_host(&mut verify_full).is_ok();
let verify5: Vec<f32> = verify_full.iter().copied().take(5).collect();
let nz = verify_full.iter().filter(|&&x| x != 0.0).count();
eprintln!("[TRACE] uploaded ptr={:?} len={} copy_ok={verify_ok} nonzero={nz} verify[:5]={verify5:?}", buf.as_ptr(), buf.len());
Ok(buf)
};
let w_q_fp32 = dequant_transpose_upload(&w_q_nf4_q, q_dim, hidden_size)?;
let w_k_fp32 = dequant_transpose_upload(&w_k_nf4_q, kv_hidden_size, hidden_size)?;
let w_v_fp32 = dequant_transpose_upload(&w_v_nf4_q, kv_hidden_size, hidden_size)?;
let w_o_fp32 = dequant_transpose_upload(&w_o_nf4_q, hidden_size, q_dim)?;
let w_gate_fp32 = dequant_transpose_upload(&w_gate_nf4_q, intermediate_size, hidden_size)?;
let w_up_fp32 = dequant_transpose_upload(&w_up_nf4_q, intermediate_size, hidden_size)?;
let w_down_fp32 = dequant_transpose_upload(&w_down_nf4_q, hidden_size, intermediate_size)?;
let (lora_a_q, lora_b_q) = match q_lora {
Some((a_data, b_data)) => {
let a = GpuBuffer::from_host(&ctx, a_data)?;
let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
(Some(a), Some(b))
}
None => (None, None),
};
let (lora_a_v, lora_b_v) = match v_lora {
Some((a_data, b_data)) => {
let a = GpuBuffer::from_host(&ctx, a_data)?;
let scaled_b: Vec<f32> = b_data.iter().map(|&v| v * lora_scale).collect();
let b = GpuBuffer::from_host(&ctx, &scaled_b)?;
(Some(a), Some(b))
}
None => (None, None),
};
let q_norm_weight = match q_norm {
Some(w) => {
assert_eq!(
w.len(),
config.head_dim(),
"ENT-270: q_norm weight expected [head_dim={}], got [{}]",
config.head_dim(),
w.len()
);
Some(GpuBuffer::from_host(&ctx, w)?)
}
None => None,
};
let k_norm_weight = match k_norm {
Some(w) => {
assert_eq!(
w.len(),
config.head_dim(),
"ENT-270: k_norm weight expected [head_dim={}], got [{}]",
config.head_dim(),
w.len()
);
Some(GpuBuffer::from_host(&ctx, w)?)
}
None => None,
};
Ok(Self {
config: config.clone(),
layer_idx,
input_norm_weight,
post_attn_norm_weight,
w_q_nf4,
w_q_scales,
w_k_nf4,
w_k_scales,
w_v_nf4,
w_v_scales,
w_o_nf4,
w_o_scales,
w_gate_nf4,
w_gate_scales,
w_up_nf4,
w_up_scales,
w_down_nf4,
w_down_scales,
w_q_fp32,
w_k_fp32,
w_v_fp32,
w_o_fp32,
w_gate_fp32,
w_up_fp32,
w_down_fp32,
lora_a_q,
lora_b_q,
lora_a_v,
lora_b_v,
lora_scale,
lora_rank,
q_norm_weight,
k_norm_weight,
w_q_fp16: None,
w_k_fp16: None,
w_v_fp16: None,
w_o_fp16: None,
w_gate_fp16: None,
w_up_fp16: None,
w_down_fp16: None,
ctx,
})
}
pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()> {
let cast_weight = |w_fp32: &GpuBuffer<f32>, ctx: &CudaContext| -> Result<GpuBuffer<u16>> {
let n = w_fp32.len();
let mut w_fp16 = GpuBuffer::<u16>::new(ctx, n)?;
cast_f32_to_f16_gpu(w_fp32, &mut w_fp16, n as u32, stream)?;
Ok(w_fp16)
};
self.w_q_fp16 = Some(cast_weight(&self.w_q_fp32, &self.ctx)?);
self.w_k_fp16 = Some(cast_weight(&self.w_k_fp32, &self.ctx)?);
self.w_v_fp16 = Some(cast_weight(&self.w_v_fp32, &self.ctx)?);
self.w_o_fp16 = Some(cast_weight(&self.w_o_fp32, &self.ctx)?);
self.w_gate_fp16 = Some(cast_weight(&self.w_gate_fp32, &self.ctx)?);
self.w_up_fp16 = Some(cast_weight(&self.w_up_fp32, &self.ctx)?);
self.w_down_fp16 = Some(cast_weight(&self.w_down_fp32, &self.ctx)?);
stream.synchronize().map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::KernelError(format!(
"FP16 weight cast sync failed: {e:?}"
))
})?;
let dummy = |ctx: &CudaContext| GpuBuffer::<f32>::new(ctx, 1).unwrap();
self.w_q_fp32 = dummy(&self.ctx);
self.w_k_fp32 = dummy(&self.ctx);
self.w_v_fp32 = dummy(&self.ctx);
self.w_o_fp32 = dummy(&self.ctx);
self.w_gate_fp32 = dummy(&self.ctx);
self.w_up_fp32 = dummy(&self.ctx);
self.w_down_fp32 = dummy(&self.ctx);
eprintln!("[FP16] Weights cast + fp32 dropped (~2.6 GB freed)");
Ok(())
}
#[rustfmt::skip]
pub(crate) fn forward(
&self,
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
) -> Result<()> {
use crate::autograd::cuda_forward::{gemm_forward, gemm_nf4_forward, gemm_nf4_tc_forward};
let hidden_size = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
let intermediate_size = self.config.intermediate_size;
scratch.prepare_causal_mask(seq_len, &self.ctx)?;
let _t = scratch.op_begin();
rms_norm_forward(
input,
&self.input_norm_weight,
&mut scratch.norm1_out,
saturating_u32(seq_len),
saturating_u32(hidden_size),
stream,
)?;
scratch.op_end(_t, OP_RMSNORM_ATTN);
static USE_NF4_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let nf4_gemm = *USE_NF4_GEMM.get_or_init(|| std::env::var("NF4_FUSED_GEMM").as_deref() == Ok("1"));
static USE_NF4_TC_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let nf4_tc_gemm = *USE_NF4_TC_GEMM.get_or_init(|| std::env::var("NF4_TC_GEMM").as_deref() == Ok("1"));
static USE_FP16_GEMM: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let fp16_gemm = *USE_FP16_GEMM.get_or_init(|| std::env::var("FP16_GEMM").as_deref() == Ok("1"));
let act_n = (seq_len * hidden_size) as u32;
if fp16_gemm && self.w_q_fp16.is_some() {
if scratch.norm1_out_f16.is_none() {
scratch.norm1_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
}
let f16_buf = scratch.norm1_out_f16.as_mut().unwrap();
cast_f32_to_f16_gpu(&scratch.norm1_out, f16_buf, act_n, stream)?;
}
let _t = scratch.op_begin(); if fp16_gemm && self.w_q_fp16.is_some() {
let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
gemm_f16_to_f32_forward(f16_act, self.w_q_fp16.as_ref().unwrap(), &mut scratch.q,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
} else if nf4_tc_gemm {
gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
} else if nf4_gemm {
gemm_nf4_forward(&scratch.norm1_out, &self.w_q_nf4, &self.w_q_scales, &mut scratch.q,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
} else {
gemm_forward(&scratch.norm1_out, &self.w_q_fp32, &mut scratch.q,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(q_dim), stream)?;
}
if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
let s = saturating_u32(seq_len);
let h = saturating_u32(hidden_size);
let r = saturating_u32(self.lora_rank);
let qd = saturating_u32(q_dim);
gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
gemm_forward(&scratch.lora_inter, b_q, &mut scratch.lora_temp, s, r, qd, stream)?;
cuda_add_inplace(&mut scratch.q, &scratch.lora_temp, seq_len * q_dim, stream)?;
}
if fp16_gemm && self.w_k_fp16.is_some() {
let f16_act = scratch.norm1_out_f16.as_ref().unwrap();
gemm_f16_to_f32_forward(f16_act, self.w_k_fp16.as_ref().unwrap(), &mut scratch.k,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
gemm_f16_to_f32_forward(f16_act, self.w_v_fp16.as_ref().unwrap(), &mut scratch.v,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
} else if nf4_tc_gemm {
gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_k_nf4, &self.w_k_scales, &mut scratch.k,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
gemm_nf4_tc_forward(&scratch.norm1_out, &self.w_v_nf4, &self.w_v_scales, &mut scratch.v,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
} else if nf4_gemm {
crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
&scratch.norm1_out,
&self.w_k_nf4, &self.w_k_scales,
&self.w_v_nf4, &self.w_v_scales,
&mut scratch.k, &mut scratch.v,
saturating_u32(seq_len), saturating_u32(hidden_size),
saturating_u32(kv_hidden_size), stream,
)?;
} else {
gemm_forward(&scratch.norm1_out, &self.w_k_fp32, &mut scratch.k,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
gemm_forward(&scratch.norm1_out, &self.w_v_fp32, &mut scratch.v,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(kv_hidden_size), stream)?;
}
scratch.op_end(_t, OP_QKV_GEMM);
if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
let s = saturating_u32(seq_len);
let h = saturating_u32(hidden_size);
let r = saturating_u32(self.lora_rank);
let vd = saturating_u32(kv_hidden_size);
gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
gemm_forward(&scratch.lora_inter, b_v, &mut scratch.lora_temp, s, r, vd, stream)?;
cuda_add_inplace(&mut scratch.v, &scratch.lora_temp, seq_len * kv_hidden_size, stream)?;
}
let _t = scratch.op_begin();
self.compute_attention_cuda(seq_len, stream, scratch)?;
scratch.op_end(_t, OP_ATTENTION);
let _t = scratch.op_begin();
if fp16_gemm && self.w_o_fp16.is_some() {
if scratch.attn_out_f16.is_none() {
scratch.attn_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * q_dim)?);
}
let f16_buf = scratch.attn_out_f16.as_mut().unwrap();
cast_f32_to_f16_gpu(&scratch.attn_out, f16_buf, (seq_len * q_dim) as u32, stream)?;
gemm_f16_to_f32_forward(f16_buf, self.w_o_fp16.as_ref().unwrap(), &mut scratch.o_proj_out,
saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
} else if nf4_tc_gemm {
gemm_nf4_tc_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
} else if nf4_gemm {
gemm_nf4_forward(&scratch.attn_out, &self.w_o_nf4, &self.w_o_scales, &mut scratch.o_proj_out,
saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
} else {
gemm_forward(&scratch.attn_out, &self.w_o_fp32, &mut scratch.o_proj_out,
saturating_u32(seq_len), saturating_u32(q_dim), saturating_u32(hidden_size), stream)?;
}
scratch.op_end(_t, OP_O_PROJ);
let _t = scratch.op_begin();
fused_residual_rmsnorm_forward(
input,
&scratch.o_proj_out,
&mut scratch.residual1,
&mut scratch.norm2_out,
&self.post_attn_norm_weight,
saturating_u32(seq_len),
saturating_u32(hidden_size),
stream,
)?;
scratch.op_end(_t, OP_RMSNORM_FFN);
let _t = scratch.op_begin(); if fp16_gemm && self.w_gate_fp16.is_some() {
if scratch.norm2_out_f16.is_none() {
scratch.norm2_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * hidden_size)?);
}
let f16_buf = scratch.norm2_out_f16.as_mut().unwrap();
cast_f32_to_f16_gpu(&scratch.norm2_out, f16_buf, (seq_len * hidden_size) as u32, stream)?;
gemm_f16_to_f32_forward(f16_buf, self.w_gate_fp16.as_ref().unwrap(), &mut scratch.gate_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
gemm_f16_to_f32_forward(f16_buf, self.w_up_fp16.as_ref().unwrap(), &mut scratch.up_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
} else if nf4_tc_gemm {
gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_gate_nf4, &self.w_gate_scales, &mut scratch.gate_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
gemm_nf4_tc_forward(&scratch.norm2_out, &self.w_up_nf4, &self.w_up_scales, &mut scratch.up_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
} else if nf4_gemm {
crate::autograd::cuda_forward::gemm_nf4_gate_up_forward(
&scratch.norm2_out,
&self.w_gate_nf4, &self.w_gate_scales,
&self.w_up_nf4, &self.w_up_scales,
&mut scratch.gate_out, &mut scratch.up_out,
saturating_u32(seq_len), saturating_u32(hidden_size),
saturating_u32(intermediate_size), stream,
)?;
} else {
gemm_forward(&scratch.norm2_out, &self.w_gate_fp32, &mut scratch.gate_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
gemm_forward(&scratch.norm2_out, &self.w_up_fp32, &mut scratch.up_out,
saturating_u32(seq_len), saturating_u32(hidden_size), saturating_u32(intermediate_size), stream)?;
}
scratch.op_end(_t, OP_GATE_UP_GEMM);
let _t = scratch.op_begin();
fused_swiglu_forward(&scratch.gate_out, &scratch.up_out, &mut scratch.swiglu_out,
saturating_u32(seq_len * intermediate_size), stream)?;
scratch.op_end(_t, OP_SILU);
let _t = scratch.op_begin();
if fp16_gemm && self.w_down_fp16.is_some() {
if scratch.swiglu_out_f16.is_none() {
scratch.swiglu_out_f16 = Some(GpuBuffer::new(&self.ctx, seq_len * intermediate_size)?);
}
let f16_buf = scratch.swiglu_out_f16.as_mut().unwrap();
cast_f32_to_f16_gpu(&scratch.swiglu_out, f16_buf, (seq_len * intermediate_size) as u32, stream)?;
gemm_f16_to_f32_forward(f16_buf, self.w_down_fp16.as_ref().unwrap(), &mut scratch.ffn_out,
saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
} else if nf4_tc_gemm {
gemm_nf4_tc_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
} else if nf4_gemm {
gemm_nf4_forward(&scratch.swiglu_out, &self.w_down_nf4, &self.w_down_scales, &mut scratch.ffn_out,
saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
} else {
gemm_forward(&scratch.swiglu_out, &self.w_down_fp32, &mut scratch.ffn_out,
saturating_u32(seq_len), saturating_u32(intermediate_size), saturating_u32(hidden_size), stream)?;
}
scratch.op_end(_t, OP_DOWN_GEMM);
cuda_add(&scratch.residual1, &scratch.ffn_out, output, seq_len * hidden_size, stream)?;
Ok(())
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
}
#[cfg(feature = "cuda")]
impl CudaNf4TransformerBlock {
fn compute_attention_cuda(
&self,
seq_len: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
) -> Result<()> {
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim();
let heads_per_kv = num_heads / num_kv_heads;
let s = saturating_u32(seq_len);
let nh = saturating_u32(num_heads);
let nkv = saturating_u32(num_kv_heads);
let hd = saturating_u32(head_dim);
if let Some(ref q_norm) = self.q_norm_weight {
for pos in 0..seq_len {
let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
per_head_rmsnorm_forward(q_ref, q_norm, &mut scratch.q, nh, hd, pos, stream)?;
}
}
if let Some(ref k_norm) = self.k_norm_weight {
for pos in 0..seq_len {
let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
per_head_rmsnorm_forward(k_ref, k_norm, &mut scratch.k, nkv, hd, pos, stream)?;
}
}
let rope_theta = self.config.rope_theta;
{
let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
batched_rope_neox_forward(
q_ref,
&mut scratch.q,
&scratch.rope_positions,
nh,
hd,
s,
rope_theta,
stream,
)?;
let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
batched_rope_neox_forward(
k_ref,
&mut scratch.k,
&scratch.rope_positions,
nkv,
hd,
s,
rope_theta,
stream,
)?;
}
interleaved_to_batched_forward(&scratch.q, &mut scratch.attn_q_batched, s, nh, hd, stream)?;
interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
if heads_per_kv > 1 {
expand_kv_heads(
&scratch.attn_kv_temp,
&mut scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
} else {
unsafe {
scratch
.attn_kv_temp2
.copy_from_buffer_async(&scratch.attn_kv_temp, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"K copy failed: {e:?}"
))
})?;
}
}
batched_transpose_forward(
&scratch.attn_kv_temp2,
&mut scratch.attn_kv_temp,
nh,
s,
hd,
stream,
)?;
batched_4d_gemm_forward(
&scratch.attn_q_batched,
&scratch.attn_kv_temp,
&mut scratch.attn_scores,
1,
nh,
s,
s,
hd,
stream,
)?;
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let total_scores = num_heads * seq_len * seq_len;
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_scores.as_ptr(),
scratch.attn_scores.len(),
)
};
scale_forward(
&scores_view,
&mut scratch.attn_scores,
scale_factor,
saturating_u32(total_scores),
stream,
)?;
leak(scores_view);
{
let seq_sq = seq_len * seq_len;
let mask_ptr = scratch.causal_mask_contiguous.as_ptr();
let scores_base = scratch.attn_scores.as_ptr();
for head in 0..num_heads {
let byte_offset = (head * seq_sq * 4) as u64;
let head_ptr = scores_base + byte_offset;
let mask_view = unsafe { GpuBuffer::<f32>::from_raw_parts(mask_ptr, seq_sq) };
let scores_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
let mut out_view = unsafe { GpuBuffer::<f32>::from_raw_parts(head_ptr, seq_sq) };
residual_add_forward(
&mask_view,
&scores_view,
&mut out_view,
saturating_u32(seq_sq),
stream,
)?;
leak(mask_view);
leak(scores_view);
leak(out_view);
}
}
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_scores.as_ptr(),
scratch.attn_scores.len(),
)
};
batched_softmax_forward(
&scores_view,
&mut scratch.attn_scores,
saturating_u32(num_heads * seq_len),
s,
stream,
)?;
leak(scores_view);
interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
if heads_per_kv > 1 {
expand_kv_heads(
&scratch.attn_kv_temp,
&mut scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
} else {
unsafe {
scratch
.attn_kv_temp2
.copy_from_buffer_async(&scratch.attn_kv_temp, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"V copy failed: {e:?}"
))
})?;
}
}
batched_4d_gemm_forward(
&scratch.attn_scores,
&scratch.attn_kv_temp2,
&mut scratch.attn_q_batched,
1,
nh,
s,
hd,
s,
stream,
)?;
batched_to_interleaved_forward(
&scratch.attn_q_batched,
&mut scratch.attn_out,
s,
nh,
hd,
stream,
)?;
Ok(())
}
}
#[cfg(feature = "cuda")]
pub(crate) struct CudaLoraGradWorkspace {
pub(crate) grad_lora_a_q: GpuBuffer<f32>,
pub(crate) grad_lora_b_q: GpuBuffer<f32>,
pub(crate) grad_lora_a_v: GpuBuffer<f32>,
pub(crate) grad_lora_b_v: GpuBuffer<f32>,
pub(crate) grad_input_norm: GpuBuffer<f32>,
pub(crate) grad_post_attn_norm: GpuBuffer<f32>,
}
#[cfg(feature = "cuda")]
impl CudaLoraGradWorkspace {
pub(crate) fn new(
ctx: &Arc<CudaContext>,
config: &super::config::TransformerConfig,
lora_rank: usize,
) -> Result<Self> {
let h = config.hidden_size;
let q_dim = config.q_dim();
let kv = config.num_kv_heads * config.head_dim();
let r = lora_rank;
Ok(Self {
grad_lora_a_q: GpuBuffer::new(ctx, h * r)?,
grad_lora_b_q: GpuBuffer::new(ctx, r * q_dim)?,
grad_lora_a_v: GpuBuffer::new(ctx, h * r)?,
grad_lora_b_v: GpuBuffer::new(ctx, r * kv)?,
grad_input_norm: GpuBuffer::new(ctx, h)?,
grad_post_attn_norm: GpuBuffer::new(ctx, h)?,
})
}
pub(crate) fn clip_gradients(&mut self, max_norm: f32, stream: &CudaStream) {
let sq_a_q = squared_sum_cuda(&self.grad_lora_a_q, self.grad_lora_a_q.len() as u32, stream)
.unwrap_or(0.0);
let sq_b_q = squared_sum_cuda(&self.grad_lora_b_q, self.grad_lora_b_q.len() as u32, stream)
.unwrap_or(0.0);
let sq_a_v = squared_sum_cuda(&self.grad_lora_a_v, self.grad_lora_a_v.len() as u32, stream)
.unwrap_or(0.0);
let sq_b_v = squared_sum_cuda(&self.grad_lora_b_v, self.grad_lora_b_v.len() as u32, stream)
.unwrap_or(0.0);
let sq_in =
squared_sum_cuda(&self.grad_input_norm, self.grad_input_norm.len() as u32, stream)
.unwrap_or(0.0);
let sq_pa = squared_sum_cuda(
&self.grad_post_attn_norm,
self.grad_post_attn_norm.len() as u32,
stream,
)
.unwrap_or(0.0);
let total_norm = (sq_a_q + sq_b_q + sq_a_v + sq_b_v + sq_in + sq_pa).sqrt();
if total_norm <= max_norm {
return;
}
let clip_scale = max_norm / (total_norm + 1e-6);
let n_aq = self.grad_lora_a_q.len() as u32;
let n_bq = self.grad_lora_b_q.len() as u32;
let n_av = self.grad_lora_a_v.len() as u32;
let n_bv = self.grad_lora_b_v.len() as u32;
let n_in = self.grad_input_norm.len() as u32;
let n_pa = self.grad_post_attn_norm.len() as u32;
let _ = gradient_clip_cuda(&mut self.grad_lora_a_q, clip_scale, n_aq, stream);
let _ = gradient_clip_cuda(&mut self.grad_lora_b_q, clip_scale, n_bq, stream);
let _ = gradient_clip_cuda(&mut self.grad_lora_a_v, clip_scale, n_av, stream);
let _ = gradient_clip_cuda(&mut self.grad_lora_b_v, clip_scale, n_bv, stream);
let _ = gradient_clip_cuda(&mut self.grad_input_norm, clip_scale, n_in, stream);
let _ = gradient_clip_cuda(&mut self.grad_post_attn_norm, clip_scale, n_pa, stream);
}
}
#[cfg(feature = "cuda")]
pub(crate) struct GpuLoraOptimizerState {
m_lora_a_q: GpuBuffer<f32>,
v_lora_a_q: GpuBuffer<f32>,
m_lora_b_q: GpuBuffer<f32>,
v_lora_b_q: GpuBuffer<f32>,
m_lora_a_v: GpuBuffer<f32>,
v_lora_a_v: GpuBuffer<f32>,
m_lora_b_v: GpuBuffer<f32>,
v_lora_b_v: GpuBuffer<f32>,
m_input_norm: GpuBuffer<f32>,
v_input_norm: GpuBuffer<f32>,
m_post_attn_norm: GpuBuffer<f32>,
v_post_attn_norm: GpuBuffer<f32>,
}
#[cfg(feature = "cuda")]
impl GpuLoraOptimizerState {
fn new(
ctx: &Arc<CudaContext>,
config: &super::config::TransformerConfig,
lora_rank: usize,
) -> Result<Self> {
let h = config.hidden_size;
let q_dim = config.q_dim();
let kv = config.num_kv_heads * config.head_dim();
let r = lora_rank;
let z = |n: usize| -> Result<GpuBuffer<f32>> {
Ok(GpuBuffer::from_host(ctx, &vec![0.0f32; n])?)
};
Ok(Self {
m_lora_a_q: z(h * r)?,
v_lora_a_q: z(h * r)?,
m_lora_b_q: z(r * q_dim)?,
v_lora_b_q: z(r * q_dim)?,
m_lora_a_v: z(h * r)?,
v_lora_a_v: z(h * r)?,
m_lora_b_v: z(r * kv)?,
v_lora_b_v: z(r * kv)?,
m_input_norm: z(h)?,
v_input_norm: z(h)?,
m_post_attn_norm: z(h)?,
v_post_attn_norm: z(h)?,
})
}
}
#[cfg(feature = "cuda")]
impl CudaNf4TransformerBlock {
#[allow(clippy::too_many_arguments)]
pub(crate) fn backward(
&self,
layer_input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
output_scratch: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
grad_lora: &mut CudaLoraGradWorkspace,
) -> Result<()> {
let hidden_size = self.config.hidden_size;
let _q_dim = self.config.q_dim();
let _kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
let intermediate_size = self.config.intermediate_size;
let eps = 1e-5_f32;
self.forward(layer_input, output_scratch, seq_len, stream, scratch).map_err(|e| {
eprintln!(
"[backward] Layer {} activation-checkpoint forward FAILED: {e:?}",
self.layer_idx
);
e
})?;
self.backward_nf4_ffn(
grad_output,
seq_len,
hidden_size,
intermediate_size,
stream,
scratch,
)?;
let _t = scratch.op_begin(); rms_norm_backward(
&scratch.residual1,
&self.post_attn_norm_weight,
&scratch.grad_hidden, grad_input, &mut grad_lora.grad_post_attn_norm,
saturating_u32(seq_len),
saturating_u32(hidden_size),
eps,
stream,
)?;
cuda_add_inplace(grad_input, grad_output, seq_len * hidden_size, stream)?;
self.backward_nf4_attention(
grad_input, seq_len, stream, scratch, grad_lora,
)?;
rms_norm_backward(
layer_input,
&self.input_norm_weight,
&scratch.grad_hidden, grad_input, &mut grad_lora.grad_input_norm,
saturating_u32(seq_len),
saturating_u32(hidden_size),
eps,
stream,
)?;
scratch.op_end(_t, OP_NORM_BWD);
Ok(())
}
fn backward_nf4_ffn(
&self,
grad_output: &GpuBuffer<f32>,
seq_len: usize,
hidden_size: usize,
intermediate_size: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
) -> Result<()> {
let s = saturating_u32(seq_len);
let h = saturating_u32(hidden_size);
let i_size = saturating_u32(intermediate_size);
let n_inter = saturating_u32(seq_len * intermediate_size);
static USE_NF4_TC_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let nf4_tc_bwd =
*USE_NF4_TC_BWD.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
let _t = scratch.op_begin(); if nf4_tc_bwd {
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
grad_output,
&self.w_down_nf4,
&self.w_down_scales,
&mut scratch.grad_swiglu,
s,
h, i_size, stream,
)?;
} else {
gemm_backward_a_fp16_dispatch(
grad_output,
self.w_down_fp16.as_ref(),
&self.w_down_fp32,
&mut scratch.grad_swiglu,
s,
i_size,
h,
stream,
&self.ctx,
)?;
}
scratch.op_end(_t, OP_DOWN_BWD);
let _t = scratch.op_begin();
elementwise_mul_forward(
&scratch.grad_swiglu,
&scratch.up_out,
&mut scratch.swiglu_out,
n_inter,
stream,
)?;
silu_backward(
&scratch.gate_out,
&scratch.swiglu_out,
&mut scratch.up_out, stream,
)?;
silu_forward(&scratch.gate_out, &mut scratch.swiglu_out, n_inter, stream)?;
elementwise_mul_forward(
&scratch.grad_swiglu,
&scratch.swiglu_out,
&mut scratch.gate_out, n_inter,
stream,
)?;
scratch.op_end(_t, OP_SWIGLU_BWD);
let _t = scratch.op_begin(); static USE_FUSED_BWD: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let fused_bwd = *USE_FUSED_BWD
.get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
if nf4_tc_bwd {
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
&scratch.gate_out, &self.w_up_nf4,
&self.w_up_scales,
&mut scratch.grad_hidden,
s,
i_size, h, stream,
)?;
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
&scratch.up_out, &self.w_gate_nf4,
&self.w_gate_scales,
&mut scratch.ffn_out,
s,
i_size, h, stream,
)?;
cuda_add_inplace(
&mut scratch.grad_hidden,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
} else if fused_bwd {
gemm_backward_a_fp16_dispatch(
&scratch.gate_out,
self.w_up_fp16.as_ref(),
&self.w_up_fp32,
&mut scratch.grad_hidden,
s,
h,
i_size,
stream,
&self.ctx,
)?;
gemm_backward_a_fp16_dispatch_accumulate(
&scratch.up_out,
self.w_gate_fp16.as_ref(),
&self.w_gate_fp32,
&mut scratch.grad_hidden,
s,
h,
i_size,
stream,
&self.ctx,
)?;
} else {
gemm_backward_a_fp16_dispatch(
&scratch.up_out,
self.w_gate_fp16.as_ref(),
&self.w_gate_fp32,
&mut scratch.ffn_out,
s,
h,
i_size,
stream,
&self.ctx,
)?;
gemm_backward_a_fp16_dispatch(
&scratch.gate_out,
self.w_up_fp16.as_ref(),
&self.w_up_fp32,
&mut scratch.grad_hidden,
s,
h,
i_size,
stream,
&self.ctx,
)?;
cuda_add_inplace(
&mut scratch.grad_hidden,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
}
scratch.op_end(_t, OP_GATE_UP_BWD);
Ok(())
}
fn backward_nf4_attention(
&self,
grad_residual1: &GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
grad_lora: &mut CudaLoraGradWorkspace,
) -> Result<()> {
use crate::autograd::cuda_forward::gemm_forward;
let hidden_size = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv_hidden_size = self.config.num_kv_heads * self.config.head_dim();
let num_heads = self.config.num_attention_heads;
let head_dim = self.config.head_dim();
let s = saturating_u32(seq_len);
let h = saturating_u32(hidden_size);
let qd = saturating_u32(q_dim);
let kvh = saturating_u32(kv_hidden_size);
static USE_NF4_TC_BWD_O: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let nf4_tc_bwd_o = *USE_NF4_TC_BWD_O
.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
let _t = scratch.op_begin(); if nf4_tc_bwd_o {
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
grad_residual1,
&self.w_o_nf4,
&self.w_o_scales,
&mut scratch.attn_out,
s,
h, qd, stream,
)?;
} else {
gemm_backward_a_fp16_dispatch(
grad_residual1,
self.w_o_fp16.as_ref(),
&self.w_o_fp32,
&mut scratch.attn_out,
s,
qd,
h,
stream,
&self.ctx,
)?;
}
self.backward_nf4_attention_mechanism(seq_len, num_heads, head_dim, stream, scratch)?;
let rope_theta = self.config.rope_theta;
let num_kv_heads = self.config.num_kv_heads;
let nkv = saturating_u32(num_kv_heads);
let nh = saturating_u32(num_heads);
let hd = saturating_u32(head_dim);
{
let q_ref = unsafe { &*(std::ptr::addr_of!(scratch.q)) };
batched_rope_neox_backward(
q_ref,
&mut scratch.q,
&scratch.rope_positions,
nh,
hd,
s,
rope_theta,
stream,
)?;
let k_ref = unsafe { &*(std::ptr::addr_of!(scratch.k)) };
batched_rope_neox_backward(
k_ref,
&mut scratch.k,
&scratch.rope_positions,
nkv,
hd,
s,
rope_theta,
stream,
)?;
}
scratch.op_end(_t, OP_ATTN_BWD);
static USE_NF4_TC_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let nf4_tc_bwd = *USE_NF4_TC_BWD_ATTN
.get_or_init(|| std::env::var("NF4_TC_BWD_GEMM").as_deref() == Ok("1"));
let _t = scratch.op_begin(); if nf4_tc_bwd {
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
&scratch.q,
&self.w_q_nf4,
&self.w_q_scales,
&mut scratch.o_proj_out,
s,
qd, h, stream,
)?;
} else {
gemm_backward_a_fp16_dispatch(
&scratch.q,
self.w_q_fp16.as_ref(),
&self.w_q_fp32,
&mut scratch.o_proj_out,
s,
h,
qd,
stream,
&self.ctx,
)?;
}
if let (Some(a_q), Some(b_q)) = (&self.lora_a_q, &self.lora_b_q) {
let r = saturating_u32(self.lora_rank);
gemm_forward(&scratch.norm1_out, a_q, &mut scratch.lora_inter, s, h, r, stream)?;
gemm_backward_b(
&scratch.lora_inter,
&scratch.q,
&mut grad_lora.grad_lora_b_q,
s,
r,
qd,
stream,
)?;
gemm_backward_a(
&scratch.q,
b_q,
&mut scratch.lora_inter, s,
qd,
r,
stream,
)?;
gemm_backward_b(
&scratch.norm1_out,
&scratch.lora_inter,
&mut grad_lora.grad_lora_a_q,
s,
h,
r,
stream,
)?;
gemm_backward_a(
&scratch.lora_inter,
a_q,
&mut scratch.lora_temp, s,
r,
h,
stream,
)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.lora_temp,
seq_len * hidden_size,
stream,
)?;
}
static USE_FUSED_BWD_ATTN: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
let fused_bwd = *USE_FUSED_BWD_ATTN
.get_or_init(|| std::env::var("NF4_FUSED_BWD_GEMM").as_deref() == Ok("1"));
if nf4_tc_bwd {
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
&scratch.k,
&self.w_k_nf4,
&self.w_k_scales,
&mut scratch.ffn_out,
s,
kvh, h, stream,
)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
crate::autograd::cuda_forward::gemm_nf4_tc_backward_a(
&scratch.v,
&self.w_v_nf4,
&self.w_v_scales,
&mut scratch.ffn_out,
s,
kvh, h, stream,
)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
} else if fused_bwd {
gemm_backward_a_fp16_dispatch_accumulate(
&scratch.k,
self.w_k_fp16.as_ref(),
&self.w_k_fp32,
&mut scratch.o_proj_out,
s,
h,
kvh,
stream,
&self.ctx,
)?;
gemm_backward_a_fp16_dispatch_accumulate(
&scratch.v,
self.w_v_fp16.as_ref(),
&self.w_v_fp32,
&mut scratch.o_proj_out,
s,
h,
kvh,
stream,
&self.ctx,
)?;
} else {
gemm_backward_a_fp16_dispatch(
&scratch.k,
self.w_k_fp16.as_ref(),
&self.w_k_fp32,
&mut scratch.ffn_out,
s,
h,
kvh,
stream,
&self.ctx,
)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
gemm_backward_a_fp16_dispatch(
&scratch.v,
self.w_v_fp16.as_ref(),
&self.w_v_fp32,
&mut scratch.ffn_out,
s,
h,
kvh,
stream,
&self.ctx,
)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.ffn_out,
seq_len * hidden_size,
stream,
)?;
}
if let (Some(a_v), Some(b_v)) = (&self.lora_a_v, &self.lora_b_v) {
let r = saturating_u32(self.lora_rank);
gemm_forward(&scratch.norm1_out, a_v, &mut scratch.lora_inter, s, h, r, stream)?;
gemm_backward_b(
&scratch.lora_inter,
&scratch.v,
&mut grad_lora.grad_lora_b_v,
s,
r,
kvh,
stream,
)?;
gemm_backward_a(&scratch.v, b_v, &mut scratch.lora_inter, s, kvh, r, stream)?;
gemm_backward_b(
&scratch.norm1_out,
&scratch.lora_inter,
&mut grad_lora.grad_lora_a_v,
s,
h,
r,
stream,
)?;
gemm_backward_a(&scratch.lora_inter, a_v, &mut scratch.lora_temp, s, r, h, stream)?;
cuda_add_inplace(
&mut scratch.o_proj_out,
&scratch.lora_temp,
seq_len * hidden_size,
stream,
)?;
}
scratch.op_end(_t, OP_QKV_BWD);
unsafe {
scratch.grad_hidden.copy_from_buffer_async(&scratch.o_proj_out, stream).map_err(
|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"grad_norm1 copy failed: {e}"
))
},
)?;
}
Ok(())
}
fn backward_nf4_attention_mechanism(
&self,
seq_len: usize,
num_heads: usize,
head_dim: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
) -> Result<()> {
let num_kv_heads = self.config.num_kv_heads;
let heads_per_kv = num_heads / num_kv_heads;
let s = saturating_u32(seq_len);
let nh = saturating_u32(num_heads);
let nkv = saturating_u32(num_kv_heads);
let hd = saturating_u32(head_dim);
let scale = 1.0 / (head_dim as f32).sqrt();
interleaved_to_batched_forward(
&scratch.attn_out,
&mut scratch.attn_q_batched, s,
nh,
hd,
stream,
)?;
interleaved_to_batched_forward(&scratch.v, &mut scratch.attn_kv_temp, s, nkv, hd, stream)?;
if heads_per_kv > 1 {
expand_kv_heads(
&scratch.attn_kv_temp,
&mut scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
} else {
unsafe {
scratch
.attn_kv_temp2
.copy_from_buffer_async(&scratch.attn_kv_temp, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"V copy for attn backward: {e:?}"
))
})?;
}
}
batched_transpose_forward(
&scratch.attn_kv_temp2,
&mut scratch.attn_kv_temp, nh,
s,
hd,
stream,
)?;
batched_4d_gemm_forward(
&scratch.attn_q_batched,
&scratch.attn_kv_temp,
&mut scratch.grad_attn_scores,
1,
nh,
s,
s,
hd,
stream,
)?;
batched_transpose_forward(
&scratch.attn_q_batched,
&mut scratch.attn_kv_temp, nh,
s,
hd,
stream,
)?;
batched_4d_gemm_forward(
&scratch.attn_kv_temp,
&scratch.attn_scores, &mut scratch.attn_kv_temp2, 1,
nh,
hd,
s,
s,
stream,
)?;
batched_transpose_forward(
&scratch.attn_kv_temp2,
&mut scratch.attn_kv_temp, nh,
hd,
s,
stream,
)?;
let total_rows = nh * s;
{
let grad_scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.grad_attn_scores.as_ptr(),
scratch.grad_attn_scores.len(),
)
};
batched_softmax_backward(
&scratch.attn_scores,
&grad_scores_view,
&mut scratch.grad_attn_scores,
total_rows,
s,
stream,
)?;
leak(grad_scores_view);
}
let total_scores = saturating_u32(num_heads * seq_len * seq_len);
{
let scores_view = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.grad_attn_scores.as_ptr(),
scratch.grad_attn_scores.len(),
)
};
scale_forward(
&scores_view,
&mut scratch.grad_attn_scores,
scale,
total_scores,
stream,
)?;
leak(scores_view);
}
interleaved_to_batched_forward(&scratch.k, &mut scratch.attn_kv_temp2, s, nkv, hd, stream)?;
if heads_per_kv > 1 {
unsafe {
scratch
.attn_q_batched
.copy_from_buffer_async(&scratch.attn_kv_temp2, stream)
.map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"K copy for GQA expand: {e}"
))
})?;
}
expand_kv_heads(
&scratch.attn_q_batched,
&mut scratch.attn_kv_temp2,
num_kv_heads,
heads_per_kv,
seq_len * head_dim,
stream,
)?;
}
batched_4d_gemm_forward(
&scratch.grad_attn_scores,
&scratch.attn_kv_temp2,
&mut scratch.attn_q_batched,
1,
nh,
s,
hd,
s,
stream,
)?;
interleaved_to_batched_forward(
&scratch.q,
&mut scratch.o_proj_out, s,
nh,
hd,
stream,
)?;
batched_transpose_forward(
&scratch.o_proj_out,
&mut scratch.attn_kv_temp2, nh,
s,
hd,
stream,
)?;
batched_4d_gemm_forward(
&scratch.attn_kv_temp2,
&scratch.grad_attn_scores,
&mut scratch.ffn_out, 1,
nh,
hd,
s,
s,
stream,
)?;
batched_transpose_forward(
&scratch.ffn_out,
&mut scratch.attn_kv_temp2, nh,
hd,
s,
stream,
)?;
if heads_per_kv > 1 {
self.reduce_gqa_gradients_nf4(
num_kv_heads,
heads_per_kv,
seq_len,
head_dim,
stream,
scratch,
)?;
}
batched_to_interleaved_forward(&scratch.attn_q_batched, &mut scratch.q, s, nh, hd, stream)?;
batched_to_interleaved_forward(&scratch.attn_kv_temp2, &mut scratch.k, s, nkv, hd, stream)?;
batched_to_interleaved_forward(&scratch.attn_kv_temp, &mut scratch.v, s, nkv, hd, stream)?;
Ok(())
}
fn reduce_gqa_gradients_nf4(
&self,
num_kv_heads: usize,
heads_per_kv: usize,
seq_len: usize,
head_dim: usize,
stream: &CudaStream,
scratch: &mut CudaBlockScratch,
) -> Result<()> {
let chunk = seq_len * head_dim;
for g in 0..num_kv_heads {
let dst_off = g * chunk;
let src_off = g * heads_per_kv * chunk;
{
let src = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp2.as_ptr() + (src_off * 4) as u64,
chunk,
)
};
let mut dst = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
chunk,
)
};
if src_off != dst_off {
unsafe {
dst.copy_from_buffer_async(&src, stream).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA K reduce copy: {e}"
))
})?;
}
}
leak(src);
leak(dst);
}
for h in 1..heads_per_kv {
let add_off = (g * heads_per_kv + h) * chunk;
let src = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp2.as_ptr() + (add_off * 4) as u64,
chunk,
)
};
let mut dst = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp2.as_ptr() + (dst_off * 4) as u64,
chunk,
)
};
cuda_add_inplace(&mut dst, &src, chunk, stream)?;
leak(src);
leak(dst);
}
{
let src = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp.as_ptr() + (src_off * 4) as u64,
chunk,
)
};
let mut dst = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
chunk,
)
};
if src_off != dst_off {
unsafe {
dst.copy_from_buffer_async(&src, stream).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"GQA V reduce copy: {e}"
))
})?;
}
}
leak(src);
leak(dst);
}
for h in 1..heads_per_kv {
let add_off = (g * heads_per_kv + h) * chunk;
let src = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp.as_ptr() + (add_off * 4) as u64,
chunk,
)
};
let mut dst = unsafe {
GpuBuffer::<f32>::from_raw_parts(
scratch.attn_kv_temp.as_ptr() + (dst_off * 4) as u64,
chunk,
)
};
cuda_add_inplace(&mut dst, &src, chunk, stream)?;
leak(src);
leak(dst);
}
}
Ok(())
}
pub(crate) fn init_lora_optimizer_state(&self) -> Result<GpuLoraOptimizerState> {
GpuLoraOptimizerState::new(&self.ctx, &self.config, self.lora_rank)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn lora_optimizer_step(
&mut self,
state: &mut GpuLoraOptimizerState,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
stream: &CudaStream,
grad_lora: &CudaLoraGradWorkspace,
) -> Result<()> {
let h = self.config.hidden_size;
let q_dim = self.config.q_dim();
let kv = self.config.num_kv_heads * self.config.head_dim();
let r = self.lora_rank;
if let Some(ref mut a_q) = self.lora_a_q {
adamw_step_cuda(
a_q,
&grad_lora.grad_lora_a_q,
&mut state.m_lora_a_q,
&mut state.v_lora_a_q,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(h * r),
stream,
)?;
}
if let Some(ref mut b_q) = self.lora_b_q {
adamw_step_cuda(
b_q,
&grad_lora.grad_lora_b_q,
&mut state.m_lora_b_q,
&mut state.v_lora_b_q,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(r * q_dim),
stream,
)?;
}
if let Some(ref mut a_v) = self.lora_a_v {
adamw_step_cuda(
a_v,
&grad_lora.grad_lora_a_v,
&mut state.m_lora_a_v,
&mut state.v_lora_a_v,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(h * r),
stream,
)?;
}
if let Some(ref mut b_v) = self.lora_b_v {
adamw_step_cuda(
b_v,
&grad_lora.grad_lora_b_v,
&mut state.m_lora_b_v,
&mut state.v_lora_b_v,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(r * kv),
stream,
)?;
}
adamw_step_cuda(
&mut self.input_norm_weight,
&grad_lora.grad_input_norm,
&mut state.m_input_norm,
&mut state.v_input_norm,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(h),
stream,
)?;
adamw_step_cuda(
&mut self.post_attn_norm_weight,
&grad_lora.grad_post_attn_norm,
&mut state.m_post_attn_norm,
&mut state.v_post_attn_norm,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
saturating_u32(h),
stream,
)?;
Ok(())
}
pub fn download_lora_weights(&self) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)> {
let download = |buf: &GpuBuffer<f32>| -> Result<Vec<f32>> {
let mut host = vec![0.0f32; buf.len()];
buf.copy_to_host(&mut host).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"LoRA weight download failed: {e}"
))
})?;
Ok(host)
};
let a_q = self.lora_a_q.as_ref().map(&download).transpose()?.unwrap_or_default();
let b_q = self.lora_b_q.as_ref().map(&download).transpose()?.unwrap_or_default();
let a_v = self.lora_a_v.as_ref().map(&download).transpose()?.unwrap_or_default();
let b_v = self.lora_b_v.as_ref().map(&download).transpose()?.unwrap_or_default();
Ok((a_q, b_q, a_v, b_v))
}
pub fn upload_lora_weights(
&mut self,
a_q: &[f32],
b_q: &[f32],
a_v: &[f32],
b_v: &[f32],
) -> Result<()> {
let upload = |buf: &mut GpuBuffer<f32>, data: &[f32], name: &str| -> Result<()> {
if data.len() != buf.len() {
return Err(crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(
format!(
"LoRA {name} size mismatch: checkpoint has {} but GPU buffer expects {}",
data.len(),
buf.len()
),
));
}
buf.copy_from_host(data).map_err(|e| {
crate::autograd::cuda_tensor::CudaTensorError::TransferFailed(format!(
"LoRA {name} upload failed: {e}"
))
})
};
if let Some(ref mut buf) = self.lora_a_q {
upload(buf, a_q, "a_q")?;
}
if let Some(ref mut buf) = self.lora_b_q {
upload(buf, b_q, "b_q")?;
}
if let Some(ref mut buf) = self.lora_a_v {
upload(buf, a_v, "a_v")?;
}
if let Some(ref mut buf) = self.lora_b_v {
upload(buf, b_v, "b_v")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_cuda_block_compiles() {
#[cfg(feature = "cuda")]
{
use super::*;
let _ = std::mem::size_of::<CudaTransformerBlock>();
let _ = std::mem::size_of::<CudaNf4TransformerBlock>();
}
}
}